diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f98053c7d4fab8..52cf51e615f449 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1590,6 +1590,7 @@ tf_kernel_library( tf_kernel_library( name = "cudnn_rnn_kernels", srcs = ["cudnn_rnn_ops.cc"], + hdrs = ["cudnn_rnn_ops.h"], visibility = ["//visibility:public"], deps = [ ":bounds_check_lib", @@ -1602,6 +1603,24 @@ tf_kernel_library( ], ) +tf_cuda_cc_test( + name = "cudnn_rnn_kernels_test", + size = "small", + srcs = ["cudnn_rnn_ops_test.cc"], + deps = [ + ":cudnn_rnn_kernels", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "batch_norm_op_test", size = "small", diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index b9b96d3fc70fb2..0a92849b26f663 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/kernels/cudnn_rnn_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -132,17 +133,6 @@ using se::dnn::RnnStateTensorDescriptor; using se::dnn::ToDataType; using se::port::StatusOr; -uint64 HashList(const std::vector& list) { - if (list.empty()) { - return 0; - } - uint64 hash_code = list[0]; - for (int i = 1; i < list.size(); i++) { - hash_code = Hash64Combine(hash_code, list[i]); - } - return hash_code; -} - // Encapsulate all the shape information that is used in both forward and // backward rnn operations. class CudnnRnnParameters { @@ -491,53 +481,6 @@ struct CudnnModelTypes { } }; -// A helper class that collects the shapes to describe a RNN model. -struct CudnnRnnModelShapes { - int num_layers; - int input_size; - int num_units; - int dir_count; - int max_seq_length; - int batch_size; - int cell_num_units = 0; - TensorShape input_shape; - TensorShape output_shape; - TensorShape hidden_state_shape; - TensorShape cell_state_shape; - // At present only fields related to cached RnnDescriptor are concerned. - bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { - return num_layers == rhs.num_layers && input_size == rhs.input_size && - num_units == rhs.num_units && dir_count == rhs.dir_count && - cell_num_units == rhs.cell_num_units; - } - string DebugString() const { - return strings::Printf( - "[num_layers, input_size, num_units, dir_count, max_seq_length, " - "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", - num_layers, input_size, num_units, dir_count, max_seq_length, - batch_size, cell_num_units); - } -}; - -// Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table -// key. -struct CudnnRnnConfigHasher { - uint64 operator()( - const std::pair>& - to_hash) const { - auto& shapes = to_hash.first; - auto& algo_desc = to_hash.second; - - uint64 hash = - HashList({shapes.num_layers, shapes.input_size, shapes.num_units, - shapes.dir_count, shapes.batch_size}); - if (algo_desc.has_value()) { - hash = Hash64Combine(hash, algo_desc->hash()); - } - return hash; - } -}; - // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash // table key. struct CudnnRnnConfigComparator { diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.h b/tensorflow/core/kernels/cudnn_rnn_ops.h new file mode 100644 index 00000000000000..9befbf6dc93cda --- /dev/null +++ b/tensorflow/core/kernels/cudnn_rnn_ops.h @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H +#define TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H 1 + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/use_cudnn.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/stream_executor_util.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +namespace { +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; + +uint64 HashList(const std::vector& list) { + if (list.empty()) { + return 0; + } + uint64 hash_code = list[0]; + for (int i = 1; i < list.size(); i++) { + hash_code = Hash64Combine(hash_code, list[i]); + } + return hash_code; +} + +// A helper class that collects the shapes to describe a RNN model. +struct CudnnRnnModelShapes { + int num_layers; + int input_size; + int num_units; + int dir_count; + int max_seq_length; + int batch_size; + int cell_num_units = 0; + // If you add new field to this structure, please take care of + // updating IsCompatibleWith() below as well as the hash function in + // CudnnRnnConfigHasher. + TensorShape input_shape; + TensorShape output_shape; + TensorShape hidden_state_shape; + TensorShape cell_state_shape; + // At present only fields related to cached RnnDescriptor are concerned. + bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const { + return num_layers == rhs.num_layers && input_size == rhs.input_size && + num_units == rhs.num_units && dir_count == rhs.dir_count && + cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length; + } + string DebugString() const { + return strings::Printf( + "[num_layers, input_size, num_units, dir_count, max_seq_length, " + "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ", + num_layers, input_size, num_units, dir_count, max_seq_length, + batch_size, cell_num_units); + } +}; + +// Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table +// key. +struct CudnnRnnConfigHasher { + uint64 operator()( + const std::pair>& + to_hash) const { + auto& shapes = to_hash.first; + auto& algo_desc = to_hash.second; + + uint64 hash = + HashList({shapes.num_layers, shapes.input_size, shapes.num_units, + shapes.dir_count, shapes.max_seq_length, shapes.batch_size}); + if (algo_desc.has_value()) { + hash = Hash64Combine(hash, algo_desc->hash()); + } + return hash; + } +}; + +} + +} + +#endif // TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H diff --git a/tensorflow/core/kernels/cudnn_rnn_ops_test.cc b/tensorflow/core/kernels/cudnn_rnn_ops_test.cc new file mode 100644 index 00000000000000..248407c5312480 --- /dev/null +++ b/tensorflow/core/kernels/cudnn_rnn_ops_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/cudnn_rnn_ops.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class CudnnRnnModelShapesTest : public OpsTestBase {}; + +TEST_F(CudnnRnnModelShapesTest, NonVariable_SeqLength) { + CudnnRnnModelShapes model_shape_a; + model_shape_a.num_layers = 1; + model_shape_a.input_size = 2; + model_shape_a.num_units = 1; + model_shape_a.dir_count = 0; + model_shape_a.max_seq_length = 1; + model_shape_a.cell_num_units = 2; + + CudnnRnnModelShapes model_shape_b; + model_shape_b.num_layers = 1; + model_shape_b.input_size = 2; + model_shape_b.num_units = 1; + model_shape_b.dir_count = 0; + model_shape_b.max_seq_length = 1; + model_shape_b.cell_num_units = 2; + + EXPECT_TRUE(model_shape_a.IsCompatibleWith(model_shape_b)); +} + +TEST_F(CudnnRnnModelShapesTest, Variable_SeqLength) { + CudnnRnnModelShapes model_shape_a; + model_shape_a.num_layers = 1; + model_shape_a.input_size = 2; + model_shape_a.num_units = 1; + model_shape_a.dir_count = 0; + model_shape_a.max_seq_length = 1; + model_shape_a.cell_num_units = 2; + + CudnnRnnModelShapes model_shape_b; + model_shape_b.num_layers = 1; + model_shape_b.input_size = 2; + model_shape_b.num_units = 1; + model_shape_b.dir_count = 0; + model_shape_b.max_seq_length = 2; + model_shape_b.cell_num_units = 2; + + EXPECT_FALSE(model_shape_a.IsCompatibleWith(model_shape_b)); +} + +class CudnnRnnConfigHasherTest : public OpsTestBase {}; + +TEST_F(CudnnRnnConfigHasherTest, NonVariable_SeqLength) { + AlgorithmConfig algo_config; + AlgorithmDesc algo_desc(DebugCudnnRnnAlgo(), DebugCudnnRnnUseTensorOps()); + algo_config.set_algorithm(algo_desc); + + CudnnRnnModelShapes model_shape_a; + model_shape_a.num_layers = 1; + model_shape_a.input_size = 2; + model_shape_a.num_units = 1; + model_shape_a.dir_count = 0; + model_shape_a.max_seq_length = 1; + model_shape_a.batch_size = 1; + model_shape_a.cell_num_units = 2; + + CudnnRnnModelShapes model_shape_b; + model_shape_b.num_layers = 1; + model_shape_b.input_size = 2; + model_shape_b.num_units = 1; + model_shape_b.dir_count = 0; + model_shape_b.max_seq_length = 1; + model_shape_b.batch_size = 1; + model_shape_b.cell_num_units = 2; + + uint64 hash_a = CudnnRnnConfigHasher()(std::make_pair(model_shape_a, algo_config.algorithm())); + uint64 hash_b = CudnnRnnConfigHasher()(std::make_pair(model_shape_b, algo_config.algorithm())); + + EXPECT_TRUE(hash_a == hash_b); +} + +TEST_F(CudnnRnnConfigHasherTest, Variable_SeqLength) { + AlgorithmConfig algo_config; + AlgorithmDesc algo_desc(DebugCudnnRnnAlgo(), DebugCudnnRnnUseTensorOps()); + algo_config.set_algorithm(algo_desc); + + CudnnRnnModelShapes model_shape_a; + model_shape_a.num_layers = 1; + model_shape_a.input_size = 2; + model_shape_a.num_units = 1; + model_shape_a.dir_count = 0; + model_shape_a.max_seq_length = 1; + model_shape_a.batch_size = 1; + model_shape_a.cell_num_units = 2; + + CudnnRnnModelShapes model_shape_b; + model_shape_b.num_layers = 1; + model_shape_b.input_size = 2; + model_shape_b.num_units = 1; + model_shape_b.dir_count = 0; + model_shape_b.max_seq_length = 2; + model_shape_b.batch_size = 1; + model_shape_b.cell_num_units = 2; + + uint64 hash_a = CudnnRnnConfigHasher()(std::make_pair(model_shape_a, algo_config.algorithm())); + uint64 hash_b = CudnnRnnConfigHasher()(std::make_pair(model_shape_b, algo_config.algorithm())); + + EXPECT_TRUE(hash_a != hash_b); +} + +} // end namespace tensorflow