Skip to content

Commit

Permalink
Move sync_batch_norm to a separate method
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jul 12, 2021
1 parent ec89c84 commit a8a8a3b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 28 deletions.
Expand Up @@ -80,8 +80,6 @@ class LayoutLMv2Config(PretrainedConfig):
The maximum number of relative 2D positions in the self-attention mechanism.
rel_2d_pos_bins (:obj:`int`, `optional`, defaults to 64):
The number of 2D relative position bins in the self-attention mechanism.
convert_sync_batchnorm (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to convert BatchNorm layers to SyncNorm layers before wrapping the visual backbone with DDP.
image_feature_pool_shape (:obj:`List[int]`, `optional`, defaults to [7, 7, 256]):
The shape of the average-pooled feature map.
coordinate_size (:obj:`int`, `optional`, defaults to 128):
Expand Down
18 changes: 17 additions & 1 deletion src/transformers/models/layoutlmv2/detectron2_config.py
@@ -1,4 +1,20 @@
# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LayoutLMv2 Detectron2 visual backbone configuration """


def add_layoutlmv2_config(cfg):
_C = cfg
# -----------------------------------------------------------------------------
Expand Down
56 changes: 31 additions & 25 deletions src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Expand Up @@ -359,7 +359,7 @@ def __init__(self, config):
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)

def _cal_1d_pos_emb(self, hidden_states, position_ids):
def _calculate_1d_position_embeddings(self, hidden_states, position_ids):
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
rel_pos = relative_position_bucket(
rel_pos_mat,
Expand All @@ -371,7 +371,7 @@ def _cal_1d_pos_emb(self, hidden_states, position_ids):
rel_pos = rel_pos.contiguous()
return rel_pos

def _cal_2d_pos_emb(self, hidden_states, bbox):
def _calculate_2d_position_embeddings(self, hidden_states, bbox):
position_coord_x = bbox[:, :, 0]
position_coord_y = bbox[:, :, 3]
rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
Expand Down Expand Up @@ -409,8 +409,14 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None

rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
rel_pos = (
self._calculate_1d_position_embeddings(hidden_states, position_ids)
if self.has_relative_attention_bias
else None
)
rel_2d_pos = (
self._calculate_2d_position_embeddings(hidden_states, bbox) if self.has_spatial_attention_bias else None
)

for i, layer_module in enumerate(self.layer):
if output_hidden_states:
Expand Down Expand Up @@ -529,27 +535,6 @@ def __init__(self, config):
model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg)
assert isinstance(model.backbone, detectron2.modeling.backbone.FPN)
self.backbone = model.backbone
if (
config.convert_sync_batchnorm
and torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > -1
):
self_rank = torch.distributed.get_rank()
node_size = torch.cuda.device_count()
world_size = torch.distributed.get_world_size()
assert world_size % node_size == 0

node_global_ranks = [
list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)
]
sync_bn_groups = [
torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
]
node_rank = self_rank // node_size
assert self_rank in node_global_ranks[node_rank]

self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])

assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD)
num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
Expand Down Expand Up @@ -582,6 +567,27 @@ def forward(self, images):
features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous()
return features

def synchronize_batch_norm(self):
assert (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > -1
), "Please make sure torch.distributed is set up properly."

self_rank = torch.distributed.get_rank()
node_size = torch.cuda.device_count()
world_size = torch.distributed.get_world_size()
assert world_size % node_size == 0

node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)]
sync_bn_groups = [
torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
]
node_rank = self_rank // node_size
assert self_rank in node_global_ranks[node_rank]

self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])


LAYOUTLMV2_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
Expand Down

0 comments on commit a8a8a3b

Please sign in to comment.