Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Simpler mechanism to publish data from LightningModule and Callback to Trainer #11715

Closed
ananthsub opened this issue Feb 3, 2022 · 4 comments
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()`

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Feb 3, 2022

🚀 Feature

This RFC expresses a desire for a simpler mechanism to send metric data from the LightningModule and Callback to the Trainer.

Motivation

The current approach with LightningModule.log has many downsides:

  1. Poor debugging experience: we started seeing this sort of failure ([RFC] Depreceate the move_metrics_to_cpu Trainer argument. #10595) after Save the loop progress state by default #10784 . This is the example stacktrace: https://gist.github.com/ananthsub/45c154145d0f852503c6a547f59e91f0 . It is very hard to tell where in logging did I go wrong. We see this error even after updating our torchmetrics dependency.

  2. LightningModule.log conflates too many things:

  • Handles synchronization of tensors, allowing user-provided syncing functions
  • Handles aggregation of tensors across steps and for the end of the epoch
  • Selectively samples what data to log based on log_every_n_steps and the global step
  • Makes the trainer aware of which metrics to reset at the end of the epoch
  • Also handles partitioning keys based on the dataloader idx??
  • Supports passing the batch_size for weighted averages to take into account. if the batch_size isn't passed, Lightning silently makes its best attempt to figure this out (not foolproof!)

Many of these assumptions came from the original Result object. This was a class that preceded the whole torchmetrics project.

The log API is not straightforward to use given the large number of options available, and differing implementation differences when logging floats/tensors vs torchmetric Metric objects.

Pitch

Provide a new API like this:

@dataclass
class PutData:
    name: str
    val: torch.Tensor
    destinations: List[str]
    timestamp 
    step: Optional[int] = None

class PutMixin:
    def __init__(self):
        self.pl_put_records: List[PutData] = []

    def put(self, name: str, val: Union[float, torch.Tensor], destinations=Optional[List[str]] = None) -> None:
        destinations = destinations or ["progress_bar", "callbacks", "loggers"]
        self.pl_put_records.append(PutData(name=name, val=val, destinations=destinations, timestamp=time.monotonic()))
    def put_dict(self, dictionary: Mapping[str, Union[float, torch.Tensor]], destinations=Optional[List[str]] = None) -> None:
        for k, v in dictionary:
            self.put(k, v, destinations)

example calling code:

from torchmetrics import MeanMetric

def __init__(self,...):
    self.loss_avg = MeanMetric()
    self.metric = ...
    self.my_fancy_metric = ...
    
def training_step(self, batch, batch_idx):
    loss = compute_loss(batch)
    self.loss_avg.update(loss, batch_size(batch)) # no more guess work from the trainer
    metric.update(batch, loss)
    self.put("loss", loss_avg.compute()) # the user always passes tensors instead of a mix of tensors and Metric instances
    self.put("metric", metric.compute(), destinations=("callbacks") ) # only send this info to callbacks, not loggers
    with self.my_fancy_metric.sync_context():
        self.put("fancy_metric", self.my_fancy_metric.compute())
    return loss

def on_train_epoch_end(self):
    self.loss_avg.reset()
    self.metric.reset()
    self.my_fancy_metric.reset()

The trainer already calls all of the hooks offered by the LightningModule & Callback APIs. We have the logic in the trainer here that can inspect the data and reset it after every hook is called: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/trainer.py#L1522-L1587. Which means data can be taken from here, attached with the global_step or other metadata the trainer is aware of, and routed to the relevant destinations (callbacks/loggers/metrics).

Pros:

  • Simpler API for users with fewer side-effects to consider, especially around checkpointing & syncing states
  • Simpler API means it's easier to onboard for new users. No need to get familiar with different arguments for sync_dist, on_step/on_epoch, rank_zero_only or metric_attribute amongst others
  • Simpler implementation that backs this. This is critical for users to debug failures. A secondary benefit is it's easier for developers to maintain the framework over time.
  • No need to duplicate logic between torchmetrics and Lightning. All of the metric syncing logic is delegated to torchmetrics in user land, which solves this more elegantly
  • Users should be able to call publish these metrics anytime. There should be no restrictions around when data is stored, unlike log today.
  • We don't need to store the current_fx name on the module before calling each lightningmodule hook, just so that we could set the defaults for on_step and on_epoch inside of log
  • The user-facing API can be wrapped in a small mixin to share between the LightningModule & Callback. Then the Trainer doesn't need to dynamically patch the LightningModule's log onto the Callback anymore: https://github.com/PyTorchLightning/pytorch-lightning/blob/9ebd7df22acc6e0de4569edacd0ec8319ab4be21/pytorch_lightning/trainer/connectors/callback_connector.py#L258-L262
  • With loss explicitly tracked as a metric from the user side, the user can precisely specify the batch size to compute the correct weighted average. There is no need for Lightning to try to guess the batch size from the black-box batch object, which could silently fail.
  • A new name like put makes clear that it's separate from logging like Python logging and Lightning's own Loggers. This is generally a means through which the user passes data to the Trainer for usage in other places like callbacks, progress bar, or loggers.

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @tchaton @justusschock @awaelchli @carmocca @edward-io @ananthsub @rohitgr7 @kamil-kaczmarek @Raalsky @Blaizzy

@ananthsub ananthsub added feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()` design Includes a design discussion labels Feb 3, 2022
@carmocca
Copy link
Member

