aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc
blob: 6b0452e7af6ef37d3c9cac222a9a22d9215cb0b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#include <stddef.h>
#include <atomic>
#include <cmath>
#include <functional>
#include <limits>
#include <string>
#include <unordered_set>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"

#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif  // GOOGLE_CUDA

/*
 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
 * using the underlying Cudnn library.
 *
 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
 * format. And it is very likely that if saved, they cannot be used across
 * different GPUs. So users need to first query the size of the opaque
 * parameter buffer, and convert it to and from its canonical forms. But each
 * actual training step is carried out with the parameter buffer.
 *
 * Similar to many other ops, the forward op has two flavors: training and
 * inference. When training is specified, additional data in reserve_space will
 * be produced for the backward pass. So there is a performance penalty.
 *
 * In addition to the actual data and reserve_space, Cudnn also needs more
 * memory as temporary workspace. The memory management to and from
 * stream-executor is done through ScratchAllocator. In general,
 * stream-executor is responsible for creating the memory of proper size. And
 * TensorFlow is responsible for making sure the memory is alive long enough
 * and recycles afterwards.
 *
 */
namespace tensorflow {

using CPUDevice = Eigen::ThreadPoolDevice;

#if GOOGLE_CUDA

using GPUDevice = Eigen::GpuDevice;

template <typename Device, typename T, typename Index>
class CudnnRNNParamsSizeOp;

template <typename Device, typename T>
class CudnnRNNParamsToCanonical;

template <typename Device, typename T>
class CudnnRNNCanonicalToParams;

template <typename Device, typename T>
class CudnnRNNForwardOp;

template <typename Device, typename T>
class CudnnRNNBackwardOp;

enum class TFRNNInputMode {
  kRNNLinearInput = 0,
  kRNNSkipInput = 1,
  kAutoSelect = 9999999
};

namespace {
using perftools::gputools::DeviceMemory;
using perftools::gputools::DeviceMemoryBase;
using perftools::gputools::ScratchAllocator;
using perftools::gputools::dnn::RnnDirectionMode;
using perftools::gputools::dnn::RnnInputMode;
using perftools::gputools::dnn::RnnMode;
using perftools::gputools::dnn::ToDataType;
using perftools::gputools::port::StatusOr;

Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
  if (str == "rnn_relu") {
    *rnn_mode = RnnMode::kRnnRelu;
    return Status::OK();
  } else if (str == "rnn_tanh") {
    *rnn_mode = RnnMode::kRnnTanh;
    return Status::OK();
  } else if (str == "lstm") {
    *rnn_mode = RnnMode::kRnnLstm;
    return Status::OK();
  } else if (str == "gru") {
    *rnn_mode = RnnMode::kRnnGru;
    return Status::OK();
  }
  return errors::InvalidArgument("Invalid RNN mode: ", str);
}

Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
  if (str == "linear_input") {
    *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
    return Status::OK();
  } else if (str == "skip_input") {
    *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
    return Status::OK();
  } else if (str == "auto_select") {
    *rnn_input_mode = TFRNNInputMode::kAutoSelect;
    return Status::OK();
  }
  return errors::InvalidArgument("Invalid RNN input mode: ", str);
}

Status ParseRNNDirectionMode(const string& str,
                             RnnDirectionMode* rnn_dir_mode) {
  if (str == "unidirectional") {
    *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
    return Status::OK();
  } else if (str == "bidirectional") {
    *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
    return Status::OK();
  }
  return errors::InvalidArgument("Invalid RNN direction mode: ", str);
}

Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
                      int input_size, RnnInputMode* input_mode) {
  switch (tf_input_mode) {
    case TFRNNInputMode::kRNNLinearInput:
      *input_mode = RnnInputMode::kRnnLinearSkip;
      break;
    case TFRNNInputMode::kRNNSkipInput:
      *input_mode = RnnInputMode::kRnnSkipInput;
      break;
    case TFRNNInputMode::kAutoSelect:
      *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
                                              : RnnInputMode::kRnnLinearSkip;
      break;
    default:
      return errors::InvalidArgument("Invalid TF input mode: ",
                                     static_cast<int>(tf_input_mode));
  }
  return Status::OK();
}

// TODO(zhengxq): Merge those into stream_executor_util.h.
template <typename T>
const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
  return DeviceMemory<T>::MakeFromByteSize(
      const_cast<T*>(tensor->template flat<T>().data()),
      tensor->template flat<T>().size() * sizeof(T));
}

