diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index f0361f9db5046..731bf1c37e969 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict, Optional +import torch import uvicorn from fastapi import FastAPI from pydantic import BaseModel @@ -105,7 +106,7 @@ def predict(self, request): self._input_type = input_type self._output_type = output_type - def setup(self) -> None: + def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or initialize the weights, setting up pipelines etc. @@ -154,7 +155,8 @@ def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: output_type: type = self.configure_output_type() def predict_fn(request: input_type): # type: ignore - return self.predict(request) + with torch.inference_mode(): + return self.predict(request) fastapi_app.post("/predict", response_model=output_type)(predict_fn) @@ -207,7 +209,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: Normally, you don't need to override this method. """ - self.setup() + self.setup(*args, **kwargs) fastapi_app = FastAPI() self._attach_predict_fn(fastapi_app)