Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More rigorous shape inference in to_tf_dataset #4763

Merged
merged 5 commits into from Sep 8, 2022

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Jul 28, 2022

tf.data needs to know the shape of tensors emitted from a tf.data.Dataset. Although None dimensions are possible, overusing them can cause problems - Keras uses the dataset tensor spec at compile-time, and so saying that a dimension is None when it's actually constant can hurt performance, or even cause training to fail for dimensions that are needed to determine the shape of weight tensors!

The compromise I used here was to sample several batches from the underlying dataset and apply the collate_fn to them, and then to see which dimensions were "empirically variable". There's an obvious problem here, though - if you sample 10 batches and they all have the same shape on a certain dimension, there's still a small chance that the 11th batch will be different, and Keras will throw an error if a dataset tries to emit a tensor whose shape doesn't match the spec.

I encountered this bug in practice once or twice for datasets that were mostly-but-not-totally constant on a given dimension, and I still don't have a perfect solution, but this PR should greatly reduce the risk. It samples many more batches, and also samples very small batches (size 2) - this increases the variability, making it more likely that a few outlier samples will be detected.

Ideally, of course, we'd determine the full output shape analytically, but that's surprisingly tricky when the collate_fn can be any arbitrary Python code!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -420,6 +420,26 @@ def to_tf_dataset(
batch_size=batch_size if drop_remainder else None,
)

shape_verification_signature, _ = dataset._get_output_signature(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to call it a second time ? can't this logic be inside _get_output_signature ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make sense, actually! I'll move it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq I cleaned things up a lot based on your feedback - _get_output_signature is only called once, and it now immediately samples 200 batches of size 2 to infer the shape, but then overwrites the batch size element of the inferred shape with the actual batch size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! :)

I also think 10 batches is good by default, going to 200 batches can take too much time for some datasets IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually specifically had problems with incorrect inferences when using 10! I think it's preferable for to_tf_dataset() to be a little slow sometimes (it's only called once at dataset creation time) than to infer wrong shapes and create tricky bugs for users.

If you want, though, I can make num_test_batches an argument to to_tf_dataset?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually specifically had problems with incorrect inferences when using 10!

Can you explain what problems ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases, sampling 10 batches from the dataset makes it look like the dataset has a constant shape, but actually it doesn't. This is particularly common when datasets have been truncated. For example, if the average length in a dataset before truncation is >> 512, but we truncate at 512, then most batches will have length 512, but if some samples in the dataset have length < 512, then there will occasionally be batches with length < 512 too.

By reducing the batch size for shape inference and increasing the number of batches sampled, this problem is resolved in all the cases I know about!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding a way for users to specify if the shapes are fixed or not ? Could be via a new parameter, or by checking if the feature type is Sequence(..., length=512)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea! We'll still need shape inference but it might be useful, and I can look into adding it when I get back!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq Reading the shape from Sequence features has been added!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1 Rocketknight1 merged commit 08a7b38 into main Sep 8, 2022
@Rocketknight1 Rocketknight1 deleted the update_tf_shape_inference branch September 8, 2022 19:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants