Skip to content

Commit

Permalink
Upd depts and insure compatibility (#201)
Browse files Browse the repository at this point in the history
* Bump version

* Upd docstring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Temporarily disable Windows tests

* Unpin setuptools

* Clear cache

* Clear cache

* Constraint pip

* poetry lock

* poetry lock

* Fix: Fix import for pl>=1.9.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix: Fix import for pl>=1.9.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix: Fix compatibility for pl>=1.9.0

* Re-enable tests on Windows

* Update quaterion/train/callbacks/cleanup_callback.py

Co-authored-by: George <george.panchuk@qdrant.tech>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: George <george.panchuk@qdrant.tech>
  • Loading branch information
3 people committed Mar 29, 2023
1 parent 3f0a96c commit 1745697
Show file tree
Hide file tree
Showing 6 changed files with 1,660 additions and 926 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: [3.8, 3.9, "3.10"] # [3.7, 3.8, 3.9] # - temporary disable due to Actions spend limit
os: [ubuntu-latest, windows-latest] # [ubuntu-latest, macOS-latest, windows-latest]
python-version: [3.8, 3.9, "3.10"]
os: [ubuntu-latest, windows-latest]

steps:
- uses: actions/checkout@v1
Expand Down
2,555 changes: 1,637 additions & 918 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quaterion"
version = "0.1.34"
version = "0.1.35"
description = "Similarity Learning fine-tuning framework"
authors = ["Quaterion Authors <team@quaterion.tech>"]
packages = [
Expand Down Expand Up @@ -31,7 +31,7 @@ ipdb = "^0.13.9"
sphinx = ">=5.0.1"
qdrant-sphinx-theme = { git = "https://github.com/qdrant/qdrant_sphinx_theme.git", branch = "master" }
black = "^22.3.0"
torchvision = "^0.12.0"
torchvision = ">=0.12.0"


[tool.poetry.extras]
Expand Down
4 changes: 3 additions & 1 deletion quaterion/loss/circle_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

class CircleLoss(GroupLoss):
"""Implements Circle Loss as defined in https://arxiv.org/abs/2002.10857.
Args:
margin: Margin value to push negative examples.
scale_factor: scale factor γ determines the largest scale of each similarity score.
Refer to sections 4.1 and 4.5 in the paper for default values and evaluation of margin and scaling_factor hyperparameters.
Note:
Refer to sections 4.1 and 4.5 in the paper for default values and evaluation of margin and scaling_factor hyperparameters.
"""

def __init__(
Expand Down
13 changes: 11 additions & 2 deletions quaterion/train/callbacks/cleanup_callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.states import TrainerFn

from quaterion.train.trainable_model import TrainableModel

try: # fix for version >= 1.9.0
from pytorch_lightning import Callback
except ImportError:
from pytorch_lightning.callbacks.base import Callback


class CleanupCallback(Callback):
def teardown(
Expand All @@ -18,7 +22,12 @@ def teardown(
# If encoders were wrapped, unwrap them
pl_module.unwrap_cache()

trainer.reset_train_val_dataloaders()
try: # fix for pl>=1.9.0
trainer.reset_train_val_dataloaders()
except NotImplementedError:
trainer.reset_train_dataloader()
trainer.reset_test_dataloader()

# Restore Data Loaders if they were modified for cache
train_dataloader = trainer.train_dataloader.loaders
pl_module.setup_dataloader(train_dataloader)
Expand Down
6 changes: 5 additions & 1 deletion quaterion/train/callbacks/metrics_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback

try: # fix for version >= 1.9.0
from pytorch_lightning import Callback
except ImportError:
from pytorch_lightning.callbacks.base import Callback


class MetricsCallback(Callback):
Expand Down

0 comments on commit 1745697

Please sign in to comment.