Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPR] Correct init #13796

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/model_doc/dpr.rst
Expand Up @@ -41,6 +41,13 @@ DPRConfig
:members:


DPRPreTrainedModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.DPRPreTrainedModel
:members:


DPRContextEncoderTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -773,6 +773,7 @@
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
Expand Down Expand Up @@ -2512,6 +2513,7 @@
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/dpr/__init__.py
Expand Up @@ -46,6 +46,7 @@
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
"DPRContextEncoder",
"DPRPretrainedContextEncoder",
"DPRPreTrainedModel",
"DPRPretrainedQuestionEncoder",
"DPRPretrainedReader",
"DPRQuestionEncoder",
Expand Down Expand Up @@ -89,6 +90,7 @@
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPRContextEncoder,
DPRPretrainedContextEncoder,
DPRPreTrainedModel,
DPRPretrainedQuestionEncoder,
DPRPretrainedReader,
DPRQuestionEncoder,
Expand Down
59 changes: 27 additions & 32 deletions src/transformers/models/dpr/modeling_dpr.py
Expand Up @@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None


class DPREncoder(PreTrainedModel):
class DPRPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


class DPREncoder(DPRPreTrainedModel):

base_model_prefix = "bert_model"

Expand Down Expand Up @@ -200,13 +222,8 @@ def embeddings_size(self) -> int:
return self.encode_proj.out_features
return self.bert_model.config.hidden_size

def init_weights(self):
self.bert_model.init_weights()
if self.projection_dim > 0:
self.encode_proj.apply(self.bert_model._init_weights)


class DPRSpanPredictor(PreTrainedModel):
class DPRSpanPredictor(DPRPreTrainedModel):

base_model_prefix = "encoder"

Expand Down Expand Up @@ -262,16 +279,13 @@ def forward(
attentions=outputs.attentions,
)

def init_weights(self):
self.encoder.init_weights()


##################
# PreTrainedModel
##################


class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedContextEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.ctx_encoder.init_weights()


class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.question_encoder.init_weights()

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


class DPRPretrainedReader(PreTrainedModel):
class DPRPretrainedReader(DPRPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
Expand All @@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def init_weights(self):
self.span_predictor.encoder.init_weights()
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


###############
# Actual Models
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Expand Up @@ -1462,6 +1462,15 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DPRPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DPRPretrainedQuestionEncoder:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down
14 changes: 14 additions & 0 deletions tests/test_modeling_dpr.py
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import tempfile
import unittest

from transformers import DPRConfig, is_torch_available
Expand Down Expand Up @@ -213,6 +214,19 @@ def test_reader_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reader(*config_and_inputs)

def test_init_changed_config(self):
config = self.model_tester.prepare_config_and_inputs()[0]

model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)

self.assertIsNotNone(model)

@slow
def test_model_from_pretrained(self):
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down