Skip to content

Commit

Permalink
Fix tensorflow#41630: include max_seq_length in cudnn descriptor cach…
Browse files Browse the repository at this point in the history
…e key
  • Loading branch information
lissyx authored and Alexandre Lissy committed Jul 29, 2020
1 parent 111f48d commit a63731f
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tensorflow/core/kernels/cudnn_rnn_ops.cc
Expand Up @@ -506,9 +506,19 @@ struct CudnnRnnModelShapes {
TensorShape cell_state_shape;
// At present only fields related to cached RnnDescriptor are concerned.
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
// Ensure the size of structure does not change, this will help
// people adding new field not forgetting about adding those fields to
// IsCompatibleWith()
//
// sizeof(struct CudnnRnnModelShapes) == 128
// sizeof(int) * 7 == 28
// sizeof(TensorShape) * 4 == 96
static_assert(sizeof(struct CudnnRnnModelShapes) >= 124
&& sizeof(struct CudnnRnnModelShapes) <= 128,
"check struct CudnnRnnModelShapes members");
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count &&
cell_num_units == rhs.cell_num_units;
cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length;
}
string DebugString() const {
return strings::Printf(
Expand All @@ -530,7 +540,7 @@ struct CudnnRnnConfigHasher {

uint64 hash =
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
shapes.dir_count, shapes.batch_size});
shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
if (algo_desc.has_value()) {
hash = Hash64Combine(hash, algo_desc->hash());
}
Expand Down

0 comments on commit a63731f

Please sign in to comment.