Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Pydantic v1 incompatibility #1622

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ requires-python = ">=3.7"
dependencies = [
# intentionally loose. perhaps these should be vendored to not collide with user code?
"attrs>=20.1,<24",
"fastapi>=0.75.2,<0.99.0",
"pydantic>=1.9,<2",
"fastapi @ git+https://github.com/replicate/fastapi.git@v1-response-model",
"pydantic>=1.9",
"PyYAML",
"requests>=2,<3",
"structlog>=20,<25",
Expand Down
6 changes: 5 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from pydantic import BaseModel
try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


from .predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path, Secret
Expand Down
2 changes: 1 addition & 1 deletion python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
}
},
"info": { "title": "Cog", "version": "0.1.0" },
"openapi": "3.0.2",
"openapi": "3.1.0",
"paths": {
"/": {
"get": {
Expand Down
6 changes: 5 additions & 1 deletion python/cog/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from types import GeneratorType
from typing import Any, Callable

from pydantic import BaseModel
try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


from .types import Path

Expand Down
12 changes: 10 additions & 2 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@
from typing_compat import get_args, get_origin # type: ignore

import yaml
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

try:
from pydantic.v1 import BaseModel, Field, create_model # type: ignore
except ImportError:
from pydantic import BaseModel, Field, create_model # pylint: disable=W0404

try:
from pydantic.v1.fields import FieldInfo # type: ignore
except ImportError:
from pydantic.fields import FieldInfo # pylint: disable=W0404

# Added in Python 3.9. Can be from typing if we drop support for <3.9
from typing_extensions import Annotated
Expand Down
19 changes: 14 additions & 5 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
from enum import Enum
from types import ModuleType

import pydantic
try:
from pydantic.v1 import AnyHttpUrl, BaseModel, Extra, create_model # type: ignore
except ImportError:
from pydantic import ( # pylint: disable=W0404
AnyHttpUrl,
BaseModel,
Extra,
create_model,
)


BUNDLED_SCHEMA_PATH = ".cog/schema.py"

Expand Down Expand Up @@ -38,7 +47,7 @@ def default_events(cls) -> t.List["WebhookEvent"]:
return [cls.START, cls.OUTPUT, cls.LOGS, cls.COMPLETED]


class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow):
class PredictionBaseModel(BaseModel, extra=Extra.allow):
input: t.Dict[str, t.Any]


Expand All @@ -49,7 +58,7 @@ class PredictionRequest(PredictionBaseModel):
# TODO: deprecate this
output_file_prefix: t.Optional[str]

webhook: t.Optional[pydantic.AnyHttpUrl]
webhook: t.Optional[AnyHttpUrl]
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = (
WebhookEvent.default_events()
)
Expand All @@ -59,7 +68,7 @@ def with_types(cls, input_type: t.Type[t.Any]) -> t.Any:
# [compat] Input is implicitly optional -- previous versions of the
# Cog HTTP API allowed input to be omitted (e.g. for models that don't
# have any inputs). We should consider changing this in future.
return pydantic.create_model(
return create_model(
cls.__name__, __base__=cls, input=(t.Optional[input_type], None)
)

Expand All @@ -85,7 +94,7 @@ def with_types(cls, input_type: t.Type[t.Any], output_type: t.Type[t.Any]) -> t.
# [compat] Input is implicitly optional -- previous versions of the
# Cog HTTP API allowed input to be omitted (e.g. for models that don't
# have any inputs). We should consider changing this in future.
return pydantic.create_model(
return create_model(
cls.__name__,
__base__=cls,
input=(t.Optional[input_type], None),
Expand Down
17 changes: 8 additions & 9 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from pydantic.error_wrappers import ErrorWrapper

try:
from pydantic.v1 import ValidationError # type: ignore
except ImportError:
from pydantic import ValidationError # pylint: disable=W0404

from .. import schema
from ..errors import PredictorNotSet
Expand Down Expand Up @@ -291,12 +294,8 @@ async def predict_idempotent(
if request.id is not None and request.id != prediction_id:
raise RequestValidationError(
[
ErrorWrapper(
ValueError(
"prediction ID must match the ID supplied in the URL"
),
("body", "id"),
)
ValueError("prediction ID must match the ID supplied in the URL"),
("body", "id"),
]
)

Expand All @@ -310,7 +309,7 @@ async def predict_idempotent(
return _predict(request=request, respond_async=respond_async)

def _predict(
*, request: PredictionRequest, respond_async: bool = False
*, request: Optional[PredictionRequest], respond_async: bool = False
) -> Response:
# [compat] If no body is supplied, assume that this model can be run
# with empty input. This will throw a ValidationError if that's not
Expand Down
7 changes: 6 additions & 1 deletion python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union

import requests
from pydantic import Field, SecretStr

try:
from pydantic.v1 import Field, SecretStr # type: ignore
except ImportError:
from pydantic import Field, SecretStr # pylint: disable=W0404


FILENAME_ILLEGAL_CHARS = set("\u0000/")

Expand Down
5 changes: 4 additions & 1 deletion python/tests/server/fixtures/complex_output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pydantic import BaseModel
try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


class Output(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion python/tests/server/fixtures/input_unsupported_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from cog import BasePredictor
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


class Input(BaseModel):
text: str
Expand Down
6 changes: 5 additions & 1 deletion python/tests/server/fixtures/openapi_custom_output_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from cog import BasePredictor
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


# Calling this `MyOutput` to test if cog renames it to `Output` in the schema
class MyOutput(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion python/tests/server/fixtures/openapi_output_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from cog import BasePredictor
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


# An output object called `Output` needs to be special cased because pydantic tries to dedupe it with the internal `Output`
class Output(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion python/tests/server/fixtures/output_complex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import io

from cog import BasePredictor, File
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


class Output(BaseModel):
Expand Down
6 changes: 5 additions & 1 deletion python/tests/server/fixtures/output_iterator_complex.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Iterator, List

from cog import BasePredictor
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


class Output(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_openapi_specification(client, static_schema):

schema = resp.json()
assert schema == static_schema
assert schema["openapi"] == "3.0.2"
assert schema["openapi"] == "3.1.0"
assert schema["info"] == {"title": "Cog", "version": "0.1.0"}
assert schema["paths"]["/"] == {
"get": {
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_train_openapi_specification(client):
assert resp.status_code == 200

schema = resp.json()
assert schema["openapi"] == "3.0.2"
assert schema["openapi"] == "3.1.0"
assert schema["info"] == {"title": "Cog", "version": "0.1.0"}

assert schema["components"]["schemas"]["TrainingInput"] == {
Expand Down
6 changes: 5 additions & 1 deletion python/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import numpy as np
from cog.files import upload_file
from cog.json import make_encodeable, upload_files
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


def test_make_encodeable_recursively_encodes_tuples():
Expand Down
116 changes: 116 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml --extra=dev -o requirements.txt
anyio==4.3.0
# via
# httpx
# starlette
# watchfiles
attrs==23.2.0
# via hypothesis
black==24.4.0
build==1.2.1
certifi==2024.2.2
# via
# httpcore
# httpx
# requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via
# black
# uvicorn
coverage==7.4.4
# via pytest-cov
execnet==2.1.1
# via pytest-xdist
fastapi @ git+https://github.com/replicate/fastapi.git@7be995a943adde5315cb557a9330ee4ff4137849
h11==0.14.0
# via
# httpcore
# uvicorn
httpcore==1.0.5
# via httpx
httptools==0.6.1
# via uvicorn
httpx==0.27.0
hypothesis==6.100.1
idna==3.7
# via
# anyio
# httpx
# requests
iniconfig==2.0.0
# via pytest
markupsafe==2.1.5
# via werkzeug
mypy-extensions==1.0.0
# via black
nodeenv==1.8.0
# via pyright
numpy==1.26.4
packaging==24.0
# via
# black
# build
# pytest
# pytest-rerunfailures
pathspec==0.12.1
# via black
pillow==10.3.0
platformdirs==4.2.0
# via black
pluggy==1.4.0
# via pytest
pydantic==1.10.15
# via fastapi
pyproject-hooks==1.0.0
# via build
pyright==1.1.347
pytest==8.1.1
# via
# pytest-cov
# pytest-rerunfailures
# pytest-xdist
pytest-cov==5.0.0
pytest-httpserver==1.0.10
pytest-rerunfailures==14.0
pytest-xdist==3.5.0
python-dotenv==1.0.1
# via uvicorn
pyyaml==6.0.1
# via
# responses
# uvicorn
requests==2.31.0
# via responses
responses==0.25.0
ruff==0.4.0
setuptools==69.5.1
# via nodeenv
sniffio==1.3.1
# via
# anyio
# httpx
sortedcontainers==2.4.0
# via hypothesis
starlette==0.37.2
# via fastapi
structlog==24.1.0
typing-extensions==4.11.0
# via
# fastapi
# pydantic
urllib3==2.2.1
# via
# requests
# responses
uvicorn==0.29.0
uvloop==0.19.0
# via uvicorn
watchfiles==0.21.0
# via uvicorn
websockets==12.0
# via uvicorn
werkzeug==3.0.2
# via pytest-httpserver
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from cog import BasePredictor, Path
from typing import Optional
from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel # type: ignore
except ImportError:
from pydantic import BaseModel # pylint: disable=W0404


class ModelOutput(BaseModel):
Expand Down