template <typename T>
DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
  return DeviceMemory<T>::MakeFromByteSize(
      tensor->template flat<T>().data(),
      tensor->template flat<T>().size() * sizeof(T));
}

template <typename U, typename T>
DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
  return DeviceMemory<U>::MakeFromByteSize(
      tensor->template flat<T>().data(),
      tensor->template flat<T>().size() * sizeof(T));
}

DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
                                   int64 offset, int64 size) {
  const void* base_ptr = device_memory.opaque();
  void* offset_ptr =
      const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
  CHECK(offset + size <= device_memory.size())
      << "The slice is not within the region of DeviceMemory.";
  return DeviceMemoryBase(offset_ptr, size);
}

inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
  return s.ok() ? Status::OK()
                : Status(static_cast<tensorflow::error::Code>(
                             static_cast<int>(s.code())),
                         s.error_message());
}

template <typename T>
inline Status FromExecutorStatus(
    const perftools::gputools::port::StatusOr<T>& s) {
  return FromExecutorStatus(s.status());
}

inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
  return s.ok() ? perftools::gputools::port::Status::OK()
                : perftools::gputools::port::Status(
                      static_cast<perftools::gputools::port::error::Code>(
                          static_cast<int>(s.code())),
                      s.error_message());
}

// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes
// the ownership of the underlying memory. The expectation is that the memory
// should be alive for the span of the Cudnn RNN itself.
class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
 public:
  ~CudnnRNNWorkspaceAllocator() override {}
  explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context)
      : context_(context) {}
  int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    return std::numeric_limits<int64>::max();
  }
  StatusOr<DeviceMemory<uint8>> AllocateBytes(
      perftools::gputools::Stream* stream, int64 byte_size) override {
    Tensor temporary_memory;
    Status allocation_status(context_->allocate_temp(
        DT_UINT8, TensorShape({byte_size}), &temporary_memory));
    if (!allocation_status.ok()) {
      return ToExecutorStatus(allocation_status);
    }
    // Hold the reference of the allocated tensors until the end of the
    // allocator.
    allocated_tensors_.push_back(temporary_memory);
    total_byte_size_ += byte_size;
    return StatusOr<DeviceMemory<uint8>>(
        AsDeviceMemory<uint8>(&temporary_memory));
  }
  int64 TotalByteSize() { return total_byte_size_; }

 private:
  int64 total_byte_size_ = 0;
  OpKernelContext* context_;  // not owned
  std::vector<Tensor> allocated_tensors_;
};

// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors
// are allocated as a kernel output, and will be fed into the backward pass.
// The memory is expected to live long enough after the backward pass is
// finished.
template <typename T>
class CudnnRNNReserveSpaceAllocator : public ScratchAllocator {
 public:
  ~CudnnRNNReserveSpaceAllocator() override {}
  CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index)
      : context_(context), output_index_(output_index) {}
  int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    return std::numeric_limits<int64>::max();
  }
  StatusOr<DeviceMemory<uint8>> AllocateBytes(
      perftools::gputools::Stream* stream, int64 byte_size) override {
    CHECK(total_byte_size_ == 0)
        << "Reserve space allocator can only be called once";
    int64 allocate_count =
        Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));

    Tensor* temporary_memory = nullptr;
    Status allocation_status(context_->allocate_output(
        output_index_, TensorShape({allocate_count}), &temporary_memory));
    if (!allocation_status.ok()) {
      return ToExecutorStatus(allocation_status);
    }
    total_byte_size_ += byte_size;
    auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
        temporary_memory->template flat<T>().data(),
        temporary_memory->template flat<T>().size() * sizeof(T));
    return StatusOr<DeviceMemory<uint8>>(memory_uint8);
  }
  int64 TotalByteSize() { return total_byte_size_; }

 private:
  int64 total_byte_size_ = 0;
  OpKernelContext* context_;  // not owned
  int output_index_;
};

// A helper to allocate persistent memory for Cudnn RNN models, which is
// expected to live between kernel invocations.
// This class is not thread-safe.
class CudnnRNNPersistentSpaceAllocator : public ScratchAllocator {
 public:
  explicit CudnnRNNPersistentSpaceAllocator(OpKernelContext* context)
      : context_(context) {}

  ~CudnnRNNPersistentSpaceAllocator() override {}

