diff options
Diffstat (limited to 'tensorflow/core/ops/cudnn_rnn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/cudnn_rnn_ops.cc | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index 37d70a22ef..f78f7a897a 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -99,6 +99,49 @@ REGISTER_OP("CudnnRNN") return Status::OK(); }); +REGISTER_OP("CudnnRNNV2") + .Input("input: T") + .Input("input_h: T") + .Input("input_c: T") + .Input("params: T") + .SetIsStateful() + .Output("output: T") + .Output("output_h: T") + .Output("output_c: T") + .Output("reserve_space: T") + .Output("host_reserved: int8") + .Attr("T: {float16, float32, float64}") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Attr("dropout: float = 0.0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .Attr("is_training: bool = true") + .SetShapeFn([](InferenceContext* c) { + auto input_shape = c->input(0); + auto input_h_shape = c->input(1); + auto seq_length = c->Dim(input_shape, 0); + auto batch_size = c->Dim(input_shape, 1); + auto num_units = c->Dim(input_h_shape, 2); + string direction; + TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction)); + string rnn_mode; + TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode)); + int dir_count = (direction == "bidirectional") ? 2 : 1; + DimensionHandle output_size; + TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size)); + auto output_shape = c->MakeShape({seq_length, batch_size, output_size}); + auto output_h_shape = input_h_shape; + auto output_c_shape TF_ATTRIBUTE_UNUSED = + (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({}); + c->set_output(0, output_shape); + c->set_output(1, output_h_shape); + c->set_output(2, output_c_shape); + c->set_output(3, c->UnknownShape()); + c->set_output(4, c->UnknownShape()); + return Status::OK(); + }); REGISTER_OP("CudnnRNNBackprop") .Input("input: T") @@ -136,6 +179,42 @@ REGISTER_OP("CudnnRNNBackprop") return Status::OK(); }); +REGISTER_OP("CudnnRNNBackpropV2") + .Input("input: T") + .Input("input_h: T") + .Input("input_c: T") + .Input("params: T") + .Input("output: T") + .Input("output_h: T") + .Input("output_c: T") + .Input("output_backprop: T") + .Input("output_h_backprop: T") + .Input("output_c_backprop: T") + .Input("reserve_space: T") + .Input("host_reserved: int8") + .SetIsStateful() + .Output("input_backprop: T") + .Output("input_h_backprop: T") + .Output("input_c_backprop: T") + .Output("params_backprop: T") + .Attr("T: {float16, float32, float64}") + .Attr(kRNNModeAttrs) + .Attr(kRNNInputModeAttrs) + .Attr(kRNNDirectionAttrs) + .Attr("dropout: float = 0.0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn([](InferenceContext* c) { + auto input_shape = c->input(0); + auto input_h_shape = c->input(1); + auto input_c_shape = c->input(2); + auto params_shape = c->input(3); + c->set_output(0, input_shape); + c->set_output(1, input_h_shape); + c->set_output(2, input_c_shape); + c->set_output(3, params_shape); + return Status::OK(); + }); REGISTER_OP("CudnnRNNParamsToCanonical") .Input("num_layers: int32") |