Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡🩹 Fix lightning training without validation #1158

Merged
merged 12 commits into from
Nov 19, 2022
8 changes: 5 additions & 3 deletions src/pykeen/contrib/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pykeen.sampling import NegativeSampler
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop
from pykeen.triples.triples_factory import CoreTriplesFactory
from pykeen.typing import InductiveMode
from pykeen.typing import InductiveMode, OneOrSequence

__all__ = [
"LitModule",
Expand Down Expand Up @@ -150,13 +150,15 @@ def _dataloader(self, triples_factory: CoreTriplesFactory, shuffle: bool = False
"""Create a data loader."""
raise NotImplementedError

def train_dataloader(self):
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Create the training data loader."""
return self._dataloader(triples_factory=self.dataset.training, shuffle=True)

def val_dataloader(self):
def val_dataloader(self) -> OneOrSequence[torch.utils.data.DataLoader]:
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
"""Create the validation data loader."""
# TODO: In sLCWA, we still want to calculate validation *metrics* in LCWA
if self.dataset.validation is None:
return []
return self._dataloader(triples_factory=self.dataset.validation, shuffle=False)

def configure_optimizers(self):
Expand Down
39 changes: 33 additions & 6 deletions tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

import pytest

from pykeen import models
from pykeen.datasets import EagerDataset, get_dataset
from pykeen.typing import TRAINING
from tests.utils import needs_packages

try:
from pykeen.contrib.lightning import lit_module_resolver
from pykeen.contrib.lightning import lit_module_resolver, lit_pipeline

LIT_MODULES = lit_module_resolver.lookup_dict.keys()
except ImportError:
LIT_MODULES = []
from pykeen import models
from pykeen.datasets import get_dataset
from pykeen.typing import TRAINING
lit_pipeline = None


EMBEDDING_DIM = 8
# TODO: this could be shared with the model tests
Expand Down Expand Up @@ -61,12 +65,11 @@


# test combinations of models with training loops
@needs_packages("pytorch_lightning")
@pytest.mark.skipif(True, reason="instability related to https://github.com/Lightning-AI/lightning/pull/14117")
@pytest.mark.parametrize(("model", "model_kwargs", "training_loop"), TEST_CONFIGURATIONS)
def test_lit_training(model, model_kwargs, training_loop):
"""Test training models with PyTorch Lightning."""
from pykeen.contrib.lightning import lit_pipeline

# some models require inverse relations
create_inverse_triples = model is not models.RGCN
dataset = get_dataset(dataset="nations", dataset_kwargs=dict(create_inverse_triples=create_inverse_triples))
Expand Down Expand Up @@ -103,3 +106,27 @@ def test_lit_training(model, model_kwargs, training_loop):
max_epochs=2,
),
)


@needs_packages("pytorch_lightning")
def test_lit_pipeline_with_dataset_without_validation():
"""Test training on a dataset without validation triples."""
dataset = get_dataset(dataset="nations")
dataset = EagerDataset(training=dataset.training, testing=dataset.testing, metadata=dataset.metadata)
lit_pipeline(
training_loop="slcwa",
training_loop_kwargs=dict(
model="transe",
dataset=dataset,
),
trainer_kwargs=dict(
# automatically choose accelerator
accelerator="auto",
# defaults to TensorBoard; explicitly disabled here
logger=False,
# disable checkpointing
enable_checkpointing=False,
# fast run
max_epochs=2,
),
)