  int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
    return std::numeric_limits<int64>::max();
  }

  StatusOr<DeviceMemory<uint8>> AllocateBytes(
      perftools::gputools::Stream* stream, int64 byte_size) override {
    if (total_byte_size_ != 0) {
      return Status(error::FAILED_PRECONDITION,
                    "Persistent space allocator can only be called once");
    }

    Status allocation_status = context_->allocate_persistent(
        DT_UINT8, TensorShape({byte_size}), &handle_, nullptr);
    if (!allocation_status.ok()) {
      return ToExecutorStatus(allocation_status);
    }
    total_byte_size_ += byte_size;
    return AsDeviceMemory<uint8>(handle_.AccessTensor(context_));
  }
  int64 TotalByteSize() { return total_byte_size_; }

 private:
  int64 total_byte_size_ = 0;
  PersistentTensor handle_;
  OpKernelContext* context_;  // not owned
};

struct CudnnModelTypes {
  RnnMode rnn_mode;
  TFRNNInputMode rnn_input_mode;
  RnnDirectionMode rnn_direction_mode;
  bool HasInputC() const {
    // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h.
    return rnn_mode == RnnMode::kRnnLstm;
  }
};

// A helper class that collects the shapes to describe a RNN model.
struct CudnnModelShapes {
  int num_layers;
  int input_size;
  int num_units;
  int seq_length;
  int batch_size;
  int dir_count;
  TensorShape input_shape;
  TensorShape output_shape;
  TensorShape hidden_state_shape;
  // At present only fields related to cached RnnDescriptor are concerned.
  bool IsCompatibleWith(const CudnnModelShapes& rhs) const {
    return num_layers == rhs.num_layers && input_size == rhs.input_size &&
           num_units == rhs.num_units && dir_count == rhs.dir_count;
  }
  string RnnDescDebugString() {
    return strings::Printf(
        "[num_layers, input_size, num_units, dir_count]: [%d, %d, %d, %d]",
        num_layers, input_size, num_units, dir_count);
  }
};

// Utility class for using CudnnModelShapes as a hash table key.
struct CudnnModelShapesHasher {
  uint64 operator()(const CudnnModelShapes& to_hash) const {
    uint64 hash = static_cast<uint64>(to_hash.num_layers);
    hash = tensorflow::FingerprintCat64(
        hash, static_cast<uint64>(to_hash.input_size));
    hash = tensorflow::FingerprintCat64(hash,
                                        static_cast<uint64>(to_hash.num_units));
    return tensorflow::FingerprintCat64(hash,
                                        static_cast<uint64>(to_hash.dir_count));
  }
};

// Utility class for using CudnnModelShapes as a hash table key.
struct CudnnModelShapesComparator {
  bool operator()(const CudnnModelShapes& first,
                  const CudnnModelShapes& second) const {
    return first.IsCompatibleWith(second);
  }
};

// Extract and checks the forward input tensors, parameters, and shapes from the
// OpKernelContext.
Status ExtractForwardInput(OpKernelContext* context,
                           const CudnnModelTypes& model_types,
                           const Tensor** input, const Tensor** input_h,
                           const Tensor** input_c, const Tensor** params,
                           CudnnModelShapes* model_shapes) {
  TF_RETURN_IF_ERROR(context->input("input", input));
  TF_RETURN_IF_ERROR(context->input("input_h", input_h));
  if (model_types.HasInputC()) {
    TF_RETURN_IF_ERROR(context->input("input_c", input_c));
  }
  TF_RETURN_IF_ERROR(context->input("params", params));

  if ((*input)->dims() != 3) {
    return errors::InvalidArgument("RNN input must be a 3-D vector.");
  }
  model_shapes->seq_length = (*input)->dim_size(0);
  model_shapes->batch_size = (*input)->dim_size(1);
  model_shapes->input_size = (*input)->dim_size(2);
  model_shapes->input_shape = (*input)->shape();
  model_shapes->dir_count =
      (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
          ? 2
          : 1;

  if ((*input_h)->dims() != 3) {
    return errors::InvalidArgument("RNN input must be a 3-D vector.");
  }
  model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
  model_shapes->num_units = (*input_h)->dim_size(2);

  model_shapes->hidden_state_shape =
      TensorShape({model_shapes->dir_count * model_shapes->num_layers,
                   model_shapes->batch_size, model_shapes->num_units});
  if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
    return errors::InvalidArgument(
        "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
        model_shapes->hidden_state_shape.DebugString());
  }
  if (model_types.HasInputC()) {
    if ((*input_h)->shape() != (*input_c)->shape()) {
      return errors::InvalidArgument(
          "input_h and input_c must have the same shape: ",
          (*input_h)->shape().DebugString(), " ",
          (*input_c)->shape().DebugString());
    }
  }
  model_shapes->output_shape =
      TensorShape({model_shapes->seq_length, model_shapes->batch_size,
                   model_shapes->dir_count * model_shapes->num_units});
  return Status::OK();
}

