Skip to content

Commit

Permalink
Fixes #4141 (#4169)
Browse files Browse the repository at this point in the history
* fix val epoch agg

* fix val agg metrics

* fix val agg metrics

* fix val agg metrics
  • Loading branch information
williamFalcon committed Oct 15, 2020
1 parent f064682 commit 45d05ff
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 10 deletions.
35 changes: 35 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def get_epoch_log_metrics(self) -> dict:
if k == '_internal':
continue

if options['forked']:
continue

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
Expand All @@ -299,6 +302,9 @@ def get_epoch_pbar_metrics(self):
if k == '_internal':
continue

if options['forked']:
continue

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
Expand All @@ -311,6 +317,22 @@ def get_epoch_pbar_metrics(self):

return result

def get_forked_metrics(self):
"""
Gets the metrics to log at the end of epoch
"""
result = {}

meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue

if options['forked']:
result[k] = self[k]

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
"""
Gets the metrics to log at the end of the batch step
Expand Down Expand Up @@ -443,6 +465,11 @@ def reduce_on_epoch_end(cls, outputs):
if k == '_internal' or isinstance(result[k], Metric):
continue

# for forked metrics don't reduce, just take the last val
if option['forked']:
result[k] = choose_last(result[k])
continue

if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
Expand Down Expand Up @@ -531,6 +558,14 @@ def rename_keys(self, map_dict: dict):
del meta[source]


def choose_last(x):
if isinstance(x, (torch.Tensor, list)):
return x[-1]
if isinstance(x, dict):
for k, v in x.items():
x[k] = x[k][-1]


def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
if 'meta' in out:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,16 @@ def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
self.callback_metrics.update(logger_metrics)
self.callback_metrics.update(pbar_metrics)

# forked metrics were dropped, enable them for callbacks
forked_metrics = reduced_epoch_metrics.get_forked_metrics()
self.callback_metrics.update(forked_metrics)

# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
if len(epoch_logger_metrics) > 0:
metrics_to_log.append(epoch_logger_metrics)
if len(logger_metrics) > 0:
metrics_to_log.append(logger_metrics)

# log all the metrics as a s single dict
metrics_to_log = dict(ChainMap(*metrics_to_log))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode):
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
model = self.trainer.get_model()

# reset results
model._results = Result()

# with a single dataloader don't pass an array
outputs = self.outputs
eval_results = outputs
Expand Down
3 changes: 1 addition & 2 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def _ddp_test_fn(rank, worldsize):

epoch_expected = {
"b": cumulative_sum * worldsize,
"a": cumulative_sum * worldsize,
"a_epoch": cumulative_sum * worldsize
}

Expand Down Expand Up @@ -136,7 +135,7 @@ def test_result_metric_integration():
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']

epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}
epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum}

assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
Expand Down
74 changes: 70 additions & 4 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks
from pytorch_lightning import callbacks, seed_everything
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule, BoringModel
import os
Expand Down Expand Up @@ -68,7 +68,6 @@ def backward(self, loss, optimizer, optimizer_idx):
'a2',
'a_step',
'a_epoch',
'b',
'b_step/epoch_0',
'b_step/epoch_1',
'b_epoch',
Expand Down Expand Up @@ -142,12 +141,10 @@ def backward(self, loss, optimizer, optimizer_idx):
'b_step',
'b_epoch',
'c',
'd',
'd_step/epoch_0',
'd_step/epoch_1',
'd_epoch',
'e',
'f',
'f_step/epoch_0',
'f_step/epoch_1',
'f_epoch',
Expand Down Expand Up @@ -247,6 +244,75 @@ def validation_step(self, batch, batch_idx):
assert logged_metrics == expected_logged_metrics


def test_eval_logging_auto_reduce(tmpdir):
"""
Tests that only training_step can be used
"""
seed_everything(1234)

os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
self.seen_vals = []
self.manual_epoch_end_mean = None

def on_validation_epoch_start(self) -> None:
self.seen_vals = []

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.seen_vals.append(loss)
self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
for passed_in, manually_tracked in zip(outputs, self.seen_vals):
assert passed_in['x'] == manually_tracked
self.manual_epoch_end_mean = torch.stack([x['x'] for x in outputs]).mean()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=3,
limit_val_batches=3,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
checkpoint_callback=callbacks.ModelCheckpoint('val_loss')
)
trainer.fit(model)

# make sure all the metrics are available for callbacks
manual_mean = model.manual_epoch_end_mean
callback_metrics = set(trainer.callback_metrics.keys())
assert callback_metrics == {'debug_epoch', 'val_loss', 'val_loss_epoch'}

# make sure values are correct
assert trainer.logged_metrics['val_loss_epoch'] == manual_mean
assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step/epoch_0']

# make sure correct values were logged
logged_val = trainer.dev_debugger.logged_metrics

# sanity check
assert logged_val[0]['global_step'] == 0
assert logged_val[1]['global_step'] == 0

# 3 val batches
assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[0]
assert logged_val[3]['val_loss_step/epoch_0'] == model.seen_vals[1]
assert logged_val[4]['val_loss_step/epoch_0'] == model.seen_vals[2]

# epoch mean
assert logged_val[5]['val_loss_epoch'] == model.manual_epoch_end_mean

# only those logged
assert len(logged_val) == 6


def test_monitor_val_epoch_end(tmpdir):
epoch_min_loss_override = 0
model = SimpleModule()
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/logging/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def val_dataloader(self):

generated = set(trainer.logger_connector.logged_metrics)
expected = {
'a_epoch', 'a',
'n', 'n_step/epoch_0', 'n_epoch',
'a_epoch',
'n_step/epoch_0', 'n_epoch',
'epoch'
}

Expand Down

0 comments on commit 45d05ff

Please sign in to comment.