Skip to content

Commit

Permalink
[fbsync] Add SimpleCopyPaste augmentation (pytorch#5825)
Browse files Browse the repository at this point in the history
Summary:
* added simple POC

* added jitter and crop options

* added references

* moved simplecopypaste to detection module

* working POC for simple copy paste in detection

* added comments

* remove transforms from class
updated the labels
added gaussian blur

* removed loop for mask calculation

* replaced Gaussian blur with functional api

* added inplace operations

* added changes to accept tuples instead of tensors

* - make copy paste functional
- make only one copy of batch and target

* add inplace support within copy paste functional

* Updated code for copy-paste transform

* Fixed code formatting

* [skip ci] removed manual thresholding

* Replaced cropping by resizing data to paste

* Removed inplace arg (as useless) and put a check on iscrowd target

* code-formatting

* Updated copypaste op to make it torch scriptable
Added fallbacks to support LSJ

* Fixed flake8

* Updates according to the review

Differential Revision: D37212651

fbshipit-source-id: 467b670164150dd5cc424f4d616d436295ce818d

Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Jun 16, 2022
1 parent f9f6782 commit b3a6867
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 1 deletion.
23 changes: 22 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from coco_utils import get_coco, get_coco_kp
from engine import train_one_epoch, evaluate
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste


def get_dataset(name, image_set, transform, data_path):
Expand Down Expand Up @@ -145,6 +147,13 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

# Use CopyPaste augmentation training parameter
parser.add_argument(
"--use-copypaste",
action="store_true",
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)

return parser


Expand Down Expand Up @@ -180,8 +189,20 @@ def main(args):
else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

train_collate_fn = utils.collate_fn
if args.use_copypaste:
if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")

copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)

def copypaste_collate_fn(batch):
return copypaste(*utils.collate_fn(batch))

train_collate_fn = copypaste_collate_fn

data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
Expand Down
155 changes: 155 additions & 0 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torchvision
from torch import nn, Tensor
from torchvision import ops
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T, InterpolationMode

Expand Down Expand Up @@ -437,3 +438,157 @@ def forward(
)

return image, target


def _copy_paste(
image: torch.Tensor,
target: Dict[str, Tensor],
paste_image: torch.Tensor,
paste_target: Dict[str, Tensor],
blending: bool = True,
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
) -> Tuple[torch.Tensor, Dict[str, Tensor]]:

# Random paste targets selection:
num_masks = len(paste_target["masks"])

if num_masks < 1:
# Such degerante case with num_masks=0 can happen with LSJ
# Let's just return (image, target)
return image, target

# We have to please torch script by explicitly specifying dtype as torch.long
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
random_selection = torch.unique(random_selection).to(torch.long)

paste_masks = paste_target["masks"][random_selection]
paste_boxes = paste_target["boxes"][random_selection]
paste_labels = paste_target["labels"][random_selection]

masks = target["masks"]

# We resize source and paste data if they have different sizes
# This is something we introduced here as originally the algorithm works
# on equal-sized data (for example, coming from LSJ data augmentations)
size1 = image.shape[-2:]
size2 = paste_image.shape[-2:]
if size1 != size2:
paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
# resize bboxes:
ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)

paste_alpha_mask = paste_masks.sum(dim=0) > 0

if blending:
paste_alpha_mask = F.gaussian_blur(
paste_alpha_mask.unsqueeze(0),
kernel_size=(5, 5),
sigma=[
2.0,
],
)

# Copy-paste images:
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)

# Copy-paste masks:
masks = masks * (~paste_alpha_mask)
non_all_zero_masks = masks.sum((-1, -2)) > 0
masks = masks[non_all_zero_masks]

# Do a shallow copy of the target dict
out_target = {k: v for k, v in target.items()}

out_target["masks"] = torch.cat([masks, paste_masks])

# Copy-paste boxes and labels
boxes = ops.masks_to_boxes(masks)
out_target["boxes"] = torch.cat([boxes, paste_boxes])

labels = target["labels"][non_all_zero_masks]
out_target["labels"] = torch.cat([labels, paste_labels])

# Update additional optional keys: area and iscrowd if exist
if "area" in target:
out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)

if "iscrowd" in target and "iscrowd" in paste_target:
# target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
# For example, if previous transforms geometrically modifies masks/boxes/labels but
# does not update "iscrowd"
if len(target["iscrowd"]) == len(non_all_zero_masks):
iscrowd = target["iscrowd"][non_all_zero_masks]
paste_iscrowd = paste_target["iscrowd"][random_selection]
out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])

# Check for degenerated boxes and remove them
boxes = out_target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
valid_targets = ~degenerate_boxes.any(dim=1)

out_target["boxes"] = boxes[valid_targets]
out_target["masks"] = out_target["masks"][valid_targets]
out_target["labels"] = out_target["labels"][valid_targets]

if "area" in out_target:
out_target["area"] = out_target["area"][valid_targets]
if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
out_target["iscrowd"] = out_target["iscrowd"][valid_targets]

return image, out_target


class SimpleCopyPaste(torch.nn.Module):
def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.resize_interpolation = resize_interpolation
self.blending = blending

def forward(
self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
torch._assert(
isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
"images should be a list of tensors",
)
torch._assert(
isinstance(targets, (list, tuple)) and len(images) == len(targets),
"targets should be a list of the same size as images",
)
for target in targets:
# Can not check for instance type dict with inside torch.jit.script
# torch._assert(isinstance(target, dict), "targets item should be a dict")
for k in ["masks", "boxes", "labels"]:
torch._assert(k in target, f"Key {k} should be present in targets")
torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")

# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
# paste_images = [t2, t3, ..., tN, t1]
# FYI: in TF they mix data on the dataset level
images_rolled = images[-1:] + images[:-1]
targets_rolled = targets[-1:] + targets[:-1]

output_images: List[torch.Tensor] = []
output_targets: List[Dict[str, Tensor]] = []

for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
output_image, output_data = _copy_paste(
image,
target,
paste_image,
paste_target,
blending=self.blending,
resize_interpolation=self.resize_interpolation,
)
output_images.append(output_image)
output_targets.append(output_data)

return output_images, output_targets

def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s

0 comments on commit b3a6867

Please sign in to comment.