using perftools::gputools::dnn::RnnDescriptor;

template <typename T>
void RestoreParams(const OpInputList params_input,
                   const std::vector<RnnDescriptor::ParamsRegion>& params,
                   DeviceMemoryBase* data_dst,
                   perftools::gputools::Stream* stream) {
  int num_params = params.size();
  CHECK(params_input.size() == num_params)
      << "Number of params mismatch. Expected " << params_input.size()
      << ", got " << num_params;
  for (int i = 0; i < params.size(); i++) {
    int64 size_in_bytes = params[i].size;
    int64 size = size_in_bytes / sizeof(T);
    CHECK(size == params_input[i].NumElements())
        << "Params size mismatch. Expected " << size << ", got "
        << params_input[i].NumElements();
    auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
    DeviceMemoryBase data_dst_ptr =
        SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
    stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
  }
}

}  // namespace

// Note: all following kernels depend on a RnnDescriptor instance, which
// according to Cudnn official doc should be kept around and reused across all
// Cudnn kernels in the same model.
// In Tensorflow, we don't pass the reference across different OpKernels,
// rather, recreate it separately in each OpKernel, which does no cause issue:
// CudnnDropoutDescriptor keeps a reference to a memory for
// random number generator state. During recreation, this state is lost.
// However, only forward-pass Cudnn APIs make use of the state.

// A common base class for RNN kernels. It extracts common attributes and
// shape validations.
class CudnnRNNKernelCommon : public OpKernel {
 protected:
  explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
      : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
    OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
    OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
    string str;
    OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
    OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
    OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
    OP_REQUIRES_OK(context,
                   ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
    OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
    OP_REQUIRES_OK(
        context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
    // Reset CudnnRnnDescriptor and related random number generate states in
    // every Compute() call.
    OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
                                               false, &reset_rnd_gen_state_));
  }

  bool HasInputC() const { return model_types_.HasInputC(); }
  RnnMode rnn_mode() const { return model_types_.rnn_mode; }
  TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
  RnnDirectionMode rnn_direction_mode() const {
    return model_types_.rnn_direction_mode;
  }
  CudnnModelTypes model_types() const { return model_types_; }
  float dropout() const { return dropout_; }
  uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
  bool ResetRndGenState() { return reset_rnd_gen_state_; }

  template <typename T>
  Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
                                   std::unique_ptr<RnnDescriptor>* rnn_desc) {
    const Tensor* num_layers_t = nullptr;
    TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
    if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
      return errors::InvalidArgument("num_layers is not a scalar");
    }
    int num_layers = num_layers_t->scalar<int>()();
    const Tensor* num_units_t = nullptr;
    TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
    if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
      return errors::InvalidArgument("num_units is not a scalar");
    }
    int num_units = num_units_t->scalar<int>()();
    const Tensor* input_size_t = nullptr;
    TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
    if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
      return errors::InvalidArgument("input_size is not a scalar");
    }
    int input_size = input_size_t->scalar<int>()();

    RnnInputMode input_mode;
    TF_RETURN_IF_ERROR(
        ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));

    auto* stream = context->op_device_context()->stream();
    // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
    // random number generator, therefore set state_allocator to nullptr.
    auto rnn_desc_s = stream->parent()->createRnnDescriptor(
        num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
        rnn_mode(), ToDataType<T>::value, dropout(), seed(),
        nullptr /* state_allocator */);
    if (!rnn_desc_s.ok()) {
      return FromExecutorStatus(rnn_desc_s);
    }
    *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
    return Status::OK();
  }

 private:
  int seed_;
  int seed2_;
  float dropout_;
  bool reset_rnd_gen_state_;

  CudnnModelTypes model_types_;
};

// A class that returns the size of the opaque parameter buffer. The user should
// use that to create the actual parameter buffer for training. However, it
// should not be used for saving and restoring.
template <typename T, typename Index>
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
 public:
  typedef GPUDevice Device;
  explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
      : CudnnRNNKernelCommon(context) {}

  void Compute(OpKernelContext* context) override {
    std::unique_ptr<RnnDescriptor> rnn_desc;
    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    CHECK(params_size_in_bytes % sizeof(T) == 0)
        << "params_size_in_bytes must be multiple of element size";
    int64 params_size = params_size_in_bytes / sizeof(T);

    Tensor* output_t = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
    *output_t->template flat<Index>().data() = params_size;
  }
};

