From 0bad3db1e71051c6382395889b3ff22c1b8e786e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 22 Jun 2017 09:27:13 -0700 Subject: Have RestoreV2's shape fn set all outputs to unknown shape. PiperOrigin-RevId: 159835723 --- tensorflow/core/ops/io_ops.cc | 3 +-- tensorflow/core/ops/io_ops_test.cc | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index fa12816c92..b92c18416f 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -109,8 +109,7 @@ REGISTER_OP("RestoreV2") TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2)); TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); + return UnknownShape(c); }) .Doc(R"doc( Restores tensors from a V2 checkpoint. diff --git a/tensorflow/core/ops/io_ops_test.cc b/tensorflow/core/ops/io_ops_test.cc index a915cdbe12..785fb96c64 100644 --- a/tensorflow/core/ops/io_ops_test.cc +++ b/tensorflow/core/ops/io_ops_test.cc @@ -79,6 +79,26 @@ TEST(IoOpsTest, Restore_ShapeFn) { INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[?]"); } +TEST(IoOpsTest, RestoreV2_ShapeFn) { + ShapeInferenceTestOp op("RestoreV2"); + + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input({"prefix", 0, DT_STRING}) + .Input({"tensor_names", 0, DT_STRING}) + .Input({"shapes_and_slices", 0, DT_STRING}) + .Attr("dtypes", {DT_FLOAT, DT_INT64}) + .Finalize(&op.node_def)); + + INFER_OK(op, "?;?;?", "?;?"); + INFER_OK(op, "[];[10];[10]", "?;?"); + + // Input shape validation. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];[?];[?]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?,?];[?]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?];[?,?]"); + INFER_ERROR("in both shapes must be equal", op, "[];[10];[20]"); +} + TEST(IoOpsTest, RestoreSlice_ShapeFn) { ShapeInferenceTestOp op("RestoreSlice"); -- cgit v1.2.3