Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetching only non trashed metadata containers #1033

Merged
merged 3 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,7 @@
### Changes
- More consistent and strict way of git repository, source files and entrypoint detection ([#1007](https://github.com/neptune-ai/neptune-client/pull/1007))
- Moved neptune and neptune_cli to src dir ([#1027](https://github.com/neptune-ai/neptune-client/pull/1027))
- `fetch_runs_table(...)`, `fetch_models_table(...)` and `fetch_model_versions_table(...)` now queries only non-trashed ([#1033](https://github.com/neptune-ai/neptune-client/pull/1033))

## neptune-client 0.16.9

Expand Down
18 changes: 4 additions & 14 deletions e2e_tests/management/test_management.py
Expand Up @@ -408,13 +408,13 @@ def test_trash_runs_and_models(self, project, environment):

# THEN trashed runs are marked as trashed
runs = project.fetch_runs_table().to_pandas()
assert run1_id in runs[runs["sys/trashed"] == True]["sys/id"].tolist()
assert run2_id in runs[runs["sys/trashed"] == False]["sys/id"].tolist()
assert len(runs) == 1

# AND trashed models are marked as trashed
models = project.fetch_models_table().to_pandas()
assert model1_id in models[models["sys/trashed"] == True]["sys/id"].tolist()
assert model2_id in models[models["sys/trashed"] == False]["sys/id"].tolist()
assert len(models) == 1

def test_trash_model_version(self, environment):
# WITH model
Expand All @@ -436,14 +436,11 @@ def test_trash_model_version(self, environment):

# THEN expect this version to be trashed
model_versions = model.fetch_model_versions_table().to_pandas()
assert (
model_version1
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert (
model_version2
in model_versions[model_versions["sys/trashed"] == False]["sys/id"].tolist()
)
assert len(model_versions) == 1

# WHEN whole model is trashed
trash_objects(environment.project, model_id)
Expand All @@ -452,11 +449,4 @@ def test_trash_model_version(self, environment):

# THEN expect all its versions to be trashed
model_versions = model.fetch_model_versions_table().to_pandas()
assert (
model_version1
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert (
model_version2
in model_versions[model_versions["sys/trashed"] == True]["sys/id"].tolist()
)
assert len(model_versions) == 0
11 changes: 9 additions & 2 deletions src/neptune/new/internal/backends/nql.py
Expand Up @@ -22,6 +22,7 @@
"NQLQueryAttribute",
]

import typing
from dataclasses import dataclass
from enum import Enum
from typing import Iterable
Expand Down Expand Up @@ -62,14 +63,20 @@ class NQLAttributeType(str, Enum):
STRING = "string"
STRING_SET = "stringSet"
EXPERIMENT_STATE = "experimentState"
BOOLEAN = "bool"


@dataclass
class NQLQueryAttribute(NQLQuery):
name: str
type: NQLAttributeType
operator: NQLAttributeOperator
value: str
value: typing.Union[str, bool]

def __str__(self) -> str:
return f'(`{self.name}`:{self.type.value} {self.operator.value} "{self.value}")'
if isinstance(self.value, bool):
value = str(self.value).lower()
else:
value = f'"{self.value}"'

return f"(`{self.name}`:{self.type.value} {self.operator.value} {value})"
23 changes: 18 additions & 5 deletions src/neptune/new/metadata_containers/model.py
Expand Up @@ -16,8 +16,10 @@
from typing import Iterable, Optional

from neptune.new.internal.backends.nql import (
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLQueryAggregate,
NQLQueryAttribute,
)
from neptune.new.internal.container_type import ContainerType
Expand Down Expand Up @@ -102,11 +104,22 @@ def fetch_model_versions_table(self, columns: Optional[Iterable[str]] = None) ->
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL_VERSION,
query=NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
query=NQLQueryAggregate(
items=[
NQLQueryAttribute(
name="sys/model_id",
value=self._sys_id,
operator=NQLAttributeOperator.EQUALS,
type=NQLAttributeType.STRING,
),
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
),
],
aggregator=NQLAggregator.AND,
),
columns=columns,
)
17 changes: 14 additions & 3 deletions src/neptune/new/metadata_containers/project.py
Expand Up @@ -21,7 +21,6 @@
NQLAggregator,
NQLAttributeOperator,
NQLAttributeType,
NQLEmptyQuery,
NQLQueryAggregate,
NQLQueryAttribute,
)
Expand Down Expand Up @@ -96,7 +95,14 @@ def _metadata_url(self) -> str:

@staticmethod
def _prepare_nql_query(ids, states, owners, tags):
query_items = []
query_items = [
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
)
]

if ids:
query_items.append(
Expand Down Expand Up @@ -304,7 +310,12 @@ def fetch_models_table(self, columns: Optional[Iterable[str]] = None) -> Table:
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL,
query=NQLEmptyQuery(),
query=NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
),
columns=columns,
)

Expand Down
11 changes: 11 additions & 0 deletions tests/neptune/new/internal/backends/test_nql.py
Expand Up @@ -59,6 +59,17 @@ def test_attributes(self):
),
'(`sys/state`:experimentState = "running")',
)
self.assertEqual(
str(
NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=False,
)
),
"(`sys/trashed`:bool = false)",
)

def test_multiple_attribute_values(self):
self.assertEqual(
Expand Down