diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-22 09:27:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-22 09:34:14 -0700 |
commit | 0bad3db1e71051c6382395889b3ff22c1b8e786e (patch) | |
tree | 08ec2dfffa70ecb9e4c5969310e43d80ba69f414 | |
parent | 77f27c95fb565a9feafe8211d3de3b40e86c4402 (diff) |
Have RestoreV2's shape fn set all outputs to unknown shape.
PiperOrigin-RevId: 159835723
-rw-r--r-- | tensorflow/core/ops/io_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/ops/io_ops_test.cc | 20 |
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index fa12816c92..b92c18416f 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -109,8 +109,7 @@ REGISTER_OP("RestoreV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2)); TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); + return UnknownShape(c); }) .Doc(R"doc( Restores tensors from a V2 checkpoint. diff --git a/tensorflow/core/ops/io_ops_test.cc b/tensorflow/core/ops/io_ops_test.cc index a915cdbe12..785fb96c64 100644 --- a/tensorflow/core/ops/io_ops_test.cc +++ b/tensorflow/core/ops/io_ops_test.cc @@ -79,6 +79,26 @@ TEST(IoOpsTest, Restore_ShapeFn) { INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[?]"); } +TEST(IoOpsTest, RestoreV2_ShapeFn) { + ShapeInferenceTestOp op("RestoreV2"); + + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input({"prefix", 0, DT_STRING}) + .Input({"tensor_names", 0, DT_STRING}) + .Input({"shapes_and_slices", 0, DT_STRING}) + .Attr("dtypes", {DT_FLOAT, DT_INT64}) + .Finalize(&op.node_def)); + + INFER_OK(op, "?;?;?", "?;?"); + INFER_OK(op, "[];[10];[10]", "?;?"); + + // Input shape validation. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];[?];[?]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?,?];[?]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?];[?,?]"); + INFER_ERROR("in both shapes must be equal", op, "[];[10];[20]"); +} + TEST(IoOpsTest, RestoreSlice_ShapeFn) { ShapeInferenceTestOp op("RestoreSlice"); |