Skip to content

Commit

Permalink
⚡🩹 Fix lightning training without validation (#1158)
Browse files Browse the repository at this point in the history
* add test for training on dataset without validation
* add fix for empty validation set

Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>
  • Loading branch information
mberr and cthoyt committed Nov 19, 2022
1 parent febbb27 commit ec6a71c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
8 changes: 5 additions & 3 deletions src/pykeen/contrib/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pykeen.sampling import NegativeSampler
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop
from pykeen.triples.triples_factory import CoreTriplesFactory
from pykeen.typing import InductiveMode
from pykeen.typing import InductiveMode, OneOrSequence

__all__ = [
"LitModule",
Expand Down Expand Up @@ -150,13 +150,15 @@ def _dataloader(self, triples_factory: CoreTriplesFactory, shuffle: bool = False
"""Create a data loader."""
raise NotImplementedError

def train_dataloader(self):
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Create the training data loader."""
return self._dataloader(triples_factory=self.dataset.training, shuffle=True)

def val_dataloader(self):
def val_dataloader(self) -> OneOrSequence[torch.utils.data.DataLoader]:
"""Create the validation data loader."""
# TODO: In sLCWA, we still want to calculate validation *metrics* in LCWA
if self.dataset.validation is None:
return []
return self._dataloader(triples_factory=self.dataset.validation, shuffle=False)

def configure_optimizers(self):
Expand Down
39 changes: 33 additions & 6 deletions tests/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

import pytest

from pykeen import models
from pykeen.datasets import EagerDataset, get_dataset
from pykeen.typing import TRAINING
from tests.utils import needs_packages

try:
from pykeen.contrib.lightning import lit_module_resolver
from pykeen.contrib.lightning import lit_module_resolver, lit_pipeline

LIT_MODULES = lit_module_resolver.lookup_dict.keys()
except ImportError:
LIT_MODULES = []
from pykeen import models
from pykeen.datasets import get_dataset
from pykeen.typing import TRAINING
lit_pipeline = None


EMBEDDING_DIM = 8
# TODO: this could be shared with the model tests
Expand Down Expand Up @@ -61,12 +65,11 @@


# test combinations of models with training loops
@needs_packages("pytorch_lightning")
@pytest.mark.skipif(True, reason="instability related to https://github.com/Lightning-AI/lightning/pull/14117")
@pytest.mark.parametrize(("model", "model_kwargs", "training_loop"), TEST_CONFIGURATIONS)
def test_lit_training(model, model_kwargs, training_loop):
"""Test training models with PyTorch Lightning."""
from pykeen.contrib.lightning import lit_pipeline

# some models require inverse relations
create_inverse_triples = model is not models.RGCN
dataset = get_dataset(dataset="nations", dataset_kwargs=dict(create_inverse_triples=create_inverse_triples))
Expand Down Expand Up @@ -103,3 +106,27 @@ def test_lit_training(model, model_kwargs, training_loop):
max_epochs=2,
),
)


@needs_packages("pytorch_lightning")
def test_lit_pipeline_with_dataset_without_validation():
"""Test training on a dataset without validation triples."""
dataset = get_dataset(dataset="nations")
dataset = EagerDataset(training=dataset.training, testing=dataset.testing, metadata=dataset.metadata)
lit_pipeline(
training_loop="slcwa",
training_loop_kwargs=dict(
model="transe",
dataset=dataset,
),
trainer_kwargs=dict(
# automatically choose accelerator
accelerator="auto",
# defaults to TensorBoard; explicitly disabled here
logger=False,
# disable checkpointing
enable_checkpointing=False,
# fast run
max_epochs=2,
),
)

0 comments on commit ec6a71c

Please sign in to comment.