diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index b9b96d3fc70fb2..d392452ca18647 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -506,9 +506,19 @@ struct CudnnRnnModelShapes { TensorShape cell_state_shape; // At present only fields related to cached RnnDescriptor are concerned. bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { + // Ensure the size of structure does not change, this will help + // people adding new field not forgetting about adding those fields to + // IsCompatibleWith() + // + // sizeof(struct CudnnRnnModelShapes) == 128 + // sizeof(int) * 7 == 28 + // sizeof(TensorShape) * 4 == 96 + static_assert(sizeof(struct CudnnRnnModelShapes) >= 124 + && sizeof(struct CudnnRnnModelShapes) <= 128, + "check struct CudnnRnnModelShapes members"); return num_layers == rhs.num_layers && input_size == rhs.input_size && num_units == rhs.num_units && dir_count == rhs.dir_count && - cell_num_units == rhs.cell_num_units; + cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length; } string DebugString() const { return strings::Printf( @@ -530,7 +540,7 @@ struct CudnnRnnConfigHasher { uint64 hash = HashList({shapes.num_layers, shapes.input_size, shapes.num_units, - shapes.dir_count, shapes.batch_size}); + shapes.dir_count, shapes.max_seq_length, shapes.batch_size}); if (algo_desc.has_value()) { hash = Hash64Combine(hash, algo_desc->hash()); }