Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Load / Save Plugin #1114

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions bandit/plugins/pytorch_load_save.py
@@ -0,0 +1,72 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
==========================================
B613: Test for unsafe PyTorch load or save
==========================================

This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution, and
improper use of `torch.save` might expose sensitive data or lead to data
corruption. A safe alternative is to use `torch.load` with the `safetensors`
library from hugingface, which provides a safe deserialization mechanism.

:Example:

.. code-block:: none

>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth',
map_location='cpu'))
9
10 print("Model loaded successfully!")

.. seealso::

- https://cwe.mitre.org/data/definitions/94.html
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
- https://github.com/huggingface/safetensors

lukehinds marked this conversation as resolved.
Show resolved Hide resolved
.. versionadded:: 1.7.9

"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id("B613")
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using
`torch.load` with untrusted data can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead
to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load", "save"],
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
text="Use of unsafe PyTorch load or save",
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
@@ -0,0 +1,5 @@
-----------------------
B613: pytorch_load_save
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
21 changes: 21 additions & 0 deletions examples/pytorch_load_save.py
@@ -0,0 +1,21 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Save the model
torch.save(loaded_model.state_dict(), 'model_weights.pth')

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

# Save the model
torch.save(another_model.state_dict(), 'model_weights.pth')

3 changes: 3 additions & 0 deletions setup.cfg
Expand Up @@ -152,6 +152,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

[build_sphinx]
all_files = 1
build-dir = doc/build
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Expand Up @@ -929,3 +929,11 @@ def test_tarfile_unsafe_members(self):
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 2},
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 1, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)