Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding secure loading of models by default for haystack #3901

Merged
merged 6 commits into from Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions haystack/__init__.py
Expand Up @@ -20,6 +20,8 @@
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline
from haystack.environment import set_pytorch_secure_model_loading


pd.options.display.max_colwidth = 80
set_pytorch_secure_model_loading()
14 changes: 14 additions & 0 deletions haystack/environment.py
@@ -1,3 +1,4 @@
import logging
import os
import platform
import sys
Expand All @@ -17,6 +18,19 @@

env_meta_data: Dict[str, Any] = {}

logger = logging.getLogger(__name__)


def set_pytorch_secure_model_loading(flag_val="1"):
# To load secure only model pytorch requires value of
# TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check whether this flag is set to any value different from ["1", "y", "yes", "true"] before setting it to 1. Maybe a user explicitly set it to False. In that case we shouldn't silently overwrite it. Instead, let's at least log a warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it. I think your suggestion is taken care of, but please do let me know if you disagree.

os_flag_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD")
if os_flag_val is None:
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val
else:
logger.info("TORCH_FORCE_WEIGHTS_ONLY_LOAD is already set to %s, Haystack will use the same.", os_flag_val)
)


def get_or_create_env_meta_data() -> Dict[str, Any]:
"""
Expand Down
14 changes: 14 additions & 0 deletions test/others/test_utils.py
@@ -1,4 +1,6 @@
import importlib
import logging
import os
from random import random
from typing import List

Expand All @@ -12,6 +14,7 @@
from ..conftest import fail_at_version, haystack_version

from haystack.errors import OpenAIRateLimitError
from haystack.environment import set_pytorch_secure_model_loading
from haystack.schema import Answer, Document, Span, Label
from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments
from haystack.utils.labels import aggregate_labels
Expand Down Expand Up @@ -1245,6 +1248,17 @@ def greet2(name: str):
assert greet2("John") == "Hello John"


def test_secure_model_loading(monkeypatch, caplog):
caplog.set_level(logging.INFO)
monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")

# now testing if just importing haystack is enough to enable secure loading of pytorch models
import haystack
masci marked this conversation as resolved.
Show resolved Hide resolved

importlib.reload(haystack)
assert "already set to" in caplog.text


class TestAggregateLabels:
@pytest.fixture
def standard_labels(self) -> List[Label]:
Expand Down