Skip to content

Commit

Permalink
QueryBuilder: add flat keyword to first method (#5410)
Browse files Browse the repository at this point in the history
This keyword already exists for the `all` method and it will likewise be
useful for `first` when only a single quantity is projected. In that
case, often the caller doesn't want a list as a return value but simply
the projected quantity. Allowing to get this directly from the method
call as opposed to manually dereferencing the first item from the
returned list often makes for cleaner code.
  • Loading branch information
sphuber committed Mar 9, 2022
1 parent ffedc8b commit 5b10cd3
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 29 deletions.
3 changes: 1 addition & 2 deletions aiida/orm/nodes/data/upf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T
md5sum = md5_file(filename)
builder = orm.QueryBuilder(backend=backend)
builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}})
existing_upf = builder.first()
existing_upf = builder.first(flat=True)

if existing_upf is None:
# return the upfdata instances, not stored
Expand All @@ -133,7 +133,6 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T
else:
if stop_if_existing:
raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing')
existing_upf = existing_upf[0]
pseudo_and_created.append((existing_upf, False))

# check whether pseudo are unique per element
Expand Down
30 changes: 25 additions & 5 deletions aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance
when instantiated by the user.
"""
from __future__ import annotations

from copy import deepcopy
from inspect import isclass as inspect_isclass
from typing import (
Expand All @@ -27,6 +29,7 @@
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Expand All @@ -35,6 +38,7 @@
Type,
Union,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -989,20 +993,36 @@ def _get_aiida_entity_res(value) -> Any:
except TypeError:
return value

def first(self) -> Optional[List[Any]]:
"""Executes the query, asking for the first row of results.
@overload
def first(self, flat: Literal[False]) -> Optional[list[Any]]:
...

@overload
def first(self, flat: Literal[True]) -> Optional[Any]:
...

def first(self, flat: bool = False) -> Optional[list[Any] | Any]:
"""Return the first result of the query.
Note, this may change if several rows are valid for the query,
as persistent ordering is not guaranteed unless explicitly specified.
Calling ``first`` results in an execution of the underlying query.
Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless
explicitly specified.
:param flat: if True, return just the projected quantity if there is just a single projection.
:returns: One row of results as a list, or None if no result returned.
"""
result = self._impl.first(self.as_dict())

if result is None:
return None

return [self._get_aiida_entity_res(rowitem) for rowitem in result]
result = [self._get_aiida_entity_res(rowitem) for rowitem in result]

if flat and len(result) == 1:
return result[0]

return result

