Skip to content

Commit

Permalink
Merge pull request #1926 from opensafely-core/evansd/fix-population-e…
Browse files Browse the repository at this point in the history
…dge-case

Fix edge case in constructing population table
  • Loading branch information
evansd committed Feb 23, 2024
2 parents cee1471 + 01ab7a8 commit a2f4389
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
22 changes: 12 additions & 10 deletions ehrql/query_engines/base_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,23 @@ def select_patient_id_for_population(self, population_expression):
available patient_ids. We could then use such tables if available rather than
messing around with UNIONS.
"""
# Get all the tables needed to evaluate the population expression
# Get all the tables needed to evaluate the population expression and select
# patients IDs from each one
tables = get_patient_id_tables(population_expression)
if len(tables) > 1:
# Select all patient IDs from all tables referenced in the expression
id_selects = [
sqlalchemy.select(table.c.patient_id.label("patient_id"))
for table in tables
]
id_selects = [
sqlalchemy.select(table.c.patient_id.label("patient_id"))
for table in tables
]
if len(id_selects) > 1:
# Create a table which contains the union of all these IDs. (Note UNION
# rather than UNION ALL so we don't get duplicates.)
all_ids_table = self.reify_query(sqlalchemy.union(*id_selects))
return sqlalchemy.select(all_ids_table.c.patient_id)
elif len(tables) == 1:
# If there's only one table then use the IDs from that
return sqlalchemy.select(tables[0].c.patient_id.label("patient_id"))
elif len(id_selects) == 1:
# If there's only one table then we have to use DISTINCT rather than UNION
# to remove duplicates
distinct_ids_table = self.reify_query(id_selects[0].distinct())
return sqlalchemy.select(distinct_ids_table.c.patient_id)
else:
# Gracefully handle the degenerate case where the population expression
# doesn't reference any tables at all. Our validation rules ensure that such
Expand Down
31 changes: 30 additions & 1 deletion tests/integration/test_query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sqlalchemy

from ehrql import create_dataset, minimum_of, when
from ehrql.query_model.nodes import Function, Value
from ehrql.query_model.nodes import AggregateByPatient, Function, Value
from ehrql.tables import (
EventFrame,
PatientFrame,
Expand Down Expand Up @@ -321,3 +321,32 @@ def test_horizontal_aggregation_wrapping_a_series_containment_query_works(engine
{"patient_id": 1, "match": "T"},
{"patient_id": 2, "match": "F"},
]


def test_population_which_uses_combine_as_set_and_no_patient_frame(engine):
# A population definition must be patient-level and therefore, if it only references
# event frames, it must involve an aggregation somewhere. Most aggregations result
# in a new patient-level SQL table being created but CombineAsSet is unusual here and
# so it's possible to use it to create a population SQL expression which references
# just a single event-level SQL table. This falsifies a previous assumption we made
# and so we need to test that we handle it correctly.
variables = dict(
population=Function.In(
Value(1),
AggregateByPatient.CombineAsSet(as_query_model(events.i)),
),
v=Value(True),
)

engine.populate(
{
events: [
{"patient_id": 1, "i": 1},
{"patient_id": 1, "i": 1},
],
}
)

assert engine.extract_qm(variables) == [
{"patient_id": 1, "v": True},
]

0 comments on commit a2f4389

Please sign in to comment.