#define REGISTER_GPU(T)                                    \
  REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
                              .Device(DEVICE_GPU)          \
                              .HostMemory("num_layers")    \
                              .HostMemory("num_units")     \
                              .HostMemory("input_size")    \
                              .HostMemory("params_size")   \
                              .TypeConstraint<T>("T")      \
                              .TypeConstraint<int32>("S"), \
                          CudnnRNNParamsSizeOp<GPUDevice, T, int32>);

TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU

// Convert weight and bias params from a platform-specific layout to the
// canonical form.
template <typename T>
class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
 public:
  typedef GPUDevice Device;
  explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
      : CudnnRNNKernelCommon(context) {
    OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(3);
    auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
    auto* stream = context->op_device_context()->stream();

    std::unique_ptr<RnnDescriptor> rnn_desc;
    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    CHECK(params_size_in_bytes % sizeof(T) == 0)
        << "params_size_in_bytes must be multiple of element size";

    const Tensor* num_units_t = nullptr;
    OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
    CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
        << "num_units is not a scalar";
    int num_units = num_units_t->scalar<int>()();

    const Tensor* input_size_t = nullptr;
    OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
    CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
        << "input_size is not a scalar";
    int input_size = input_size_t->scalar<int>()();

    const Tensor* num_layers_t = nullptr;
    OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
    CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
        << "num_layers is not a scalar";
    int num_layers = num_layers_t->scalar<int>()();
    int num_dirs = 1;
    if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
      num_dirs = 2;
    }
    const int num_params_per_layer = num_params_ / num_layers / num_dirs;
    // Number of params applied on inputs. The rest are applied on recurrent
    // hiddden states.
    const int num_params_input_state = num_params_per_layer / 2;
    CHECK(num_params_ % (num_layers * num_dirs) == 0)
        << "Number of params is not a multiple of num_layers * num_dirs.";
    CHECK(num_params_per_layer % 2 == 0)
        << "Number of params per layer is not a even number.";

    CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
        << "Number of params mismatch. Expected " << num_params_ << ", got "
        << rnn_desc->ParamsWeightRegions().size();
    for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
      int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
      int64 size = size_in_bytes / sizeof(T);
      const int layer_idx = i / num_params_per_layer;
      const int index_within_layer = i % num_params_per_layer;
      int width = 0, height = num_units;
      // In CuDNN layout, each layer has num_params_per_layer params, with the
      // first half a.k.a num_params_input_state params applied on the inputs,
      // and the second half on the recurrent hidden states.
      bool apply_on_input_state = index_within_layer < num_params_input_state;
      if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
        if (layer_idx == 0 && apply_on_input_state) {
          width = input_size;
        } else {
          width = num_units;
        }
      } else {
        if (apply_on_input_state) {
          if (layer_idx <= 1) {
            // First fwd or bak layer.
            width = input_size;
          } else {
            // Following layers, cell inputs are concatenated outputs of
            // its prior layer.
            width = 2 * num_units;
          }
        } else {
          width = num_units;
        }
      }
      CHECK(size == width * height) << "Params size mismatch. Expected "
                                    << width * height << ", got " << size;
      Tensor* output = nullptr;
      OP_REQUIRES_OK(context, context->allocate_output(
                                  i, TensorShape({height, width}), &output));
      DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
          input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
      auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
      stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    }

    OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
                errors::InvalidArgument("Number of params mismatch. Expected ",
                                        num_params_, ", got ",
                                        rnn_desc->ParamsBiasRegions().size()));
    for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
      int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
      int64 size = size_in_bytes / sizeof(T);
      OP_REQUIRES(context, size == num_units,
                  errors::InvalidArgument("Params size mismatch. Expected ",
                                          num_units, ", got ", size));

      Tensor* output = nullptr;
      OP_REQUIRES_OK(context,
                     context->allocate_output(num_params_ + i,
                                              TensorShape({size}), &output));
      DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
          input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
      auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
      stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
    }
  }

 private:
  int num_params_;
};

#define REGISTER_GPU(T)                                     \
  REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
                              .Device(DEVICE_GPU)           \
                              .HostMemory("num_layers")     \
                              .HostMemory("num_units")      \
                              .HostMemory("input_size")     \
                              .TypeConstraint<T>("T"),      \
                          CudnnRNNParamsToCanonical<GPUDevice, T>);
TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU

