Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #4141 #4169

Merged
merged 4 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def get_epoch_log_metrics(self) -> dict:
if k == '_internal':
continue

if options['forked']:
continue

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
Expand All @@ -299,6 +302,9 @@ def get_epoch_pbar_metrics(self):
if k == '_internal':
continue

if options['forked']:
continue

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
Expand All @@ -311,6 +317,22 @@ def get_epoch_pbar_metrics(self):

return result

def get_forked_metrics(self):
"""
Gets the metrics to log at the end of epoch
"""
result = {}

meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue

if options['forked']:
result[k] = self[k]

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
"""
Gets the metrics to log at the end of the batch step
Expand Down Expand Up @@ -443,6 +465,11 @@ def reduce_on_epoch_end(cls, outputs):
if k == '_internal' or isinstance(result[k], Metric):
continue

# for forked metrics don't reduce, just take the last val
if option['forked']:
result[k] = choose_last(result[k])
continue

if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
Expand Down Expand Up @@ -531,6 +558,14 @@ def rename_keys(self, map_dict: dict):
del meta[source]


def choose_last(x):
if isinstance(x, (torch.Tensor, list)):
return x[-1]
if isinstance(x, dict):
for k, v in x.items():
x[k] = x[k][-1]


def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
if 'meta' in out:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,16 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
self.callback_metrics.update(logger_metrics)
self.callback_metrics.update(pbar_metrics)

# forked metrics were dropped, enable them for callbacks
forked_metrics = reduced_epoch_metrics.get_forked_metrics()
self.callback_metrics.update(forked_metrics)

# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)
if len(logger_metrics) > 0:
metrics_to_log.append(logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode):
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
model = self.trainer.get_model()

# reset results
model._results = Result()

# with a single dataloader don't pass an array
outputs = self.outputs
eval_results = outputs
Expand Down
3 changes: 1 addition & 2 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _ddp_test_fn(rank, worldsize):

epoch_expected = {
"b": cumulative_sum * worldsize,
"a": cumulative_sum * worldsize,
"a_epoch": cumulative_sum * worldsize
}

Expand Down Expand Up @@ -136,7 +135,7 @@ def test_result_metric_integration():
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']

epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}
epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum}

assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
Expand Down
74 changes: 70 additions & 4 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks
from pytorch_lightning import callbacks, seed_everything
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule, BoringModel
import os
Expand Down Expand Up @@ -68,7 +68,6 @@ def backward(self, loss, optimizer, optimizer_idx):
'a2',
'a_step',
'a_epoch',
'b',
'b_step/epoch_0',
'b_step/epoch_1',
'b_epoch',
Expand Down Expand Up @@ -142,12 +141,10 @@ def backward(self, loss, optimizer, optimizer_idx):
'b_step',
'b_epoch',
'c',
'd',
'd_step/epoch_0',
'd_step/epoch_1',
'd_epoch',
'e',
'f',
'f_step/epoch_0',
'f_step/epoch_1',
'f_epoch',
Expand Down Expand Up @@ -247,6 +244,75 @@ def validation_step(self, batch, batch_idx):
assert logged_metrics == expected_logged_metrics


def test_eval_logging_auto_reduce(tmpdir):
"""
Tests that only training_step can be used
"""
seed_everything(1234)

os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
self.seen_vals = []
self.manual_epoch_end_mean = None

def on_validation_epoch_start(self) -> None:
self.seen_vals = []

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.seen_vals.append(loss)
self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
for passed_in, manually_tracked in zip(outputs, self.seen_vals):
assert passed_in['x'] == manually_tracked
self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs]).mean()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=3,
limit_val_batches=3,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
checkpoint_callback=callbacks.ModelCheckpoint('val_loss')
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
manual_mean = model.manual_epoch_end_mean
callback_metrics = set(trainer.callback_metrics.keys())
assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'}

# make sure values are correct
assert trainer.logged_metrics['val_loss_epoch'] == manual_mean
assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step/epoch_0']

# make sure correct values were logged
logged_val = trainer.dev_debugger.logged_metrics

# sanity check
assert logged_val[0]['global_step'] == 0
assert logged_val[1]['global_step'] == 0

# 3 val batches
assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[0]
assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[1]
assert logged_val[4]['val_loss_step/epoch_0'] == model.seen_vals[2]

# epoch mean
assert logged_val[5]['val_loss_epoch'] == model.manual_epoch_end_mean

# only those logged
assert len(logged_val) == 6


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/logging/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def val_dataloader(self):

generated = set(trainer.logger_connector.logged_metrics)
expected = {
'a_epoch', 'a',
'n', 'n_step/epoch_0', 'n_epoch',
'a_epoch',
'n_step/epoch_0', 'n_epoch',
'epoch'
}

Expand Down