From 9442cbb955147567792c37ace0090f7ae9df8360 Mon Sep 17 00:00:00 2001 From: Alexandre Lissy Date: Tue, 28 Jul 2020 20:41:54 +0200 Subject: [PATCH] Fix #41630: include max_seq_length in cudnn descriptor cache key --- tensorflow/core/kernels/cudnn_rnn_ops.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index b9b96d3fc70fb2..1a3f05fdcd9700 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -500,6 +500,9 @@ struct CudnnRnnModelShapes { int max_seq_length; int batch_size; int cell_num_units = 0; + // If you add new field to this structure, please take care of + // updating IsCompatibleWith() below as well as the hash function in + // CudnnRnnConfigHasher. TensorShape input_shape; TensorShape output_shape; TensorShape hidden_state_shape; @@ -508,7 +511,7 @@ struct CudnnRnnModelShapes { bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { return num_layers == rhs.num_layers && input_size == rhs.input_size && num_units == rhs.num_units && dir_count == rhs.dir_count && - cell_num_units == rhs.cell_num_units; + cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length; } string DebugString() const { return strings::Printf( @@ -530,7 +533,7 @@ struct CudnnRnnConfigHasher { uint64 hash = HashList({shapes.num_layers, shapes.input_size, shapes.num_units, - shapes.dir_count, shapes.batch_size}); + shapes.dir_count, shapes.max_seq_length, shapes.batch_size}); if (algo_desc.has_value()) { hash = Hash64Combine(hash, algo_desc->hash()); }