From 7e84723fe4e9a232e5e27dc38aed373c0c7ab94a Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Fri, 23 Sep 2022 16:24:28 +0300 Subject: [PATCH] Add semantic segmentation post-processing method to MobileViT (#19105) * add post-processing method for semantic segmentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/model_doc/mobilevit.mdx | 1 + .../mobilevit/feature_extraction_mobilevit.py | 50 ++++++++++++++++++- .../mobilevit/test_modeling_mobilevit.py | 24 +++++++++ 3 files changed, 73 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index 5725bd5ce5835..e0799d2962f2f 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -66,6 +66,7 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The T [[autodoc]] MobileViTFeatureExtractor - __call__ + - post_process_semantic_segmentation ## MobileViTModel diff --git a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py index 51e022b809c92..75bd6d51bc15c 100644 --- a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py +++ b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py @@ -14,16 +14,19 @@ # limitations under the License. """Feature extractor class for MobileViT.""" -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from PIL import Image from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor -from ...utils import TensorType, logging +from ...utils import TensorType, is_torch_available, logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -151,3 +154,46 @@ def __call__( encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`MobileViTForSemanticSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`MobileViTForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]`, *optional*): + A list of length `batch_size`, where each item is a `Tuple[int, int]` corresponding to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation diff --git a/tests/models/mobilevit/test_modeling_mobilevit.py b/tests/models/mobilevit/test_modeling_mobilevit.py index 84ffc7b89bc54..bb86cbc451fe6 100644 --- a/tests/models/mobilevit/test_modeling_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_mobilevit.py @@ -340,3 +340,27 @@ def test_inference_semantic_segmentation(self): ) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_post_processing_semantic_segmentation(self): + model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small") + model = model.to(torch_device) + + feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small") + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + outputs.logits = outputs.logits.detach().cpu() + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(50, 60)]) + expected_shape = torch.Size((50, 60)) + self.assertEqual(segmentation[0].shape, expected_shape) + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs) + expected_shape = torch.Size((32, 32)) + self.assertEqual(segmentation[0].shape, expected_shape)