Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #41630: include max_seq_length in cudnn descriptor cache key #41832

Merged
merged 1 commit into from Aug 7, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensorflow/core/kernels/cudnn_rnn_ops.cc
Expand Up @@ -500,6 +500,9 @@ struct CudnnRnnModelShapes {
int max_seq_length;
int batch_size;
int cell_num_units = 0;
// If you add new field to this structure, please take care of
// updating IsCompatibleWith() below as well as the hash function in
// CudnnRnnConfigHasher.
TensorShape input_shape;
TensorShape output_shape;
TensorShape hidden_state_shape;
Expand All @@ -508,7 +511,7 @@ struct CudnnRnnModelShapes {
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count &&
cell_num_units == rhs.cell_num_units;
cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length;
lissyx marked this conversation as resolved.
Show resolved Hide resolved
}
string DebugString() const {
return strings::Printf(
Expand All @@ -530,7 +533,7 @@ struct CudnnRnnConfigHasher {

uint64 hash =
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
shapes.dir_count, shapes.batch_size});
shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
if (algo_desc.has_value()) {
hash = Hash64Combine(hash, algo_desc->hash());
}
Expand Down