Skip to content

Commit

Permalink
Merge pull request #41832 from lissyx:issue41630
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 325454474
Change-Id: I5fb9f4cfa4d9836056f3ba14c12f2a04d1a09a55
  • Loading branch information
tensorflower-gardener committed Aug 7, 2020
2 parents 17cdd71 + 9442cbb commit 0fb8017
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/cudnn_rnn_ops.cc
Expand Up @@ -500,6 +500,9 @@ struct CudnnRnnModelShapes {
int max_seq_length;
int batch_size;
int cell_num_units = 0;
// If you add new field to this structure, please take care of
// updating IsCompatibleWith() below as well as the hash function in
// CudnnRnnConfigHasher.
TensorShape input_shape;
TensorShape output_shape;
TensorShape hidden_state_shape;
Expand All @@ -508,7 +511,8 @@ struct CudnnRnnModelShapes {
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count &&
cell_num_units == rhs.cell_num_units;
cell_num_units == rhs.cell_num_units &&
max_seq_length == rhs.max_seq_length;
}
string DebugString() const {
return strings::Printf(
Expand All @@ -530,7 +534,7 @@ struct CudnnRnnConfigHasher {

uint64 hash =
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
shapes.dir_count, shapes.batch_size});
shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
if (algo_desc.has_value()) {
hash = Hash64Combine(hash, algo_desc->hash());
}
Expand Down

0 comments on commit 0fb8017

Please sign in to comment.