// Convert weight and bias params from the canonical form to a
// platform-specific layout.
template <typename T>
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
 public:
  typedef GPUDevice Device;
  explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
      : CudnnRNNKernelCommon(context) {}

  void Compute(OpKernelContext* context) override {
    std::unique_ptr<RnnDescriptor> rnn_desc;
    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
    int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
    CHECK(params_size_in_bytes % sizeof(T) == 0)
        << "params_size_in_bytes must be multiple of element size";
    Tensor* output = nullptr;
    int params_size = params_size_in_bytes / sizeof(T);
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, {params_size}, &output));
    auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    auto* stream = context->op_device_context()->stream();

    OpInputList weights;
    OP_REQUIRES_OK(context, context->input_list("weights", &weights));
    RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
                     stream);

    OpInputList biases;
    OP_REQUIRES_OK(context, context->input_list("biases", &biases));
    RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
                     stream);
  }
};

#define REGISTER_GPU(T)                                     \
  REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
                              .Device(DEVICE_GPU)           \
                              .HostMemory("num_layers")     \
                              .HostMemory("num_units")      \
                              .HostMemory("input_size")     \
                              .TypeConstraint<T>("T"),      \
                          CudnnRNNCanonicalToParams<GPUDevice, T>);
TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU

// Pointers to RNN scratch space for a specific set of shape parameters (used as
// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
struct RnnScratchSpace {
  std::unique_ptr<RnnDescriptor> rnn_desc;
  std::unique_ptr<CudnnRNNPersistentSpaceAllocator> dropout_state_allocator;
};

// Run the forward operation of the RNN model.
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
 public:
  typedef GPUDevice Device;
  explicit CudnnRNNForwardOp(OpKernelConstruction* context)
      : CudnnRNNKernelCommon(context) {
    OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
  }

  void Compute(OpKernelContext* context) override {
    const Tensor* input = nullptr;
    const Tensor* input_h = nullptr;
    const Tensor* input_c = nullptr;
    const Tensor* params = nullptr;
    CudnnModelShapes model_shapes;
    OP_REQUIRES_OK(context,
                   ExtractForwardInput(context, model_types(), &input, &input_h,
                                       &input_c, &params, &model_shapes));
    const auto& input_shape = model_shapes.input_shape;
    const auto& hidden_state_shape = model_shapes.hidden_state_shape;
    const auto& output_shape = model_shapes.output_shape;

    Tensor* output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    Tensor* output_h = nullptr;
    OP_REQUIRES_OK(context,
                   context->allocate_output(1, hidden_state_shape, &output_h));
    Tensor* output_c = nullptr;
    if (HasInputC()) {
      // Only LSTM uses input_c and output_c. So for all other models, we only
      // need to create dummy outputs.
      OP_REQUIRES_OK(
          context, context->allocate_output(2, hidden_state_shape, &output_c));
    } else {
      OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_c));
    }

    auto* stream = context->op_device_context()->stream();
    auto* executor = stream->parent();
    RnnInputMode input_mode;
    OP_REQUIRES_OK(context,
                   ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
                                  model_shapes.input_size, &input_mode));
    auto data_type = ToDataType<T>::value;

    auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
        input_shape.dim_size(0), input_shape.dim_size(1),
        input_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
    auto input_desc = input_desc_s.ConsumeValueOrDie();

    auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
        hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
        hidden_state_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
    auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();

    auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
        output_shape.dim_size(0), output_shape.dim_size(1),
        output_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
    auto output_desc = output_desc_s.ConsumeValueOrDie();

    auto input_data = AsDeviceMemory<T>(input);
    auto input_h_data = AsDeviceMemory<T>(input_h);
    DeviceMemory<T> input_c_data;
    if (HasInputC()) {
      input_c_data = AsDeviceMemory<T>(input_c);
    }
    auto params_data = AsDeviceMemory<T>(params);
    auto output_data = AsDeviceMemory<T>(output);
    auto output_h_data = AsDeviceMemory<T>(output_h);
    DeviceMemory<T> output_c_data;
    if (HasInputC()) {
      output_c_data = AsDeviceMemory<T>(output_c);
    }

    // Creates a memory callback for the reserve_space. The memory lives in the
    // output of this kernel. And it will be fed into the backward pass when
    // needed.
    CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3);
    if (!is_training_) {
      Tensor* dummy_reserve_space = nullptr;
      OP_REQUIRES_OK(context,
                     context->allocate_output(3, {}, &dummy_reserve_space));
    }
    // Creates a memory callback for the workspace. The memory lives to the end
    // of this kernel calls.
    CudnnRNNWorkspaceAllocator workspace_allocator(context);
    bool launch_status = false;
    {
      mutex_lock l(mu_);
      RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
      if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
        CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
            new CudnnRNNPersistentSpaceAllocator(context);
        rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
        auto rnn_desc_s = executor->createRnnDescriptor(
            model_shapes.num_layers, model_shapes.num_units,
            model_shapes.input_size, input_mode, rnn_direction_mode(),
            rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
        OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
        rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
      }
      launch_status =
          stream
              ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data,
                               *hidden_state_desc, input_h_data,
                               *hidden_state_desc, input_c_data, params_data,
                               *output_desc, &output_data, *hidden_state_desc,
                               &output_h_data, *hidden_state_desc,
                               &output_c_data, is_training_,
                               &reserve_space_allocator, &workspace_allocator)
              .ok();
    }
    OP_REQUIRES(context, launch_status,
                errors::Internal("Failed to call ThenRnnForward"));
  }

 private:
  mutex mu_;
  bool is_training_;
  std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
                     CudnnModelShapesComparator>
      rnn_state_cache_ GUARDED_BY(mu_);
};

