Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsemax layer does not actually support marking even though the support_masking flag is True #2750

Open
pushpendre opened this issue Aug 23, 2022 · 2 comments

Comments

@pushpendre
Copy link

Describe the bug

The Sparsemax layer has the supports_masking flag True but its call method does not actually take in a mask.

Code to reproduce the issue

See https://github.com/tensorflow/addons/blob/v0.17.0/tensorflow_addons/layers/sparsemax.py#L36

Other info / logs

This issue can be fixed by supporting the use of mask

BIG_NUMBER = 1e6
def call(self, inputs, mask=None):
        return sparsemax(inputs - (1 - mask) * BIG_NUMBER, axis=self.axis) * mask

This code can be tested that it works as follows

    import numpy as np
    from tensorflow_addons.activations.tests.sparsemax_test import _np_sparsemax

    att_scores = np.random.randn(2, 3, 5, 7).astype(np.float32)
    mask = (np.random.uniform(0, 1, (2, 1, 5, 7)) > 0.5).astype(np.float32)
    
    out_sparsemax = Sparsemax()(att_scores, mask).numpy()
    for i in range(2):
        for j in range(3):
            for k in range(5):
                m = mask[i,0,k].astype(bool)
                a = out_sparsemax[i,j,k][m]
                b = _np_sparsemax(att_scores[i,j,k][m][np.newaxis, :])[0]
                assert np.allclose(a, b, atol=1e-6)
@bhack
Copy link
Contributor

bhack commented Aug 23, 2022

Can you extend the current test with a PR and then commit a fix?

@pushpendre
Copy link
Author

i have a couple of deadlines coming up, I'll try to put something together probably after Sep 15.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants