Skip to content

Commit

Permalink
Add semantic segmentation post-processing method to MobileViT (#19105)
Browse files Browse the repository at this point in the history
* add post-processing method for semantic segmentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
alaradirik and sgugger committed Sep 23, 2022
1 parent 905635f commit 7e84723
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/en/model_doc/mobilevit.mdx
Expand Up @@ -66,6 +66,7 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The T

[[autodoc]] MobileViTFeatureExtractor
- __call__
- post_process_semantic_segmentation

## MobileViTModel

Expand Down
50 changes: 48 additions & 2 deletions src/transformers/models/mobilevit/feature_extraction_mobilevit.py
Expand Up @@ -14,16 +14,19 @@
# limitations under the License.
"""Feature extractor class for MobileViT."""

from typing import Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from PIL import Image

from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import TensorType, logging
from ...utils import TensorType, is_torch_available, logging


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -151,3 +154,46 @@ def __call__(
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)

return encoded_inputs

def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`MobileViTForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]`, *optional*):
A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits

# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)

if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()

semantic_segmentation = []

for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]

return semantic_segmentation
24 changes: 24 additions & 0 deletions tests/models/mobilevit/test_modeling_mobilevit.py
Expand Up @@ -340,3 +340,27 @@ def test_inference_semantic_segmentation(self):
)

self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))

@slow
def test_post_processing_semantic_segmentation(self):
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
model = model.to(torch_device)

feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")

image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

outputs.logits = outputs.logits.detach().cpu()

segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)])
expected_shape = torch.Size((50, 60))
self.assertEqual(segmentation[0].shape, expected_shape)

segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((32, 32))
self.assertEqual(segmentation[0].shape, expected_shape)

0 comments on commit 7e84723

Please sign in to comment.