Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GatherLayer on batch axis #1089

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

GatherLayer on batch axis #1089

wants to merge 5 commits into from

Conversation

vieting
Copy link
Contributor

@vieting vieting commented Aug 4, 2022

This PR fixes #1087. As I face the issue in the context of supervised multilingual training, I added a more general test case also for that which does not necessarily need to go into the main branch. The fix is similar to how the size_placeholder is modified in the ShiftAxisLayer.

@vieting vieting requested review from albertz and a team as code owners August 4, 2022 09:32
@vieting
Copy link
Contributor Author

vieting commented Aug 10, 2022

Hi @albertz, what do you think about the way the size placeholder and dim tag are modified in general? Right now there is a failing test case, where we first do flatten_batch and then gather on the batch axis. I'm not very sure how the desired behavior in this case would look like. It'd be nice if you could comment on what you think here.

kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis] = new_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't use the Dim object you created?
Instead of assigning size_placeholder, I think it would be better to set the newly created dim tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the usual way to set dim tags? I can't just reassign self.output.dim_tags. declare_same_as is used elsewhere, but not sure if it applies here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See most other layers. Usually you set dim_tags in get_out_data_from_opts. You should not assign a new dim tag in __init__. In __init__, you just might to assign the dyn_size_ext or dyn_size_ext.placeholder of a dim tag which was previously newly created in get_out_data_from_opts.

# gather targets and encoder outputs
"tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse)
"enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F
"enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

returnn/tf/layers/basic.py Outdated Show resolved Hide resolved
from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You miss dyn_size_ext here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not set dyn_size in case it is non-standard. Set dyn_size_ext instead.

@albertz
Copy link
Member

albertz commented Aug 10, 2022

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

@@ -1341,6 +1341,17 @@ def __init__(self, position, axis, **kwargs):
# (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..)
self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims)

if input_data.dim_tags[old_gather_axis].is_batch_dim():
for axis in self.output.size_placeholder:
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You assume that position_data is of shape [new-batch]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, in the case I have in mind yes. But for the failing test case, this is different and we need to take this into account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is it in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There it's of shape [B,T,F], however, in the input B and T are packed

>>> input_data                                                                                                                                                                                                                  
Data{'flat_output', [B&Packed{'time'},F|F'feature'(5)]}
>>> self.output                                                                                                                                                                                                                 
Data{'output_output', [B,T|'time'[B],'other-spatial'(7),F|F'feature'(5)]}
>>> position_data                                                                                                                                                                                                               
Data{'indices_flat_output', [B,T|'time'[B],F|'other-spatial'(7)], dtype='int32'}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test case is that? The one you added? test_rand_indices?
Why is position_data of this shape? As described, it should have some new-batch dim in it, right? Or basically just the shape [new-batch]? When you gather into the batch dim. It definitely should not have the old batch dim in its shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that I need to assign it for output. But it should come from position_data, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it None for position_data? I don't mean in the test case, I mean in the real case which motivated this test case. In the real case, you would not have such InternalLayer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should never be done if the data has a batch dim, unless sth is wrong. In case of the test case, then the test case is buggy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is about the test case. However, in the case that I'm interested in, still input_data.batch == position_data.batch is True. This is probably because I'm using an EvalLayer to get the batch indices from a 0/1 vector with shape (B,) and that EvalLayer does not set the output correctly. Then we would need a layer which does that correctly, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An EvalLayer should never change the shape. If it does, and you are not very careful in setting the output data, then yes, this is a bug in your config.

@vieting
Copy link
Contributor Author

vieting commented Aug 10, 2022

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

As I said, the fix is similar to what is done in the ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?

@albertz
Copy link
Member

albertz commented Aug 10, 2022

When you modify the batch dim, you should create a new BatchInfo object as well, and assign that to output.

As I said, the fix is similar to what is done in the ShiftAxisLayer. Do you have another layer which modifies the batch axis and could serve as a good example?

ShiftAxisLayer does not modify the batch dim. You probably mean the size adoption. That code is a bit ugly/outdated/deprecated/hacky in ShiftAxisLayer, and might not work correct in all cases (but anyway it's simpler because the batch dim is not changed).

@albertz
Copy link
Member

albertz commented Aug 10, 2022

Do you have another layer which modifies the batch axis and could serve as a good example?

Not many layers do that. I just recall FlattenBatchLayer right now.

kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis] = new_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not assign size_placeholder but rather the dim tags.

from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not assign dyn_size but rather dyn_size_ext.

for dim_tag in self.output.dim_tags:
if dim_tag.is_spatial_dim():
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description))
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not access size_placeholder but rather dim_tag.dyn_size_ext.

if input_data.dim_tags[old_gather_axis].is_batch_dim():
for dim_tag in self.output.dim_tags:
if dim_tag.is_spatial_dim():
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is:

  • way too complicated: you can simply do for axis, dim_tag in enumerate(self.output.dim_tags)
  • wrong: do not rely on get_axis_by_tag_name and dim_tag.description
  • not necessary: just use dim_tag.dyn_size_ext

Comment on lines +5482 to +5488
position = InternalLayer(
name="position", network=net,
output=Data(
name="position",
placeholder=tf.constant(position_np, dtype=tf.int64),
batch_dim_axis=0, shape=[], dtype="int64",
))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertz do I need to change the creation of position in order to make it have a different batch axis dim tag here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's actually not so simple because of the special treatment of the batch dim tag. I'm not sure it's really possible currently.

In practice, in your real code, how would you end up with position?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, in your real code, how would you end up with position?
Do you mean what dim tag I get there?

>>> position.output.dim_tags[0].description                                                                                                                                                                                     
'batch:position'

So it's not actually the global batch dim. I was just confused because I got

>>> position.output.dim_tags[0] == values.output.dim_tags[0]                                                                                                                                                                    
True

but this is because the check does not cover this case, see comment here: #1089 (comment)

@vieting
Copy link
Contributor Author

vieting commented Aug 30, 2022

As discussed offline, it is possible to get the desired results in my use case using the MaskedComputationLayer. Instead of the indices to gather, we need a boolean mask over the batch axis. In my use case, I have this anyway and only computed the indices from the mask. We can use the mask like this:

network = {
    "encoder": {...},  # B, T, F
    "boolean_mask": {...},  # B
    "encoder_masked": {
        "class": "masked_computation",
        "mask": "boolean_mask",
        "unit": {"class": "copy", "from": "encoder"}
    },  # B', T, F
    ...
}

Since that does exactly what I need, I'll close this PR and the corresponding issue.

@vieting vieting closed this Aug 30, 2022
@albertz
Copy link
Member

albertz commented Aug 30, 2022

Well, GatherLayer on batch axis is still maybe sometimes a valid thing someone wants to do. I would leave this PR open.

@albertz albertz reopened this Aug 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GatherLayer on batch axis
2 participants