Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize argument-free object params before logging #19771

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560)

- Added object sanitization before logging hyperparameters to WandB and Neptune [#19771](https://github.com/Lightning-AI/pytorch-lightning/pull/19771)

- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))

- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))
Expand All @@ -17,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870), [#19872](https://github.com/Lightning-AI/pytorch-lightning/pull/19872))


### Changed

- Renamed `lightning run model` to `fabric run` ([#19442](https://github.com/Lightning-AI/pytorch-lightning/pull/19442), [#19527](https://github.com/Lightning-AI/pytorch-lightning/pull/19527))
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/fabric/utilities/logger.py
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
from argparse import Namespace
from dataclasses import asdict, is_dataclass
Expand Down Expand Up @@ -52,6 +54,10 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
"""

def _sanitize_callable(val: Any) -> Any:
# If it's a class, return the name; otherwise for a class without any initialization arguments,
# the configuration will store an instance of the class, which is not what we want.
if inspect.isclass(val):
return getattr(val, "__name__", None)
# Give them one chance to return a value. Don't go rabbit hole of recursive call
if callable(val):
try:
Expand Down
17 changes: 16 additions & 1 deletion tests/tests_fabric/utilities/test_logger.py
Expand Up @@ -92,7 +92,7 @@ class B:


def test_sanitize_callable_params():
"""Callback function are not serializiable.
"""Callback functions are not serializable.

Therefore, we get them a chance to return something and if the returned type is not accepted, return None.

Expand All @@ -104,11 +104,24 @@ def return_something():
def wrapper_something():
return return_something

class Something:
def __init__(self):
pass

class SomethingElse:
def __init__(self, arg):
self.arg = arg

def __repr__(self):
return "SomethingElseElse"

params = Namespace(
foo="bar",
something=return_something,
wrapper_something_wo_name=(lambda: lambda: "1"),
wrapper_something=wrapper_something,
something_class=Something,
something_else=SomethingElse,
)

params = _convert_params(params)
Expand All @@ -118,6 +131,8 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
assert params["something_class"] == "Something"
assert params["something_else"] == "SomethingElseElse"


def test_sanitize_params():
Expand Down