Skip to content

Commit

Permalink
Fetching only non trashed metadata containers (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky committed Oct 6, 2022
1 parent 81e3d17 commit bfffee4
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,7 @@
### Changes
- More consistent and strict way of git repository, source files and entrypoint detection ([#1007](https://github.com/neptune-ai/neptune-client/pull/1007))
- Moved neptune and neptune_cli to src dir ([#1027](https://github.com/neptune-ai/neptune-client/pull/1027))
- `fetch_runs_table(...)`, `fetch_models_table(...)` and `fetch_model_versions_table(...)` now queries only non-trashed ([#1033](https://github.com/neptune-ai/neptune-client/pull/1033))

## neptune-client 0.16.9

Expand Down
18 changes: 4 additions & 14 deletions e2e_tests/management/test_management.py
Expand Up @@ -408,13 +408,13 @@ def test_trash_runs_and_models(self, project, environment):

# THEN trashed runs are marked as trashed
runs = project.fetch_runs_table().to_pandas()
assert run1_id in runs[runs["sys/trashed"] == True]["sys/id"].tolist()
assert run2_id in runs[runs["sys/trashed"] == False]["sys/id"].tolist()
assert len(runs) == 1

# AND trashed models are marked as trashed
models = project.fetch_models_table().to_pandas()
assert model1_id in models[models["sys/trashed"] == True]["sys/id"].tolist()
assert model2_id in models[models["sys/trashed"] == False]["sys/id"].tolist()
assert len(models) == 1

def test_trash_model_version(self, environment):
# WITH model
Expand All @@ -436,14 +436,11 @@ def test_trash_model_version(self, environment):

# THEN expect this version to be trashed
model_versions = model.fetch_model_versions_table().to_pandas()
assert (
model_version1
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert (
model_version2
in model_versions[model_versions["sys/trashed"] == False]["sys/id"].tolist()
)
assert len(model_versions) == 1

# WHEN whole model is trashed
trash_objects(environment.project, model_id)
Expand All @@ -452,11 +449,4 @@ def test_trash_model_version(self, environment):

# THEN expect all its versions to be trashed
model_versions = model.fetch_model_versions_table().to_pandas()
assert (
model_version1
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert (
model_version2
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert len(model_versions) == 0
11 changes: 9 additions & 2 deletions src/neptune/new/internal/backends/nql.py
Expand Up @@ -22,6 +22,7 @@
"NQLQueryAttribute",
]

import typing
from dataclasses import dataclass
from enum import Enum
from typing import Iterable
Expand Down Expand Up @@ -62,14 +63,20 @@ class NQLAttributeType(str, Enum):
STRING = "string"
STRING_SET = "stringSet"
EXPERIMENT_STATE = "experimentState"
BOOLEAN = "bool"


@dataclass
class NQLQueryAttribute(NQLQuery):
name: str
type: NQLAttributeType
operator: NQLAttributeOperator
value: str
value: typing.Union[str, bool]

def __str__(self) -> str:
return f'(`{self.name}`:{self.type.value} {self.operator.value} "{self.value}")'
if isinstance(self.value, bool):
value = str(self.value).lower()
else:
value = f'"{self.value}"'

return f"(`{self.name}`:{self.type.value} {self.operator.value} {value})"
23 changes: 18 additions & 5 deletions src/neptune/new/metadata_containers/model.py
Expand Up @@ -16,8 +16,10 @@
from typing import Iterable, Optional

from neptune.new.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQueryAggregate,
NQLQueryAttribute,
)
from neptune.new.internal.container_type import ContainerType
Expand Down Expand Up @@ -102,11 +104,22 @@ def fetch_model_versions_table(self, columns: Optional[Iterable[str]] = None) ->
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL_VERSION,
query=NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
query=NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
),
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
),
],
aggregator=NQLAggregator.AND,
),
columns=columns,
)
17 changes: 14 additions & 3 deletions src/neptune/new/metadata_containers/project.py
Expand Up @@ -21,7 +21,6 @@
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLEmptyQuery,
NQLQueryAggregate,
NQLQueryAttribute,
)
Expand Down Expand Up @@ -96,7 +95,14 @@ def _metadata_url(self) -> str:

@staticmethod
def _prepare_nql_query(ids, states, owners, tags):
query_items = []
query_items = [
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
)
]

if ids:
query_items.append(
Expand Down Expand Up @@ -304,7 +310,12 @@ def fetch_models_table(self, columns: Optional[Iterable[str]] = None) -> Table:
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL,
query=NQLEmptyQuery(),
query=NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
),
columns=columns,
)

Expand Down
11 changes: 11 additions & 0 deletions tests/neptune/new/internal/backends/test_nql.py
Expand Up @@ -59,6 +59,17 @@ def test_attributes(self):
),
'(`sys/state`:experimentState = "running")',
)
self.assertEqual(
str(
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
)
),
"(`sys/trashed`:bool = false)",
)

def test_multiple_attribute_values(self):
self.assertEqual(
Expand Down

0 comments on commit bfffee4

Please sign in to comment.