diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 4a27394f289a29..1eeab78943033b 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -500,6 +500,9 @@ struct CudnnRnnModelShapes { int max_seq_length; int batch_size; int cell_num_units = 0; + // If you add new field to this structure, please take care of + // updating IsCompatibleWith() below as well as the hash function in + // CudnnRnnConfigHasher. TensorShape input_shape; TensorShape output_shape; TensorShape hidden_state_shape; @@ -508,7 +511,7 @@ struct CudnnRnnModelShapes { bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { return num_layers == rhs.num_layers && input_size == rhs.input_size && num_units == rhs.num_units && dir_count == rhs.dir_count && - cell_num_units == rhs.cell_num_units; + cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length; } string DebugString() const { return strings::Printf( @@ -530,7 +533,7 @@ struct CudnnRnnConfigHasher { uint64 hash = HashList({shapes.num_layers, shapes.input_size, shapes.num_units, - shapes.dir_count, shapes.batch_size}); + shapes.dir_count, shapes.max_seq_length, shapes.batch_size}); if (algo_desc.has_value()) { hash = Hash64Combine(hash, algo_desc->hash()); }