Skip to content

Commit

Permalink
QueryBuilder: add flat keyword to first method
Browse files Browse the repository at this point in the history
This keyword already exists for the `all` method and it will likewise be
useful for `first` when only a single quantity is projected. In that
case, often the caller doesn't want a list as a return value but simply
the projected quantity. Allowing to get this directly from the method
call as opposed to manually dereferencing the first item from the
returned list often makes for cleaner code.
  • Loading branch information
sphuber committed Mar 9, 2022
1 parent 6866d03 commit 5c0564d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
18 changes: 13 additions & 5 deletions aiida/orm/querybuilder.py
Expand Up @@ -989,20 +989,28 @@ def _get_aiida_entity_res(value) -> Any:
except TypeError:
return value

def first(self) -> Optional[List[Any]]:
"""Executes the query, asking for the first row of results.
def first(self, flat: bool = False) -> Optional[Union[List[Any], Any]]:
"""Return the first result of the query.
Note, this may change if several rows are valid for the query,
as persistent ordering is not guaranteed unless explicitly specified.
Calling ``first`` results in an execution of the underlying query.
Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless
explicitly specified.
:param flat: if True, return just the projected quantity if there is just a single projection.
:returns: One row of results as a list, or None if no result returned.
"""
result = self._impl.first(self.as_dict())

if result is None:
return None

return [self._get_aiida_entity_res(rowitem) for rowitem in result]
result = [self._get_aiida_entity_res(rowitem) for rowitem in result]

if flat and len(result) == 1:
return result[0]

return result

def count(self) -> int:
"""
Expand Down
17 changes: 15 additions & 2 deletions tests/orm/test_querybuilder.py
Expand Up @@ -649,7 +649,7 @@ def test_direction_keyword(self):
assert res2 == {d2.id, d4.id}

@staticmethod
def test_flat():
def test_all_flat():
"""Test the `flat` keyword for the `QueryBuilder.all()` method."""
pks = []
uuids = []
Expand All @@ -665,13 +665,26 @@ def test_flat():
assert len(result) == 10
assert result == pks

# Mutltiple projections
# Multiple projections
builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'})
result = builder.all(flat=True)
assert isinstance(result, list)
assert len(result) == 20
assert result == list(chain.from_iterable(zip(pks, uuids)))

@staticmethod
def test_first_flat():
"""Test the `flat` keyword for the `QueryBuilder.first()` method."""
node = orm.Data().store()

# Single projected property
query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk})
assert query.first(flat=True) == node.pk

# Mutltiple projections
query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk})
assert query.first(flat=True) == [node.pk, node.uuid]

def test_query_links(self):
"""Test querying for links"""
d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)]
Expand Down

0 comments on commit 5c0564d

Please sign in to comment.