Skip to content

Commit

Permalink
QueryBuilder: add flat keyword to first method
Browse files Browse the repository at this point in the history
This keyword already exists for the `all` method and it will likewise be
useful for `first` when only a single quantity is projected. In that
case, often the caller doesn't want a list as a return value but simply
the projected quantity. Allowing to get this directly from the method
call as opposed to manually dereferencing the first item from the
returned list often makes for cleaner code.
  • Loading branch information
sphuber committed Mar 6, 2022
1 parent 09765ec commit 3bedc9b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
12 changes: 9 additions & 3 deletions aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,20 +989,26 @@ def _get_aiida_entity_res(value) -> Any:
except TypeError:
return value

def first(self) -> Optional[List[Any]]:
"""Executes the query, asking for the first row of results.
def first(self, flat: bool = False) -> Optional[Union[List[Any], Any]]:
"""Return the first result of the query.
Note, this may change if several rows are valid for the query,
as persistent ordering is not guaranteed unless explicitly specified.
:param flat: if True, return just the projected quantity if there is just a single projection.
:returns: One row of results as a list, or None if no result returned.
"""
result = self._impl.first(self.as_dict())

if result is None:
return None

return [self._get_aiida_entity_res(rowitem) for rowitem in result]
result = [self._get_aiida_entity_res(rowitem) for rowitem in result]

if flat and len(result) == 1:
return result[0]

return result

def count(self) -> int:
"""
Expand Down
17 changes: 15 additions & 2 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_direction_keyword(self):
assert res2 == {d2.id, d4.id}

@staticmethod
def test_flat():
def test_all_flat():
"""Test the `flat` keyword for the `QueryBuilder.all()` method."""
pks = []
uuids = []
Expand All @@ -665,13 +665,26 @@ def test_flat():
assert len(result) == 10
assert result == pks

# Mutltiple projections
# Multiple projections
builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'})
result = builder.all(flat=True)
assert isinstance(result, list)
assert len(result) == 20
assert result == list(chain.from_iterable(zip(pks, uuids)))

@staticmethod
def test_first_flat():
"""Test the `flat` keyword for the `QueryBuilder.first()` method."""
node = orm.Data().store()

# Single projected property
query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk})
assert query.first(flat=True) == node.pk

# Mutltiple projections
query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk})
assert query.first(flat=True) == [node.pk, node.uuid]

def test_query_links(self):
"""Test querying for links"""
d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)]
Expand Down

0 comments on commit 3bedc9b

Please sign in to comment.