aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 17:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 17:19:43 -0700
commitdfa0447740b42edff5ae2d76d5957aa688ae8053 (patch)
treed2a94a6385280f239603d7cf70b32f10e8780673 /tensorflow/core/ops
parent1d78936a3989f6ee5a9945746cd329c37e82287c (diff)
parentb127c201cda558db21ce5f48f5899593d73da46b (diff)
Merge pull request #21114 from yongtang:07242018-CudnnRNNParamsSize
PiperOrigin-RevId: 213726710
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc9
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc11
2 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a897a..f84142c992 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd867561b..13c3b933f4 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {