Skip to content

How data gets processed in a Gradio Interface

Abubakar Abid edited this page Mar 16, 2022 · 2 revisions

1. Standard data flow (when a user provides a prediction fn, inputs, and outputs to construct an Interface)

Let's take this image classification Space (abidlabs/pytorch-image-classifier) as an example:

import requests
from torchvision import transforms
import torch
import gradio as gr

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def predict(inp):
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
  return confidences

gr.Interface(fn=predict, 
             inputs=gr.inputs.Image(type="pil"),
             outputs=gr.outputs.Label(num_top_classes=3)).launch()

image

The data flow is as follows:

  1. When a user uploads an image and clicks submit, the image is serialized into a base64 format so that it can be sent to /api/predict on the server where the gradio application is running
  2. Based on the type of the input component, gradio automatically applies certain preprocessing steps to convert the input image into the format that the user's prediction fn expects. In this case, the base64 image is converted to the PIL image format.
  3. Then the image is run through the fn and in this case, a dictionary of labels is returned
  4. The labels are postprocessed into the appropriate format based on the parameters of the Label component the user specifies. In this case, the postprocessing identifies the 3 label with the highest confidence. Depending on the component e.g. Image output, the resulting output may also need to be serialized
  5. Then this serialized output is sent to the front end and displayed

2. "API mode" (when a user loads a model or Space using Interface.load())

This special case is used when building a Gradio demo on top of a Space or Model on the Hub using the inference API endpoint. Let's take this image classification Space (abidlabs/vision-transformer) as an example:

import gradio as gr
gr.Interface.load("huggingface/google/vit-base-patch16-224").launch()

image

The data flow is as follows:

  1. When a user uploads an image and clicks submit, the image is serialized into a base64 format so that it can be sent to /api/predict on the server where the gradio application is running
  2. Based on the type of the inference API endpoint, gradio automatically applies certain preprocessing steps to convert the input image into a particular format. In this case, the base64 image is converted to a RGB file and the filepath is returned.
  3. Based on the type of input component, gradio automatically again serializes the data to convert it to a format that can be sent to the API endpoint on the Hugging Face Hub.
  4. Then the image is sent to the inference API endpoint and some response is returned
  5. gradio deserializes the response data to get a dictionary of labels
  6. Based on the type of the inference API endpoint, gradio automatically applies certain postprocessing steps to convert the dictionary of labels into a particular format. In this case, the label with the highest confidence is extracted and passed along with rest of the dictionary of labels.
  7. Then this output sent to the front end and displayed