Skip to content

Commit

Permalink
added integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalkraj committed Aug 27, 2021
1 parent 4ed282a commit 9c75ba8
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_modeling_flax_albert.py
Expand Up @@ -23,6 +23,7 @@


if is_flax_available():
import jax.numpy as jnp
from transformers.models.albert.modeling_flax_albert import (
FlaxAlbertForMaskedLM,
FlaxAlbertForMultipleChoice,
Expand Down Expand Up @@ -141,3 +142,20 @@ def test_model_from_pretrained(self):
model = model_class_name.from_pretrained("albert-base-v2", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)


@require_flax
class AlbertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = FlaxAlbertModel.from_pretrained("albert-base-v2", from_pt=True)
input_ids = np.array([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
output = model(input_ids, attention_mask=attention_mask)[0]
expected_shape = (1, 11, 768)
self.assertEqual(output.shape, expected_shape)
expected_slice = np.array(
[[[-0.6513, 1.5035, -0.2766], [-0.6515, 1.5046, -0.2780], [-0.6512, 1.5049, -0.2784]]]
)

self.assertTrue(jnp.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))

0 comments on commit 9c75ba8

Please sign in to comment.