carmocca commented Feb 4, 2022

The design of "logging" has been iteratively developed since before 1.0. log is the most minimalistic user-facing API that supports all of "logging" and feeds data into callbacks, the progress bar, and loggers.

A small recap of the current design:

  • The LightningModule and Callbacks generate results through log. These results are basically torchmetrics.Metrics with added metadata saved in a dictionary (the ResultCollection).
  • The Trainer manages some hook calling information for error checking. The trainer loops notify the LoggerConnector of the loop state.
  • The Logger connector acts as a middleman between the loops and the results. It returns the data computed from the results at specific points in training.

Poor debugging experience

The debugging experience is hard by design. A minimalistic API will always be a tradeoff to flexibility and visibility. This is great for the general usebase but tricky when things go wrong. It's not really a problem but a tradeoff.

LightningModule.log conflates too many things

It might seem that way from the viewpoint of a user unaware of the internal design but different components take responsibility of the different actions:

Handles synchronization of tensors, allowing user-provided syncing functions

ResultMetric takes care of this. The class is a Metric or a Metric proxy depending on what the user logged. log merely processes the data from the user into a result.

Handles aggregation of tensors across steps and for the end of the epoch

The trainer loops drive this, using the LoggerConnector as a middleman.

Selectively samples what data to log based on log_every_n_steps and the global step

This is not really part of logging but part of loggers. I understand that logging and loggers are basically the same word, but internally they are completely different things. It might be easier to think of loggers as writers. log_every_n_steps is more like write_every_n_steps

Makes the trainer aware of which metrics to reset at the end of the epoch

I'm not sure what you mean with this point. Everything is reset at the moment. There's an issue suggesting customization: #11262 (comment)

Also handles partitioning keys based on the dataloader idx??

Yes, but why would that be a problem? One of the biggest features of Lightning is to be able to easily use multiple dataloaders.

and differing implementation differences when logging floats/tensors vs torchmetric Metric objects.

This is on purpose as we try to be as efficient as possible when tensors are logged. We could simplify a lot if we always forced users to use Metrics but then we are adding boilerplate for simple use cases.

Provide a new API...

Honestly, there's nothing wrong with your pitch. It basically has the same underlying ideas as log has. And log would be just like that if we were to drop many of the features and invisible management that it does (again simplicity vs flexibility).

However, if you were to take this proposal and try supporting all the features log does, I think you would end up with something very similar.

One thing to note is that logging is technically optional to lightning. You could choose to not log and use the proposed alternative. Callbacks and loggers would work as long as the *_metrics dictionaries are populated. You could even avoid the logger connector via loop customization.
I'm saying this in case you want to work on a PoC to try things out. But I don't think this would ever replace the current features and user-facing design of log. At best it would just be two options with different degrees of management and flexibility.

@mannatsingh
Copy link

mannatsingh commented Mar 3, 2022

I think @ananthsub 's pitch mitigates a bunch of the issues I have run into while using metrics in lightning.

Sharing my experience as a researcher about the problems I have run into -

