Skip to content

Commit

Permalink
feat: adding secure loading of models by default for haystack (#3901)
Browse files Browse the repository at this point in the history
* adding secure loading of models by default

* simplified set function

* testing import effect correctly

* added appropriate log line, adapted the test

* change log string formatting

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

* remove extra closing bracket )

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
  • Loading branch information
3 people committed Jan 24, 2023
1 parent 739fc22 commit 5c53b2b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
2 changes: 2 additions & 0 deletions haystack/__init__.py
Expand Up @@ -20,6 +20,8 @@
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline
from haystack.environment import set_pytorch_secure_model_loading


pd.options.display.max_colwidth = 80
set_pytorch_secure_model_loading()
13 changes: 13 additions & 0 deletions haystack/environment.py
@@ -1,3 +1,4 @@
import logging
import os
import platform
import sys
Expand All @@ -17,6 +18,18 @@

env_meta_data: Dict[str, Any] = {}

logger = logging.getLogger(__name__)


def set_pytorch_secure_model_loading(flag_val="1"):
# To load secure only model pytorch requires value of
# TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"]
os_flag_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD")
if os_flag_val is None:
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val
else:
logger.info("TORCH_FORCE_WEIGHTS_ONLY_LOAD is already set to %s, Haystack will use the same.", os_flag_val)


def get_or_create_env_meta_data() -> Dict[str, Any]:
"""
Expand Down
14 changes: 14 additions & 0 deletions test/others/test_utils.py
@@ -1,4 +1,6 @@
import importlib
import logging
import os
from random import random
from typing import List

Expand All @@ -12,6 +14,7 @@
from ..conftest import fail_at_version, haystack_version

from haystack.errors import OpenAIRateLimitError
from haystack.environment import set_pytorch_secure_model_loading
from haystack.schema import Answer, Document, Span, Label
from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments
from haystack.utils.labels import aggregate_labels
Expand Down Expand Up @@ -1245,6 +1248,17 @@ def greet2(name: str):
assert greet2("John") == "Hello John"


def test_secure_model_loading(monkeypatch, caplog):
caplog.set_level(logging.INFO)
monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")

# now testing if just importing haystack is enough to enable secure loading of pytorch models
import haystack

importlib.reload(haystack)
assert "already set to" in caplog.text


class TestAggregateLabels:
@pytest.fixture
def standard_labels(self) -> List[Label]:
Expand Down

0 comments on commit 5c53b2b

Please sign in to comment.