diff options
author | 2018-09-19 17:19:43 -0700 | |
---|---|---|
committer | 2018-09-19 17:19:43 -0700 | |
commit | dfa0447740b42edff5ae2d76d5957aa688ae8053 (patch) | |
tree | d2a94a6385280f239603d7cf70b32f10e8780673 /tensorflow/core/ops | |
parent | 1d78936a3989f6ee5a9945746cd329c37e82287c (diff) | |
parent | b127c201cda558db21ce5f48f5899593d73da46b (diff) |
Merge pull request #21114 from yongtang:07242018-CudnnRNNParamsSize
PiperOrigin-RevId: 213726710
Diffstat (limited to 'tensorflow/core/ops')
-rw-r--r-- | tensorflow/core/ops/cudnn_rnn_ops.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/ops/cudnn_rnn_ops_test.cc | 11 |
2 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index f78f7a897a..f84142c992 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -37,7 +37,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; - REGISTER_OP("CudnnRNNParamsSize") .Input("num_layers: int32") .Input("num_units: int32") @@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize") .Attr("seed2: int = 0") .Output("params_size: S") .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + // num_layers, num_units, and input_size should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(0, c->Vector(1)); return Status::OK(); }); - REGISTER_OP("CudnnRNN") .Input("input: T") .Input("input_h: T") @@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical") return Status::OK(); }); - REGISTER_OP("CudnnRNNCanonicalToParams") .Input("num_layers: int32") .Input("num_units: int32") diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc index 2dd867561b..13c3b933f4 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc @@ -26,7 +26,16 @@ namespace tensorflow { TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) { ShapeInferenceTestOp op("CudnnRNNParamsSize"); - INFER_OK(op, "[1];[1];[1]", "[1]"); + INFER_OK(op, "[];[];[]", "[1]"); + INFER_OK(op, "?;[];[]", "[1]"); + INFER_OK(op, "[];?;[]", "[1]"); + INFER_OK(op, "[];[];?", "[1]"); + INFER_OK(op, "[];?;?", "[1]"); + INFER_OK(op, "?;?;?", "[1]"); + + INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]"); + INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]"); + INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]"); } TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) { |