def count(self) -> int:
"""
Expand Down
4 changes: 2 additions & 2 deletions aiida/restapi/translator/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _get_content(self):
return {}

# otherwise ...
node = self.qbobj.first()[0]
node = self.qbobj.first()[0] # pylint: disable=unsubscriptable-object

# content/attributes
if self._content_type == 'attributes':
Expand Down Expand Up @@ -643,7 +643,7 @@ def get_node_description(node):
nodes = []

if qb_obj.count() > 0:
main_node = qb_obj.first()[0]
main_node = qb_obj.first(flat=True)
pk = main_node.pk
uuid = main_node.uuid
nodetype = main_node.node_type
Expand Down
4 changes: 2 additions & 2 deletions tests/orm/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_group_uuid_hashing_for_querybuidler(self):
# Search for the UUID of the stored group
builder = orm.QueryBuilder()
builder.append(orm.Group, project=['uuid'], filters={'label': {'==': 'test_group'}})
[uuid] = builder.first()
uuid = builder.first(flat=True)

# Look the node with the previously returned UUID
builder = orm.QueryBuilder()
Expand All @@ -279,7 +279,7 @@ def test_group_uuid_hashing_for_querybuidler(self):

# And that the results are correct
assert builder.count() == 1
assert builder.first()[0] == group.id
assert builder.first(flat=True) == group.id


@pytest.mark.usefixtures('aiida_profile_clean')
Expand Down
28 changes: 22 additions & 6 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_direction_keyword(self):
assert res2 == {d2.id, d4.id}

@staticmethod
def test_flat():
def test_all_flat():
"""Test the `flat` keyword for the `QueryBuilder.all()` method."""
pks = []
uuids = []
Expand All @@ -665,13 +665,26 @@ def test_flat():
assert len(result) == 10
assert result == pks

# Mutltiple projections
# Multiple projections
builder = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid']).order_by({orm.Data: 'id'})
result = builder.all(flat=True)
assert isinstance(result, list)
assert len(result) == 20
assert result == list(chain.from_iterable(zip(pks, uuids)))

@staticmethod
def test_first_flat():
"""Test the `flat` keyword for the `QueryBuilder.first()` method."""
node = orm.Data().store()

# Single projected property
query = orm.QueryBuilder().append(orm.Data, project='id', filters={'id': node.pk})
assert query.first(flat=True) == node.pk

# Mutltiple projections
query = orm.QueryBuilder().append(orm.Data, project=['id', 'uuid'], filters={'id': node.pk})
assert query.first(flat=True) == [node.pk, node.uuid]

def test_query_links(self):
"""Test querying for links"""
d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)]
Expand Down Expand Up @@ -703,13 +716,16 @@ def test_first_multiple_projections(self):
orm.Data().store()
orm.Data().store()

result = orm.QueryBuilder().append(orm.User, tag='user',
project=['email']).append(orm.Data, with_user='user', project=['*']).first()
query = orm.QueryBuilder()
query.append(orm.User, tag='user', project=['email'])
query.append(orm.Data, with_user='user', project=['*'])

result = query.first()

assert isinstance(result, list)
assert len(result) == 2
assert isinstance(result[0], str)
assert isinstance(result[1], orm.Data)
assert isinstance(result[0], str) # pylint: disable=unsubscriptable-object
assert isinstance(result[1], orm.Data) # pylint: disable=unsubscriptable-object


class TestRepresentations:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_node_uuid_hashing_for_querybuidler(self):
# Search for the UUID of the stored node
qb = orm.QueryBuilder()
qb.append(orm.Data, project=['uuid'], filters={'id': {'==': n.id}})
[uuid] = qb.first()
uuid = qb.first(flat=True)

# Look the node with the previously returned UUID
qb = orm.QueryBuilder()
Expand All @@ -99,7 +99,7 @@ def test_node_uuid_hashing_for_querybuidler(self):
qb.all()
# And that the results are correct
assert qb.count() == 1
assert qb.first()[0] == n.id
assert qb.first(flat=True) == n.id

@staticmethod
def create_folderdata_with_empty_file():
Expand Down
20 changes: 10 additions & 10 deletions tests/tools/archive/orm/test_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost):
builder = orm.QueryBuilder()
builder.append(orm.CalcJobNode, project=['label'])
assert builder.count() == 1, 'Only one calculation should be found.'
assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.'
assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.'

# Check that the referenced computer is imported correctly.
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, 'Only one computer should be found.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.'
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object

# Store the id of the computer
comp_id = builder.first()[2]
comp_id = builder.first()[2] # pylint: disable=unsubscriptable-object

# Import the second calculation
import_archive(filename2)
Expand All @@ -99,9 +99,9 @@ def test_same_computer_import(tmp_path, aiida_profile_clean, aiida_localhost):
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.'
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.'
assert builder.first()[2] == comp_id, 'The computer id is not correct.'
assert str(builder.first()[0]) == comp_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object
assert str(builder.first()[1]) == comp_uuid, 'The computer uuid is not correct.' # pylint: disable=unsubscriptable-object
assert builder.first()[2] == comp_id, 'The computer id is not correct.' # pylint: disable=unsubscriptable-object

# Check that now you have two calculations attached to the same
# computer.
Expand Down Expand Up @@ -175,13 +175,13 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid
builder = orm.QueryBuilder()
builder.append(orm.CalcJobNode, project=['label'])
assert builder.count() == 1, 'Only one calculation should be found.'
assert str(builder.first()[0]) == calc1_label, 'The calculation label is not correct.'
assert str(builder.first(flat=True)) == calc1_label, 'The calculation label is not correct.'

# Check that the referenced computer is imported correctly.
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label', 'uuid', 'id'])
assert builder.count() == 1, 'Only one computer should be found.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.' # pylint: disable=unsubscriptable-object

# Import the second calculation
import_archive(filename2)
Expand All @@ -191,7 +191,7 @@ def test_same_computer_different_name_import(tmp_path, aiida_profile_clean, aiid
builder = orm.QueryBuilder()
builder.append(orm.Computer, project=['label'])
assert builder.count() == 1, f'Found {builder.count()} computersbut only one computer should be found.'
assert str(builder.first()[0]) == comp1_name, 'The computer name is not correct.'
assert str(builder.first(flat=True)) == comp1_name, 'The computer name is not correct.'


def test_different_computer_same_name_import(tmp_path, aiida_profile_clean, aiida_localhost_factory):
Expand Down

0 comments on commit 5b10cd3

Please sign in to comment.