Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡🩹 Fix lightning training without validation #1158

Merged
merged 12 commits into from
Nov 19, 2022
2 changes: 2 additions & 0 deletions src/pykeen/contrib/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def train_dataloader(self):
def val_dataloader(self):
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
"""Create the validation data loader."""
# TODO: In sLCWA, we still want to calculate validation *metrics* in LCWA
if self.dataset.validation is None:
return []
return self._dataloader(triples_factory=self.dataset.validation, shuffle=False)

def configure_optimizers(self):
Expand Down
2 changes: 1 addition & 1 deletion src/pykeen/stoppers/stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_summary_dict(self) -> Mapping[str, Any]:
"""Get a summary dict."""
raise NotImplementedError

def _write_from_summary_dict(
def _write_from_summary_dict( # noqa: B027
self,
*,
frequency: int,
Expand Down
28 changes: 27 additions & 1 deletion tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
except ImportError:
LIT_MODULES = []
from pykeen import models
from pykeen.datasets import get_dataset
from pykeen.datasets import EagerDataset, get_dataset
from pykeen.typing import TRAINING

EMBEDDING_DIM = 8
Expand Down Expand Up @@ -103,3 +103,29 @@ def test_lit_training(model, model_kwargs, training_loop):
max_epochs=2,
),
)


def test_lit_pipeline_with_dataset_without_validation():
"""Test training on a dataset without validation triples."""
from pykeen.contrib.lightning import lit_pipeline
cthoyt marked this conversation as resolved.
Show resolved Hide resolved

dataset = get_dataset(dataset="nations")
dataset = EagerDataset(training=dataset.training, testing=dataset.testing, metadata=dataset.metadata)

lit_pipeline(
training_loop="slcwa",
training_loop_kwargs=dict(
model="transe",
dataset=dataset,
),
trainer_kwargs=dict(
# automatically choose accelerator
accelerator="auto",
# defaults to TensorBoard; explicitly disabled here
logger=False,
# disable checkpointing
enable_checkpointing=False,
# fast run
max_epochs=2,
),
)