Skip to content

Commit

Permalink
Torch inference mode for prediction (#15719)
Browse files Browse the repository at this point in the history
torch inference mode for prediction
  • Loading branch information
Sherin Thomas committed Nov 19, 2022
1 parent f40eb2c commit 08d14ec
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/lightning_app/components/serve/python_server.py
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
Expand Down Expand Up @@ -105,7 +106,7 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

def setup(self) -> None:
def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down Expand Up @@ -154,7 +155,8 @@ def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
output_type: type = self.configure_output_type()

def predict_fn(request: input_type): # type: ignore
return self.predict(request)
with torch.inference_mode():
return self.predict(request)

fastapi_app.post("/predict", response_model=output_type)(predict_fn)

Expand Down Expand Up @@ -207,7 +209,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
Normally, you don't need to override this method.
"""
self.setup()
self.setup(*args, **kwargs)

fastapi_app = FastAPI()
self._attach_predict_fn(fastapi_app)
Expand Down

0 comments on commit 08d14ec

Please sign in to comment.