diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31d9b9a8..1df710e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ default_language_version: - python: python3.8 + python: python3 default_stages: [commit, push] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: trailing-whitespace - id: check-yaml diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index dea56541..926c5b0c 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -224,6 +224,8 @@ def get_default_device() -> str: if is_cuda_available() and is_cuda_built(): return "cuda" elif is_mps_available() and is_mps_built(): - return "mps" + # temporarily fix mps-enabled devices on cpu until mps is able to support all operations this package needs + # change this value on your own risk as it might break things depending on the attribution functions used + return "cpu" else: return "cpu" diff --git a/poetry.lock b/poetry.lock index 0d4e72b8..f391aa63 100644 --- a/poetry.lock +++ b/poetry.lock @@ -95,10 +95,10 @@ python-versions = ">=3.5" dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] -tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] [[package]] -name = "babel" +name = "Babel" version = "2.11.0" description = "Internationalization utilities" category = "dev" @@ -212,7 +212,7 @@ optional = false python-versions = ">=3.6.0" [package.extras] -unicode-backport = ["unicodedata2"] +unicode_backport = ["unicodedata2"] [[package]] name = "click" @@ -339,7 +339,7 @@ docs = ["s3fs"] quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"] s3 = ["boto3", "botocore", "fsspec", "s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] -tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] tests = ["Pillow (>=6.2.1)", "Werkzeug (>=1.0.1)", "absl-py", "aiobotocore (>=2.0.1)", "apache-beam (>=2.26.0)", "bert-score (>=0.3.6)", "boto3 (>=1.19.8)", "botocore (>=1.22.8)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "fsspec[s3]", "jiwer", "langdetect", "librosa", "lz4", "mauve-text", "moto[s3,server] (==2.0.4)", "nltk", "py7zr", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "requests-file (>=1.5.1)", "rouge-score", "s3fs (>=2021.11.1)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "soundfile", "spacy (>=3.0.0)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "torchaudio (<0.12.0)", "transformers", "typer (<0.5.0)", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -535,7 +535,7 @@ python-versions = ">=3.7" smmap = ">=3.0.1,<6" [[package]] -name = "gitpython" +name = "GitPython" version = "3.1.29" description = "GitPython is a python library used to interact with Git repositories" category = "dev" @@ -574,7 +574,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t [[package]] name = "identify" -version = "2.5.9" +version = "2.5.10" description = "File identification library for Python" category = "dev" optional = false @@ -685,7 +685,7 @@ notebook = ["ipywidgets", "notebook"] parallel = ["ipyparallel"] qtconsole = ["qtconsole"] test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] -test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.20)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] +test_extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.20)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] [[package]] name = "ipywidgets" @@ -707,7 +707,7 @@ test = ["jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] [[package]] name = "isort" -version = "5.11.1" +version = "5.11.2" description = "A Python utility / library to sort Python imports." category = "dev" optional = false @@ -739,7 +739,7 @@ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] [[package]] -name = "jinja2" +name = "Jinja2" version = "3.1.2" description = "A very fast and expressive template engine." category = "dev" @@ -831,7 +831,7 @@ optional = false python-versions = ">=3.7" [[package]] -name = "markdown" +name = "Markdown" version = "3.4.1" description = "Python implementation of Markdown." category = "dev" @@ -845,7 +845,7 @@ importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} testing = ["coverage", "pyyaml"] [[package]] -name = "markupsafe" +name = "MarkupSafe" version = "2.1.1" description = "Safely add untrusted strings to HTML/XML markup." category = "dev" @@ -1077,7 +1077,7 @@ optional = true python-versions = "*" [[package]] -name = "pillow" +name = "Pillow" version = "9.3.0" description = "Python Imaging Library (Fork)" category = "main" @@ -1125,7 +1125,7 @@ pastel = ">=0.2.1,<0.3.0" tomli = ">=1.2.2" [package.extras] -poetry-plugin = ["poetry (>=1.0,<2.0)"] +poetry_plugin = ["poetry (>=1.0,<2.0)"] [[package]] name = "pre-commit" @@ -1236,7 +1236,7 @@ optional = false python-versions = ">=3.6" [[package]] -name = "pygments" +name = "Pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." category = "main" @@ -1354,7 +1354,7 @@ optional = true python-versions = "*" [[package]] -name = "pyyaml" +name = "PyYAML" version = "6.0" description = "YAML parser and emitter for Python" category = "main" @@ -1410,7 +1410,7 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "responses" @@ -1444,7 +1444,7 @@ pygments = ">=2.6.0,<3.0.0" jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] [[package]] -name = "ruamel-yaml" +name = "ruamel.yaml" version = "0.17.21" description = "ruamel.yaml is a YAML parser/emitter that supports roundtrip preservation of comments, seq/map flow style, and map key order" category = "dev" @@ -1459,7 +1459,7 @@ docs = ["ryd"] jinja2 = ["ruamel.yaml.jinja2 (>=0.2)"] [[package]] -name = "ruamel-yaml-clib" +name = "ruamel.yaml.clib" version = "0.2.7" description = "C version of reader, parser and emitter for ruamel.yaml derived from libyaml" category = "dev" @@ -1586,7 +1586,7 @@ optional = false python-versions = "*" [[package]] -name = "sphinx" +name = "Sphinx" version = "5.3.0" description = "Python documentation generator" category = "dev" @@ -1629,7 +1629,7 @@ python-versions = ">=3.6" sphinx = ">=1.8" [package.extras] -code-style = ["pre-commit (==2.12.1)"] +code_style = ["pre-commit (==2.12.1)"] rtd = ["ipython", "sphinx", "sphinx-book-theme"] [[package]] @@ -1846,17 +1846,17 @@ python-versions = ">=3.6" [[package]] name = "torch" -version = "1.13.0" +version = "1.13.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false python-versions = ">=3.7.0" [package.dependencies] -nvidia-cublas-cu11 = "11.10.3.66" -nvidia-cuda-nvrtc-cu11 = "11.7.99" -nvidia-cuda-runtime-cu11 = "11.7.99" -nvidia-cudnn-cu11 = "8.5.0.96" +nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""} +nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} +nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""} +nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""} typing-extensions = "*" [package.extras] @@ -1946,7 +1946,7 @@ dev = ["GitPython (<3.1.19)", "Pillow", "accelerate (>=0.10.0)", "beautifulsoup4 dev-tensorflow = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (==22.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flake8 (>=3.8.3)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf (<=3.20.2)", "psutil", "pyctcdecode (>=0.4.0)", "pytest", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.4,<2.11)", "tensorflow-text", "tf2onnx", "timeout-decorator", "tokenizers (>=0.11.1,!=0.11.3,<0.14)"] dev-torch = ["GitPython (<3.1.19)", "Pillow", "beautifulsoup4", "black (==22.3)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf (<=3.20.2)", "psutil", "pyctcdecode (>=0.4.0)", "pyknp (>=0.6.1)", "pytest", "pytest-timeout", "pytest-xdist", "ray[tune]", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "safetensors (>=0.2.1)", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.7,!=1.12.0)", "torchaudio", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] docs = ["Pillow", "accelerate (>=0.10.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1)", "hf-doc-builder", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8)", "optuna", "phonemizer", "protobuf (<=3.20.2)", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.4,<2.11)", "tensorflow-text", "tf2onnx", "timm", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torch (>=1.7,!=1.12.0)", "torchaudio"] -docs-specific = ["hf-doc-builder"] +docs_specific = ["hf-doc-builder"] fairscale = ["fairscale (>0.3)"] flax = ["flax (>=0.4.1)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "optax (>=0.0.8)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -2103,7 +2103,7 @@ sklearn = ["scikit-learn", "joblib"] [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<3.12" -content-hash = "59d32488c3d75b567dcb1784290dc581a921dc007aa3f82779b247c3e89f8661" +content-hash = "e1f95d71145b778f3f05d42294911fa3bb69450733f03522cf5c70d39ed76985" [metadata.files] aiohttp = [ @@ -2223,7 +2223,7 @@ attrs = [ {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, ] -babel = [ +Babel = [ {file = "Babel-2.11.0-py3-none-any.whl", hash = "sha256:1ad3eca1c885218f6dce2ab67291178944f810a10a9b5f3cb8382a5a232b64fe"}, {file = "Babel-2.11.0.tar.gz", hash = "sha256:5ef4b3226b0180dedded4229651c8b0e1a3a6a2837d45a073272f313e4cf97f6"}, ] @@ -2630,7 +2630,7 @@ gitdb = [ {file = "gitdb-4.0.10-py3-none-any.whl", hash = "sha256:c286cf298426064079ed96a9e4a9d39e7f3e9bf15ba60701e95f5492f28415c7"}, {file = "gitdb-4.0.10.tar.gz", hash = "sha256:6eb990b69df4e15bad899ea868dc46572c3f75339735663b81de79b06f17eb9a"}, ] -gitpython = [ +GitPython = [ {file = "GitPython-3.1.29-py3-none-any.whl", hash = "sha256:41eea0deec2deea139b459ac03656f0dd28fc4a3387240ec1d3c259a2c47850f"}, {file = "GitPython-3.1.29.tar.gz", hash = "sha256:cc36bfc4a3f913e66805a28e84703e419d9c264c1077e537b54f0e1af85dbefd"}, ] @@ -2639,8 +2639,8 @@ huggingface-hub = [ {file = "huggingface_hub-0.11.1.tar.gz", hash = "sha256:8b9ebf9bbb1782f6f0419ec490973a6487c6c4ed84293a8a325d34c4f898f53f"}, ] identify = [ - {file = "identify-2.5.9-py2.py3-none-any.whl", hash = "sha256:a390fb696e164dbddb047a0db26e57972ae52fbd037ae68797e5ae2f4492485d"}, - {file = "identify-2.5.9.tar.gz", hash = "sha256:906036344ca769539610436e40a684e170c3648b552194980bb7b617a8daeb9f"}, + {file = "identify-2.5.10-py2.py3-none-any.whl", hash = "sha256:fb7c2feaeca6976a3ffa31ec3236a6911fbc51aec9acc111de2aed99f244ade2"}, + {file = "identify-2.5.10.tar.gz", hash = "sha256:dce9e31fee7dbc45fea36a9e855c316b8fbf807e65a862f160840bb5a2bf5dfd"}, ] idna = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, @@ -2671,14 +2671,14 @@ ipywidgets = [ {file = "ipywidgets-8.0.3.tar.gz", hash = "sha256:2ec50df8538a1d4ddd5d454830d010922ad1015e81ac23efb27c0908bbc1eece"}, ] isort = [ - {file = "isort-5.11.1-py3-none-any.whl", hash = "sha256:bf02c95f1fe615ebbe13a619cfed1619ddfe8941274c9e3de3143adca406cb02"}, - {file = "isort-5.11.1.tar.gz", hash = "sha256:7c5bd998504826b6f1e6f2f98b533976b066baba29b8bae83fdeefd0b89c6b70"}, + {file = "isort-5.11.2-py3-none-any.whl", hash = "sha256:e486966fba83f25b8045f8dd7455b0a0d1e4de481e1d7ce4669902d9fb85e622"}, + {file = "isort-5.11.2.tar.gz", hash = "sha256:dd8bbc5c0990f2a095d754e50360915f73b4c26fc82733eb5bfc6b48396af4d2"}, ] jedi = [ {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"}, {file = "jedi-0.18.2.tar.gz", hash = "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"}, ] -jinja2 = [ +Jinja2 = [ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, ] @@ -2793,11 +2793,11 @@ lazy-object-proxy = [ {file = "lazy_object_proxy-1.8.0-pp38-pypy38_pp73-any.whl", hash = "sha256:7e1561626c49cb394268edd00501b289053a652ed762c58e1081224c8d881cec"}, {file = "lazy_object_proxy-1.8.0-pp39-pypy39_pp73-any.whl", hash = "sha256:ce58b2b3734c73e68f0e30e4e725264d4d6be95818ec0a0be4bb6bf9a7e79aa8"}, ] -markdown = [ +Markdown = [ {file = "Markdown-3.4.1-py3-none-any.whl", hash = "sha256:08fb8465cffd03d10b9dd34a5c3fea908e20391a2a90b88d66362cb05beed186"}, {file = "Markdown-3.4.1.tar.gz", hash = "sha256:3b809086bb6efad416156e00a0da66fe47618a5d6918dd688f53f40c8e4cfeff"}, ] -markupsafe = [ +MarkupSafe = [ {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812"}, {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a"}, {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e"}, @@ -3098,7 +3098,9 @@ pickleshare = [ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, ] -pillow = [ +Pillow = [ + {file = "Pillow-9.3.0-1-cp37-cp37m-win32.whl", hash = "sha256:e6ea6b856a74d560d9326c0f5895ef8050126acfdc7ca08ad703eb0081e82b74"}, + {file = "Pillow-9.3.0-1-cp37-cp37m-win_amd64.whl", hash = "sha256:32a44128c4bdca7f31de5be641187367fe2a450ad83b833ef78910397db491aa"}, {file = "Pillow-9.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:0b7257127d646ff8676ec8a15520013a698d1fdc48bc2a79ba4e53df792526f2"}, {file = "Pillow-9.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b90f7616ea170e92820775ed47e136208e04c967271c9ef615b6fbd08d9af0e3"}, {file = "Pillow-9.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68943d632f1f9e3dce98908e873b3a090f6cba1cbb1b892a9e8d97c938871fbe"}, @@ -3270,7 +3272,7 @@ pyflakes = [ {file = "pyflakes-3.0.1-py2.py3-none-any.whl", hash = "sha256:ec55bf7fe21fff7f1ad2f7da62363d749e2a470500eab1b555334b67aa1ef8cf"}, {file = "pyflakes-3.0.1.tar.gz", hash = "sha256:ec8b276a6b60bd80defed25add7e439881c19e64850afd9b346283d4165fd0fd"}, ] -pygments = [ +Pygments = [ {file = "Pygments-2.13.0-py3-none-any.whl", hash = "sha256:f643f331ab57ba3c9d89212ee4a2dabc6e94f117cf4eefde99a0574720d14c42"}, {file = "Pygments-2.13.0.tar.gz", hash = "sha256:56a8508ae95f98e2b9bdf93a6be5ae3f7d8af858b43e02c5a2ff083726be40c1"}, ] @@ -3318,7 +3320,7 @@ pywin32 = [ {file = "pywin32-305-cp39-cp39-win32.whl", hash = "sha256:9d968c677ac4d5cbdaa62fd3014ab241718e619d8e36ef8e11fb930515a1e918"}, {file = "pywin32-305-cp39-cp39-win_amd64.whl", hash = "sha256:50768c6b7c3f0b38b7fb14dd4104da93ebced5f1a50dc0e834594bff6fbe1271"}, ] -pyyaml = [ +PyYAML = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, @@ -3542,11 +3544,11 @@ rich = [ {file = "rich-10.16.2-py3-none-any.whl", hash = "sha256:c59d73bd804c90f747c8d7b1d023b88f2a9ac2454224a4aeaf959b21eeb42d03"}, {file = "rich-10.16.2.tar.gz", hash = "sha256:720974689960e06c2efdb54327f8bf0cdbdf4eae4ad73b6c94213cad405c371b"}, ] -ruamel-yaml = [ +"ruamel.yaml" = [ {file = "ruamel.yaml-0.17.21-py3-none-any.whl", hash = "sha256:742b35d3d665023981bd6d16b3d24248ce5df75fdb4e2924e93a05c1f8b61ca7"}, {file = "ruamel.yaml-0.17.21.tar.gz", hash = "sha256:8b7ce697a2f212752a35c1ac414471dc16c424c9573be4926b56ff3f5d23b7af"}, ] -ruamel-yaml-clib = [ +"ruamel.yaml.clib" = [ {file = "ruamel.yaml.clib-0.2.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d5859983f26d8cd7bb5c287ef452e8aacc86501487634573d260968f753e1d71"}, {file = "ruamel.yaml.clib-0.2.7-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:debc87a9516b237d0466a711b18b6ebeb17ba9f391eb7f91c649c5c4ec5006c7"}, {file = "ruamel.yaml.clib-0.2.7-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:df5828871e6648db72d1c19b4bd24819b80a755c4541d3409f0f7acd0f335c80"}, @@ -3555,6 +3557,7 @@ ruamel-yaml-clib = [ {file = "ruamel.yaml.clib-0.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:d000f258cf42fec2b1bbf2863c61d7b8918d31ffee905da62dede869254d3b8a"}, {file = "ruamel.yaml.clib-0.2.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:045e0626baf1c52e5527bd5db361bc83180faaba2ff586e763d3d5982a876a9e"}, {file = "ruamel.yaml.clib-0.2.7-cp311-cp311-macosx_12_6_arm64.whl", hash = "sha256:721bc4ba4525f53f6a611ec0967bdcee61b31df5a56801281027a3a6d1c2daf5"}, + {file = "ruamel.yaml.clib-0.2.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:41d0f1fa4c6830176eef5b276af04c89320ea616655d01327d5ce65e50575c94"}, {file = "ruamel.yaml.clib-0.2.7-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4b3a93bb9bc662fc1f99c5c3ea8e623d8b23ad22f861eb6fce9377ac07ad6072"}, {file = "ruamel.yaml.clib-0.2.7-cp36-cp36m-macosx_12_0_arm64.whl", hash = "sha256:a234a20ae07e8469da311e182e70ef6b199d0fbeb6c6cc2901204dd87fb867e8"}, {file = "ruamel.yaml.clib-0.2.7-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:15910ef4f3e537eea7fe45f8a5d19997479940d9196f357152a09031c5be59f3"}, @@ -3690,7 +3693,7 @@ snowballstemmer = [ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, ] -sphinx = [ +Sphinx = [ {file = "Sphinx-5.3.0.tar.gz", hash = "sha256:51026de0a9ff9fc13c05d74913ad66047e104f56a129ff73e174eb5c3ee794b5"}, {file = "sphinx-5.3.0-py3-none-any.whl", hash = "sha256:060ca5c9f7ba57a08a1219e547b269fadf125ae25b06b9fa7f66768efb652d6d"}, ] @@ -3807,27 +3810,27 @@ tomlkit = [ {file = "tomlkit-0.11.6.tar.gz", hash = "sha256:71b952e5721688937fb02cf9d354dbcf0785066149d2855e44531ebdd2b65d73"}, ] torch = [ - {file = "torch-1.13.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:f68edfea71ade3862039ba66bcedf954190a2db03b0c41a9b79afd72210abd97"}, - {file = "torch-1.13.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d2d2753519415d154de4d3e64d2eaaeefdba6b6fd7d69d5ffaef595988117700"}, - {file = "torch-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c227c16626e4ce766cca5351cc62a2358a11e8e466410a298487b9dff159eb1"}, - {file = "torch-1.13.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:49a949b8136b32b2ec0724cbf4c6678b54e974b7d68f19f1231eea21cde5c23b"}, - {file = "torch-1.13.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0fdd38c96230947b1ed870fed4a560252f8d23c3a2bf4dab9d2d42b18f2e67c8"}, - {file = "torch-1.13.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:43db0723fc66ad6486f86dc4890c497937f7cd27429f28f73fb7e4d74b7482e2"}, - {file = "torch-1.13.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e643ac8d086706e82f77b5d4dfcf145a9dd37b69e03e64177fc23821754d2ed7"}, - {file = "torch-1.13.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bb33a911460475d1594a8c8cb73f58c08293211760796d99cae8c2509b86d7f1"}, - {file = "torch-1.13.0-cp37-cp37m-win_amd64.whl", hash = "sha256:220325d0f4e69ee9edf00c04208244ef7cf22ebce083815ce272c7491f0603f5"}, - {file = "torch-1.13.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:cd1e67db6575e1b173a626077a54e4911133178557aac50683db03a34e2b636a"}, - {file = "torch-1.13.0-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:9197ec216833b836b67e4d68e513d31fb38d9789d7cd998a08fba5b499c38454"}, - {file = "torch-1.13.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fa768432ce4b8ffa29184c79a3376ab3de4a57b302cdf3c026a6be4c5a8ab75b"}, - {file = "torch-1.13.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:635dbb99d981a6483ca533b3dc7be18ef08dd9e1e96fb0bb0e6a99d79e85a130"}, - {file = "torch-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:857c7d5b1624c5fd979f66d2b074765733dba3f5e1cc97b7d6909155a2aae3ce"}, - {file = "torch-1.13.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:ef934a21da6f6a516d0a9c712a80d09c56128abdc6af8dc151bee5199b4c3b4e"}, - {file = "torch-1.13.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:f01a9ae0d4b69d2fc4145e8beab45b7877342dddbd4838a7d3c11ca7f6680745"}, - {file = "torch-1.13.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ac382cedaf2f70afea41380ad8e7c06acef6b5b7e2aef3971cdad666ca6e185"}, - {file = "torch-1.13.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:e20df14d874b024851c58e8bb3846249cb120e677f7463f60c986e3661f88680"}, - {file = "torch-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:4a378f5091307381abfb30eb821174e12986f39b1cf7c4522bf99155256819eb"}, - {file = "torch-1.13.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:922a4910613b310fbeb87707f00cb76fec328eb60cc1349ed2173e7c9b6edcd8"}, - {file = "torch-1.13.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:47fe6228386bff6d74319a2ffe9d4ed943e6e85473d78e80502518c607d644d2"}, + {file = "torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:fd12043868a34a8da7d490bf6db66991108b00ffbeecb034228bfcbbd4197143"}, + {file = "torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d9fe785d375f2e26a5d5eba5de91f89e6a3be5d11efb497e76705fdf93fa3c2e"}, + {file = "torch-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:98124598cdff4c287dbf50f53fb455f0c1e3a88022b39648102957f3445e9b76"}, + {file = "torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:393a6273c832e047581063fb74335ff50b4c566217019cc6ace318cd79eb0566"}, + {file = "torch-1.13.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:0122806b111b949d21fa1a5f9764d1fd2fcc4a47cb7f8ff914204fd4fc752ed5"}, + {file = "torch-1.13.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:22128502fd8f5b25ac1cd849ecb64a418382ae81dd4ce2b5cebaa09ab15b0d9b"}, + {file = "torch-1.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:76024be052b659ac1304ab8475ab03ea0a12124c3e7626282c9c86798ac7bc11"}, + {file = "torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:ea8dda84d796094eb8709df0fcd6b56dc20b58fdd6bc4e8d7109930dafc8e419"}, + {file = "torch-1.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:2ee7b81e9c457252bddd7d3da66fb1f619a5d12c24d7074de91c4ddafb832c93"}, + {file = "torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:0d9b8061048cfb78e675b9d2ea8503bfe30db43d583599ae8626b1263a0c1380"}, + {file = "torch-1.13.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:f402ca80b66e9fbd661ed4287d7553f7f3899d9ab54bf5c67faada1555abde28"}, + {file = "torch-1.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:727dbf00e2cf858052364c0e2a496684b9cb5aa01dc8a8bc8bbb7c54502bdcdd"}, + {file = "torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:df8434b0695e9ceb8cc70650afc1310d8ba949e6db2a0525ddd9c3b2b181e5fe"}, + {file = "torch-1.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:5e1e722a41f52a3f26f0c4fcec227e02c6c42f7c094f32e49d4beef7d1e213ea"}, + {file = "torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:33e67eea526e0bbb9151263e65417a9ef2d8fa53cbe628e87310060c9dcfa312"}, + {file = "torch-1.13.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:eeeb204d30fd40af6a2d80879b46a7efbe3cf43cdbeb8838dd4f3d126cc90b2b"}, + {file = "torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:50ff5e76d70074f6653d191fe4f6a42fdbe0cf942fbe2a3af0b75eaa414ac038"}, + {file = "torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:2c3581a3fd81eb1f0f22997cddffea569fea53bafa372b2c0471db373b26aafc"}, + {file = "torch-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0aa46f0ac95050c604bcf9ef71da9f1172e5037fdf2ebe051962d47b123848e7"}, + {file = "torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6930791efa8757cb6974af73d4996b6b50c592882a324b8fb0589c6a9ba2ddaf"}, + {file = "torch-1.13.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:e0df902a7c7dd6c795698532ee5970ce898672625635d885eade9976e5a04949"}, ] torchtyping = [ {file = "torchtyping-0.1.4-py3-none-any.whl", hash = "sha256:485fb6ef3965c39b0de15f00d6f49373e0a3a6993e9733942a63c5e207d35390"}, diff --git a/pyproject.toml b/pyproject.toml index 2538438f..acaad28d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ classifiers = [ #! Update me "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.11" + ] [tool.poetry.scripts] @@ -53,7 +54,7 @@ captum = "^0.5.0" numpy = "^1.22.4" torchtyping = "^0.1.4" json-tricks = "^3.15.5" -torch = "^1.13.0" +torch = "^1.13.1" scipy = "^1.8.1" matplotlib = "^3.5.2" tqdm = "^4.64.0" @@ -101,7 +102,7 @@ notebook = ["ipykernel", "ipywidgets"] [tool.poe.tasks] upgrade-pip = "python -m pip install --upgrade pip" -torch-cpu = "python -m pip install torch==1.13.0+cpu -f https://download.pytorch.org/whl/torch_stable.html" +torch-cpu = "python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html" torch-cuda11 = "python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116" [tool.black] diff --git a/requirements-dev.txt b/requirements-dev.txt index eea14aab..16e434ae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -41,7 +41,7 @@ fsspec[http]==2022.11.0 ; python_full_version >= "3.8.1" and python_version < "3 gitdb==4.0.10 ; python_full_version >= "3.8.1" and python_version < "3.12" gitpython==3.1.29 ; python_full_version >= "3.8.1" and python_version < "3.12" huggingface-hub==0.11.1 ; python_full_version >= "3.8.1" and python_version < "3.12" -identify==2.5.9 ; python_full_version >= "3.8.1" and python_version < "3.12" +identify==2.5.10 ; python_full_version >= "3.8.1" and python_version < "3.12" idna==3.4 ; python_full_version >= "3.8.1" and python_version < "3.12" imagesize==1.4.1 ; python_full_version >= "3.8.1" and python_version < "3.12" importlib-metadata==5.1.0 ; python_full_version >= "3.8.1" and python_version < "3.10" @@ -50,8 +50,8 @@ ipykernel==6.19.2 ; python_full_version >= "3.8.1" and python_version < "3.12" ipykernel[notebook]==6.19.2 ; python_full_version >= "3.8.1" and python_version < "3.12" ipython==8.7.0 ; python_full_version >= "3.8.1" and python_version < "3.12" ipywidgets[notebook]==8.0.3 ; python_full_version >= "3.8.1" and python_version < "3.12" -isort==5.11.1 ; python_full_version >= "3.8.1" and python_version < "3.12" -isort[colors]==5.11.1 ; python_full_version >= "3.8.1" and python_version < "3.12" +isort==5.11.2 ; python_full_version >= "3.8.1" and python_version < "3.12" +isort[colors]==5.11.2 ; python_full_version >= "3.8.1" and python_version < "3.12" jedi==0.18.2 ; python_full_version >= "3.8.1" and python_version < "3.12" jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_version < "3.12" joblib==1.2.0 ; python_full_version >= "3.8.1" and python_version < "3.12" @@ -73,10 +73,10 @@ mypy-extensions==0.4.3 ; python_full_version >= "3.8.1" and python_version < "3. nest-asyncio==1.5.6 ; python_full_version >= "3.8.1" and python_version < "3.12" nodeenv==1.7.0 ; python_full_version >= "3.8.1" and python_version < "3.12" numpy==1.23.5 ; python_version < "3.12" and python_full_version >= "3.8.1" -nvidia-cublas-cu11==11.10.3.66 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cuda-nvrtc-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cuda-runtime-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cudnn-cu11==8.5.0.96 ; python_full_version >= "3.8.1" and python_version < "3.12" +nvidia-cublas-cu11==11.10.3.66 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cuda-nvrtc-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cuda-runtime-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cudnn-cu11==8.5.0.96 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" packaging==22.0 ; python_full_version >= "3.8.1" and python_version < "3.12" pandas==1.5.2 ; python_full_version >= "3.8.1" and python_version < "3.12" parso==0.8.3 ; python_full_version >= "3.8.1" and python_version < "3.12" @@ -148,7 +148,7 @@ tokenizers==0.13.2 ; python_full_version >= "3.8.1" and python_version < "3.12" toml==0.10.2 ; python_full_version >= "3.8.1" and python_version < "3.12" tomli==2.0.1 ; python_full_version >= "3.8.1" and python_version < "3.12" tomlkit==0.11.6 ; python_full_version >= "3.8.1" and python_version < "3.12" -torch==1.13.0 ; python_full_version >= "3.8.1" and python_version < "3.12" +torch==1.13.1 ; python_full_version >= "3.8.1" and python_version < "3.12" torchtyping==0.1.4 ; python_full_version >= "3.8.1" and python_version < "3.12" tornado==6.2 ; python_full_version >= "3.8.1" and python_version < "3.12" tqdm==4.64.1 ; python_full_version >= "3.8.1" and python_version < "3.12" @@ -159,7 +159,7 @@ typing-extensions==4.4.0 ; python_full_version >= "3.8.1" and python_version < " urllib3==1.26.13 ; python_full_version >= "3.8.1" and python_version < "3.12" virtualenv==20.17.1 ; python_full_version >= "3.8.1" and python_version < "3.12" wcwidth==0.2.5 ; python_full_version >= "3.8.1" and python_version < "3.12" -wheel==0.38.4 ; python_full_version >= "3.8.1" and python_version < "3.12" +wheel==0.38.4 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" widgetsnbextension==4.0.4 ; python_full_version >= "3.8.1" and python_version < "3.12" wrapt==1.14.1 ; python_full_version >= "3.8.1" and python_version < "3.12" xxhash==3.1.0 ; python_full_version >= "3.8.1" and python_version < "3.12" diff --git a/requirements.txt b/requirements.txt index 9022f6d0..9238eb88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,10 +13,10 @@ json-tricks==3.16.1 ; python_full_version >= "3.8.1" and python_version < "3.12" kiwisolver==1.4.4 ; python_full_version >= "3.8.1" and python_version < "3.12" matplotlib==3.6.2 ; python_full_version >= "3.8.1" and python_version < "3.12" numpy==1.23.5 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cublas-cu11==11.10.3.66 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cuda-nvrtc-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cuda-runtime-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" -nvidia-cudnn-cu11==8.5.0.96 ; python_full_version >= "3.8.1" and python_version < "3.12" +nvidia-cublas-cu11==11.10.3.66 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cuda-nvrtc-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cuda-runtime-cu11==11.7.99 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" +nvidia-cudnn-cu11==8.5.0.96 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux" packaging==22.0 ; python_full_version >= "3.8.1" and python_version < "3.12" pastel==0.2.1 ; python_full_version >= "3.8.1" and python_version < "3.12" pillow==9.3.0 ; python_full_version >= "3.8.1" and python_version < "3.12" @@ -36,11 +36,11 @@ setuptools==65.6.3 ; python_full_version >= "3.8.1" and python_version < "3.12" six==1.16.0 ; python_full_version >= "3.8.1" and python_version < "3.12" tokenizers==0.13.2 ; python_full_version >= "3.8.1" and python_version < "3.12" tomli==2.0.1 ; python_full_version >= "3.8.1" and python_version < "3.12" -torch==1.13.0 ; python_full_version >= "3.8.1" and python_version < "3.12" +torch==1.13.1 ; python_full_version >= "3.8.1" and python_version < "3.12" torchtyping==0.1.4 ; python_full_version >= "3.8.1" and python_version < "3.12" tqdm==4.64.1 ; python_full_version >= "3.8.1" and python_version < "3.12" transformers[sentencepiece,tokenizers,torch]==4.25.1 ; python_full_version >= "3.8.1" and python_version < "3.12" typeguard==2.13.3 ; python_full_version >= "3.8.1" and python_version < "3.12" typing-extensions==4.4.0 ; python_full_version >= "3.8.1" and python_version < "3.12" urllib3==1.26.13 ; python_full_version >= "3.8.1" and python_version < "3.12" -wheel==0.38.4 ; python_full_version >= "3.8.1" and python_version < "3.12" +wheel==0.38.4 ; python_full_version >= "3.8.1" and python_version < "3.12" and platform_system == "Linux"