aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/cudnn_rnn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/cudnn_rnn_ops.cc')
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc79
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index 37d70a22ef..f78f7a897a 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -99,6 +99,49 @@ REGISTER_OP("CudnnRNN")
return Status::OK();
});
+REGISTER_OP("CudnnRNNV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .SetIsStateful()
+ .Output("output: T")
+ .Output("output_h: T")
+ .Output("output_c: T")
+ .Output("reserve_space: T")
+ .Output("host_reserved: int8")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto seq_length = c->Dim(input_shape, 0);
+ auto batch_size = c->Dim(input_shape, 1);
+ auto num_units = c->Dim(input_h_shape, 2);
+ string direction;
+ TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
+ string rnn_mode;
+ TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
+ int dir_count = (direction == "bidirectional") ? 2 : 1;
+ DimensionHandle output_size;
+ TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
+ auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
+ auto output_h_shape = input_h_shape;
+ auto output_c_shape TF_ATTRIBUTE_UNUSED =
+ (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
+ c->set_output(0, output_shape);
+ c->set_output(1, output_h_shape);
+ c->set_output(2, output_c_shape);
+ c->set_output(3, c->UnknownShape());
+ c->set_output(4, c->UnknownShape());
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNBackprop")
.Input("input: T")
@@ -136,6 +179,42 @@ REGISTER_OP("CudnnRNNBackprop")
return Status::OK();
});
+REGISTER_OP("CudnnRNNBackpropV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .Input("output: T")
+ .Input("output_h: T")
+ .Input("output_c: T")
+ .Input("output_backprop: T")
+ .Input("output_h_backprop: T")
+ .Input("output_c_backprop: T")
+ .Input("reserve_space: T")
+ .Input("host_reserved: int8")
+ .SetIsStateful()
+ .Output("input_backprop: T")
+ .Output("input_h_backprop: T")
+ .Output("input_c_backprop: T")
+ .Output("params_backprop: T")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto input_c_shape = c->input(2);
+ auto params_shape = c->input(3);
+ c->set_output(0, input_shape);
+ c->set_output(1, input_h_shape);
+ c->set_output(2, input_c_shape);
+ c->set_output(3, params_shape);
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNParamsToCanonical")
.Input("num_layers: int32")