Skip to content

Commit

Permalink
[DPR] Correct init (#13796)
Browse files Browse the repository at this point in the history
* update

* add to docs and init

* make fix-copies
  • Loading branch information
patrickvonplaten committed Sep 30, 2021
1 parent 44eb8bd commit 41436d3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 32 deletions.
7 changes: 7 additions & 0 deletions docs/source/model_doc/dpr.rst
Expand Up @@ -41,6 +41,13 @@ DPRConfig
:members:


DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DPRPreTrainedModel
:members:


DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -773,6 +773,7 @@
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
Expand Down Expand Up @@ -2512,6 +2513,7 @@
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dpr/__init__.py
Expand Up @@ -46,6 +46,7 @@
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
Expand Down Expand Up @@ -89,6 +90,7 @@
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
Expand Down
59 changes: 27 additions & 32 deletions src/transformers/models/dpr/modeling_dpr.py
Expand Up @@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


class DPREncoder(PreTrainedModel):
class DPRPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


class DPREncoder(DPRPreTrainedModel):

base_model_prefix = "bert_model"

Expand Down Expand Up @@ -200,13 +222,8 @@ def embeddings_size(self) -> int:
return self.encode_proj.out_features
return self.bert_model.config.hidden_size

def init_weights(self):
self.bert_model.init_weights()
if self.projection_dim > 0:
self.encode_proj.apply(self.bert_model._init_weights)


class DPRSpanPredictor(PreTrainedModel):
class DPRSpanPredictor(DPRPreTrainedModel):

base_model_prefix = "encoder"

Expand Down Expand Up @@ -262,16 +279,13 @@ def forward(
attentions=outputs.attentions,
)

def init_weights(self):
self.encoder.init_weights()


##################
# PreTrainedModel
##################


class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedContextEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.ctx_encoder.init_weights()


class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.question_encoder.init_weights()

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


class DPRPretrainedReader(PreTrainedModel):
class DPRPretrainedReader(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.span_predictor.encoder.init_weights()
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


###############
# Actual Models
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Expand Up @@ -1462,6 +1462,15 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DPRPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DPRPretrainedQuestionEncoder:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down
14 changes: 14 additions & 0 deletions tests/test_modeling_dpr.py
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import tempfile
import unittest

from transformers import DPRConfig, is_torch_available
Expand Down Expand Up @@ -213,6 +214,19 @@ def test_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reader(*config_and_inputs)

def test_init_changed_config(self):
config = self.model_tester.prepare_config_and_inputs()[0]

model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)

self.assertIsNotNone(model)

@slow
def test_model_from_pretrained(self):
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down

0 comments on commit 41436d3

Please sign in to comment.