-
Notifications
You must be signed in to change notification settings - Fork 74k
/
cudnn_rnn_ops.h
101 lines (86 loc) · 3.34 KB
/
cudnn_rnn_ops.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H
#define TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H 1
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/util/use_cudnn.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
namespace {
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
uint64 HashList(const std::vector<int>& list) {
if (list.empty()) {
return 0;
}
uint64 hash_code = list[0];
for (int i = 1; i < list.size(); i++) {
hash_code = Hash64Combine(hash_code, list[i]);
}
return hash_code;
}
// A helper class that collects the shapes to describe a RNN model.
struct CudnnRnnModelShapes {
int num_layers;
int input_size;
int num_units;
int dir_count;
int max_seq_length;
int batch_size;
int cell_num_units = 0;
// If you add new field to this structure, please take care of
// updating IsCompatibleWith() below as well as the hash function in
// CudnnRnnConfigHasher.
TensorShape input_shape;
TensorShape output_shape;
TensorShape hidden_state_shape;
TensorShape cell_state_shape;
// At present only fields related to cached RnnDescriptor are concerned.
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count &&
cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length;
}
string DebugString() const {
return strings::Printf(
"[num_layers, input_size, num_units, dir_count, max_seq_length, "
"batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
num_layers, input_size, num_units, dir_count, max_seq_length,
batch_size, cell_num_units);
}
};
// Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
// key.
struct CudnnRnnConfigHasher {
uint64 operator()(
const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
to_hash) const {
auto& shapes = to_hash.first;
auto& algo_desc = to_hash.second;
uint64 hash =
HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
if (algo_desc.has_value()) {
hash = Hash64Combine(hash, algo_desc->hash());
}
return hash;
}
};
}
}
#endif // TENSORFLOW_CORE_KERNELS_CUDNN_RNN_OPS_H