Skip to content

Commit

Permalink
Add accelerate support for ViLT (huggingface#18683)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored and oneraghavan committed Sep 26, 2022
1 parent e29a559 commit 3d257ce
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/vilt/modeling_vilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_att
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

# first residual connection
hidden_states = attention_output + hidden_states
hidden_states = attention_output + hidden_states.to(attention_output.device)

# in ViLT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
Expand Down Expand Up @@ -573,6 +573,7 @@ class ViltPreTrainedModel(PreTrainedModel):
config_class = ViltConfig
base_model_prefix = "vilt"
supports_gradient_checkpointing = True
_no_split_modules = ["ViltSelfAttention"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,6 @@ class CaptureStd:
```"""

def __init__(self, out=True, err=True, replay=True):

self.replay = replay

if out:
Expand Down Expand Up @@ -1122,7 +1121,6 @@ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
"""
if tmp_dir is not None:

# defining the most likely desired behavior for when a custom path is provided.
# this most likely indicates the debug mode where we want an easily locatable dir that:
# 1. gets cleared out before the test (if it already exists)
Expand Down Expand Up @@ -1200,7 +1198,6 @@ def python_one_liner_max_rss(self, one_liner_str):
return max_rss

def tearDown(self):

# get_auto_remove_tmp_dir feature: remove registered temp dirs
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
Expand Down Expand Up @@ -1472,7 +1469,6 @@ def tee(line, sink, pipe, label=""):


def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:

loop = asyncio.get_event_loop()
result = loop.run_until_complete(
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
Expand Down
2 changes: 0 additions & 2 deletions tests/models/vilt/test_modeling_vilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def prepare_pixel_values(self):

@require_torch
class ViltModelTest(ModelTesterMixin, unittest.TestCase):

all_model_classes = (
(
ViltModel,
Expand Down Expand Up @@ -512,7 +511,6 @@ def test_model_from_pretrained(self):

@require_torch
class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCase):

all_model_classes = (ViltForImagesAndTextClassification,) if is_torch_available() else ()

def setUp(self):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2307,6 +2307,7 @@ def test_disk_offload(self):
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
Expand All @@ -2324,6 +2325,7 @@ def test_disk_offload(self):
)

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0]))
Expand All @@ -2340,6 +2342,8 @@ def test_cpu_offload(self):
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
Expand All @@ -2355,6 +2359,8 @@ def test_cpu_offload(self):
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0]))
Expand All @@ -2371,6 +2377,8 @@ def test_model_parallelism(self):
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
Expand All @@ -2386,6 +2394,8 @@ def test_model_parallelism(self):
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0]))
Expand Down

0 comments on commit 3d257ce

Please sign in to comment.