Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding secure loading of models by default for haystack #3901

Merged
merged 6 commits into from Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions haystack/__init__.py
Expand Up @@ -20,6 +20,8 @@
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline
from haystack.environment import set_pytorch_secure_model_loading


pd.options.display.max_colwidth = 80
set_pytorch_secure_model_loading()
6 changes: 6 additions & 0 deletions haystack/environment.py
Expand Up @@ -18,6 +18,12 @@
env_meta_data: Dict[str, Any] = {}


def set_pytorch_secure_model_loading(flag_val="1"):
# To load secure only model pytorch requires value of
# TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check whether this flag is set to any value different from ["1", "y", "yes", "true"] before setting it to 1. Maybe a user explicitly set it to False. In that case we shouldn't silently overwrite it. Instead, let's at least log a warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it. I think your suggestion is taken care of, but please do let me know if you disagree.

os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val
masci marked this conversation as resolved.
Show resolved Hide resolved


def get_or_create_env_meta_data() -> Dict[str, Any]:
"""
Collects meta data about the setup that is used with Haystack, such as: operating system, python version, Haystack version, transformers version, pytorch version, number of GPUs, execution environment, and the value stored in the env variable HAYSTACK_EXECUTION_CONTEXT.
Expand Down
16 changes: 16 additions & 0 deletions test/others/test_utils.py
@@ -1,4 +1,5 @@
import logging
import os
from random import random
from typing import List

Expand Down Expand Up @@ -1245,6 +1246,21 @@ def greet2(name: str):
assert greet2("John") == "Hello John"


def test_secure_model_loading():
# setting the flag explicitly to zero
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
mayankjobanputra marked this conversation as resolved.
Show resolved Hide resolved
env_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD")
assert env_val == "0"
mayankjobanputra marked this conversation as resolved.
Show resolved Hide resolved

# now testing if just importing haystack is enough to enable secure loading of pytorch models
import importlib
import haystack
masci marked this conversation as resolved.
Show resolved Hide resolved

importlib.reload(haystack)
env_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")
assert env_val == "1"


class TestAggregateLabels:
@pytest.fixture
def standard_labels(self) -> List[Label]:
Expand Down