#define REGISTER_GPU(T)                                           \
  REGISTER_KERNEL_BUILDER(                                        \
      Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      CudnnRNNForwardOp<GPUDevice, T>);

TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU

// Run the backward operation of the RNN model.
template <typename T>
class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
 public:
  typedef GPUDevice Device;

  explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
      : CudnnRNNKernelCommon(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor* input = nullptr;
    const Tensor* input_h = nullptr;
    const Tensor* input_c = nullptr;
    const Tensor* params = nullptr;
    CudnnModelShapes model_shapes;
    OP_REQUIRES_OK(context,
                   ExtractForwardInput(context, model_types(), &input, &input_h,
                                       &input_c, &params, &model_shapes));

    const auto& input_shape = model_shapes.input_shape;
    const auto& hidden_state_shape = model_shapes.hidden_state_shape;
    const auto& output_shape = model_shapes.output_shape;

    auto data_type = ToDataType<T>::value;
    const Tensor* output = nullptr;
    OP_REQUIRES_OK(context, context->input("output", &output));
    OP_REQUIRES(context, output_shape == output->shape(),
                errors::InvalidArgument(
                    "input_h and input_c must have the same shape: ",
                    input_h->shape().DebugString(), " ",
                    input_c->shape().DebugString()));
    const Tensor* output_h = nullptr;
    OP_REQUIRES_OK(context, context->input("output_h", &output_h));
    OP_REQUIRES(context, output_h->shape() == hidden_state_shape,
                errors::InvalidArgument(
                    "Invalid output_h shape: ", output_h->shape().DebugString(),
                    " ", hidden_state_shape.DebugString()));
    const Tensor* output_c = nullptr;
    if (HasInputC()) {
      // Only LSTM uses input_c and output_c. So for all other models, we only
      // need to create dummy outputs.
      OP_REQUIRES_OK(context, context->input("output_c", &output_c));
      OP_REQUIRES(context, output_c->shape() == hidden_state_shape,
                  errors::InvalidArgument("Invalid output_c shape: ",
                                          output_c->shape().DebugString(), " ",
                                          hidden_state_shape.DebugString()));
    }

    const Tensor* output_backprop = nullptr;
    OP_REQUIRES_OK(context,
                   context->input("output_backprop", &output_backprop));
    OP_REQUIRES(context, output_backprop->shape() == output_shape,
                errors::InvalidArgument("Invalid output_backprop shapes: ",
                                        output_backprop->shape().DebugString(),
                                        " ", output_shape.DebugString()));

    const Tensor* output_h_backprop = nullptr;
    OP_REQUIRES_OK(context,
                   context->input("output_h_backprop", &output_h_backprop));
    OP_REQUIRES(
        context, output_h_backprop->shape() == hidden_state_shape,
        errors::InvalidArgument("Invalid output_h_backprop shapes: ",
                                output_h_backprop->shape().DebugString(), " ",
                                hidden_state_shape.DebugString()));
    const Tensor* output_c_backprop = nullptr;
    if (HasInputC()) {
      OP_REQUIRES_OK(context,
                     context->input("output_c_backprop", &output_c_backprop));
      OP_REQUIRES(
          context, output_c_backprop->shape() == hidden_state_shape,
          errors::InvalidArgument("Invalid output_c_backprop shapes: ",
                                  output_c_backprop->shape().DebugString(), " ",
                                  hidden_state_shape.DebugString()));
    }
    const Tensor* reserve_space_const = nullptr;
    // This is the same "reserve_space" created by the forward op.
    // It can also be modified by this backward operation.
    OP_REQUIRES_OK(context,
                   context->input("reserve_space", &reserve_space_const));
    // Cudnn needs the reserve space to be writeable. This is fine because they
    // are opaque.
    Tensor* reserve_space = const_cast<Tensor*>(reserve_space_const);

    Tensor* input_backprop = nullptr;
    OP_REQUIRES_OK(
        context, context->allocate_output(0, input->shape(), &input_backprop));
    Tensor* input_h_backprop = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(1, input_h->shape(),
                                                     &input_h_backprop));
    Tensor* input_c_backprop = nullptr;
    if (HasInputC()) {
      OP_REQUIRES_OK(context, context->allocate_output(2, input_c->shape(),
                                                       &input_c_backprop));
    } else {
      OP_REQUIRES_OK(context,
                     context->allocate_output(2, {}, &input_c_backprop));
    }
    Tensor* params_backprop = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(3, params->shape(),
                                                     &params_backprop));

    auto* stream = context->op_device_context()->stream();
    auto* executor = stream->parent();
    RnnInputMode input_mode;
    OP_REQUIRES_OK(context,
                   ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
                                  model_shapes.input_size, &input_mode));

    auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
        input_shape.dim_size(0), input_shape.dim_size(1),
        input_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(input_desc_s));
    auto input_desc = input_desc_s.ConsumeValueOrDie();

    auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
        hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
        hidden_state_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(hidden_state_desc_s));
    auto hidden_state_desc = hidden_state_desc_s.ConsumeValueOrDie();

    auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
        output_shape.dim_size(0), output_shape.dim_size(1),
        output_shape.dim_size(2), data_type);
    OP_REQUIRES_OK(context, FromExecutorStatus(output_desc_s));
    auto output_desc = output_desc_s.ConsumeValueOrDie();

    auto input_data = AsDeviceMemory<T>(input);
    auto input_h_data = AsDeviceMemory<T>(input_h);
    DeviceMemory<T> input_c_data;
    if (HasInputC()) {
      input_c_data = AsDeviceMemory<T>(input_c);
    }
    auto params_data = AsDeviceMemory<T>(params);
    auto output_data = AsDeviceMemory<T>(output);
    auto output_h_data = AsDeviceMemory<T>(output_h);
    DeviceMemory<T> output_c_data;
    if (HasInputC()) {
      output_c_data = AsDeviceMemory<T>(output_c);
    }
    auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
    auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
    DeviceMemory<T> output_c_backprop_data;
    if (HasInputC()) {
      output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
    }
    auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
    auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
    DeviceMemory<T> input_c_backprop_data;
    if (HasInputC()) {
      input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
    }
    auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
    auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
    // Creates a memory callback for the workspace. The memory lives to the end
    // of this kernel calls.
    CudnnRNNWorkspaceAllocator workspace_allocator(context);
    bool launch_status = false;
    {
      mutex_lock l(mu_);
      RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes];
      if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
        CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
            new CudnnRNNPersistentSpaceAllocator(context);
        rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
        auto rnn_desc_s = executor->createRnnDescriptor(
            model_shapes.num_layers, model_shapes.num_units,
            model_shapes.input_size, input_mode, rnn_direction_mode(),
            rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
        OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
        rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
      }
      launch_status =
          stream
              ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data,
                                *hidden_state_desc, input_h_data,
                                *hidden_state_desc, input_c_data, params_data,
                                *output_desc, output_data, *hidden_state_desc,
                                output_h_data, *hidden_state_desc,
                                output_c_data, output_backprop_data,
                                output_h_backprop_data, output_c_backprop_data,
                                &input_backprop_data, &input_h_backprop_data,
                                &input_c_backprop_data, &params_backprop_data,
                                &reserve_space_uint8, &workspace_allocator)
              .ok();
    }
    OP_REQUIRES(context, launch_status,
                errors::Internal("Failed to call ThenRnnBackward"));
  }

 private:
  mutex mu_;
  std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
                     CudnnModelShapesComparator>
      rnn_state_cache_ GUARDED_BY(mu_);
};

#define REGISTER_GPU(T)                                                   \
  REGISTER_KERNEL_BUILDER(                                                \
      Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      CudnnRNNBackwardOp<GPUDevice, T>);

TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU

// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
// its canonical form.

#endif  // GOOGLE_CUDA

}  // namespace tensorflow