I tried using torchmetrics inside lightning and noticed a few peculiar behaviors. After spending a few hours trying to make sure there's no bugs, I wanted to note my observations so that other users can make note of them in case that's useful and ask for what the recommendation to log meters is.

  • It seems that if you derive from Metric and set compute_on_step=True (default behavior), the metric is computed twice - once how you would expect a metric to work, and then the second time by resetting the internal state, computing the metric and then restoring the state back.
    • As a user, I saw reset being called at every step which got me really worried and I had to debug what was happening. I think some docs around this are really important! Ideally, this "magic" shouldn't ever happen where the internal state changes for a while so that users don't run into the issue, but having documentation is the next best thing. I realized later on that the part about calling the metric twice is mentioned in https://torchmetrics.readthedocs.io/en/stable/pages/implement.html. As a new user though, I think these docs should be placed in the overview section - even if I don't implement metrics, I should know what they're doing.
    • Secondly, the current docs say that compute_on_step is deprecated but don't say what the default behavior will be.
  • When logging any numbers with lightning through self.log(my_dict, on_epoch=True), the sync_dist option is set to False by default. This tells me that this is a simple logging function which doesn't do any communication - but that was not what was happening.
  • It seems to me that for logging meters (or other values for that matter), users should always do the syncs on their side and not rely on the logging logic (even if it works today for accuracy, it'll introduce a bug in the future when the reduction method isn't all reduce mean).
  • If I do want full control over syncs, what is the recommended approach to log my data to the logger? How do I disable the sync at the end of an epoch?
    • For logging at the end of an epoch I could call self.log(on_step=True) but that statement will be confusing for readers.
    • The docs for rank_zero_only are really confusing - "Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call." - am I supposed to ensure this statement is only called on rank 0 on my end? What happens when on_epoch=True here?
  • How do I log this information to the python logger? I don't see support for this. I could recompute metrics, but given lightning owns the synced data inside self.log(), I don't have access to this information. I want to make a consistent value is logged everywhere. I could implement a custom logger, but I'm surprised it's not already there - the progress bar isn't a great experience.
  • Lastly, I realized that compute() also forces a synchronization. I actually don't know if there is any way to get the current state of the meter without a synchronization - this seems like an important feature to have. Also, are these behaviors for torchmetrics different inside / outside lightning, i.e. does lightning change the behavior of compute to add the sync call? That should not be the case!
  • I think there will be some solution that exists to support my exact requirements, but the high level issue is that there should generally only be one way to do things - and that's not the case in this setup and that gets me worried.

The debugging experience is hard by design. A minimalistic API will always be a tradeoff to flexibility and visibility. This is great for the general usebase but tricky when things go wrong. It's not really a problem but a tradeoff.

I think this is a fine call to make - but if that is made, then we should be clear in saying lightning chooses a "minimalistic API" over debuggability and user control - there is no free lunch. For certain things I see no way in the docs to take control over form lightning (like how do I compute a metric without the syncing?).

@mannatsingh
Copy link

mannatsingh commented Mar 3, 2022

It seems the compute() syncing stuff is mentioned in the docs - https://torchmetrics.readthedocs.io/en/latest/pages/implement.html. I'm not sure if I can just use _compute() to avoid syncs, and then later call compute() when I do need a sync?
The code in https://github.com/facebookresearch/recipes/blob/main/torchrecipes/vision/image_classification/metrics/multilabel_accuracy.py confused me since it overrides update() and compute() directly.

@carmocca carmocca self-assigned this Apr 26, 2022
@carmocca
Copy link
Member

The default compute sync on-epoch has been removed in #13364. I believe this fixes most of your issues @mannatsingh.

I would like to re-iterate that one could circumvent the logging internals entirely by using torchmetrics and calling update, compute, reset at the desired hooks (or torchevals equivalents: https://github.com/pytorch-labs/torcheval#using-torcheval). The results of which can be directly sent to the loggers as desired and/or made visible to our callbacks by updating trainer.callback_metrics.

Doing so is very different to "magically" self.logging inside a hook. But that would be expected.

If you (or any future readers) have issues with getting a self-managed logging solution like this to work, I will be happy to help with bugs or refactors to make the internals more flexible. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement logging Related to the `LoggerConnector` and `log()`
Projects
No open projects
Status: No status
Development

No branches or pull requests

3 participants