Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Load / Save Plugin #1114

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions bandit/blacklists/calls.py
Expand Up @@ -320,6 +320,19 @@
| | | - os.tmpnam | |
+------+---------------------+------------------------------------+-----------+

B704: pytorch_load_save

Use of unsafe PyTorch load. `torch.load` can lead to arbitrary code execution,
and improper use of `torch.save` might expose sensitive data or lead to data
corruption.

+------+---------------------+--------------------------------------+-----------+
| ID | Name | Calls | Severity |
+======+=====================+======================================+===========+
| B704 | pytorch_load_save| | - torch.load | Medium |
| B704 | pytorch_load_save| | - torch.save | Medium |
+------+---------------------+--------------------------------------+-----------+
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

"""
import sys

Expand Down Expand Up @@ -685,6 +698,17 @@ def gen_blacklist():
)
)

sets.append(
utils.build_conf_dict(
"pytorch_load_save",
"B704",
issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
["torch.load", "torch.save"],
"Use of unsafe PyTorch load or save",
"MEDIUM",
)
)

# skipped B324 (used in bandit/plugins/hashlib_new_insecure_functions.py)

# skipped B325 as the check for a call to os.tempnam and os.tmpnam have
Expand Down
68 changes: 68 additions & 0 deletions bandit/plugins/pytorch_load_save.py
@@ -0,0 +1,68 @@
# Copyright (c) 2024 Stacklok, Inc.
#
# SPDX-License-Identifier: Apache-2.0
r"""
=========================================
B704: Test for unsafe PyTorch load or save
=========================================
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.

:Example:

.. code-block:: none

>> Issue: Use of unsafe PyTorch load or save
Severity: Medium Confidence: High
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
Location: examples/pytorch_load_save.py:8
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
8 another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
9
10 print("Model loaded successfully!")

.. seealso::

- https://cwe.mitre.org/data/definitions/94.html

lukehinds marked this conversation as resolved.
Show resolved Hide resolved
.. versionadded:: 1.7.8
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

"""
import bandit
from bandit.core import issue
from bandit.core import test_properties as test


@test.checks("Call")
@test.test_id(
"B704"
) # Ensure the test ID is unique and does not conflict with existing Bandit tests
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
def pytorch_load_save(context):
"""
This plugin checks for the use of `torch.load` and `torch.save`. Using `torch.load`
with untrusted data can lead to arbitrary code execution, and improper use of
`torch.save` might expose sensitive data or lead to data corruption.
"""
imported = context.is_module_imported_exact("torch")
qualname = context.call_function_name_qual
if not imported and isinstance(qualname, str):
return

qualname_list = qualname.split(".")
func = qualname_list[-1]
if all(
[
"torch" in qualname_list,
func in ["load"],
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
not context.check_call_arg_value("map_location", "cpu"),
]
):
return bandit.Issue(
severity=bandit.MEDIUM,
confidence=bandit.HIGH,
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
text="Use of unsafe PyTorch load or save",
lukehinds marked this conversation as resolved.
Show resolved Hide resolved
cwe=issue.Cwe.UNTRUSTED_INPUT,
lineno=context.get_lineno_for_call_arg("load"),
)
5 changes: 5 additions & 0 deletions doc/source/plugins/b704_pytorch_load_save.rst
@@ -0,0 +1,5 @@
-----------------------
B704: pytorch_load_save
-----------------------

.. automodule:: bandit.plugins.pytorch_load_save
16 changes: 16 additions & 0 deletions examples/pytorch_load_save.py
@@ -0,0 +1,16 @@
import torch
import torchvision.models as models

# Example of saving a model
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# Example of loading the model weights in an insecure way
loaded_model = models.resnet18()
loaded_model.load_state_dict(torch.load('model_weights.pth'))

# Another example using torch.load with more parameters
another_model = models.resnet18()
another_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
lukehinds marked this conversation as resolved.
Show resolved Hide resolved

print("Model loaded successfully!")
3 changes: 3 additions & 0 deletions setup.cfg
Expand Up @@ -148,6 +148,9 @@ bandit.plugins =
#bandit/plugins/tarfile_unsafe_members.py
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members

#bandit/plugins/pytorch_load_save.py
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save

[build_sphinx]
all_files = 1
build-dir = doc/build
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/test_functional.py
Expand Up @@ -930,3 +930,11 @@ def test_tarfile_unsafe_members(self):
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 2},
}
self.check_example("tarfile_extractall.py", expect)

def test_pytorch_load_save(self):
"""Test insecure usage of torch.load and torch.save."""
expect = {
"SEVERITY": {"UNDEFINED": 0, "LOW": 1, "MEDIUM": 3, "HIGH": 0},
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 1, "HIGH": 3},
}
self.check_example("pytorch_load_save.py", expect)