aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-22 09:27:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-22 09:34:14 -0700
commit0bad3db1e71051c6382395889b3ff22c1b8e786e (patch)
tree08ec2dfffa70ecb9e4c5969310e43d80ba69f414
parent77f27c95fb565a9feafe8211d3de3b40e86c4402 (diff)
Have RestoreV2's shape fn set all outputs to unknown shape.
PiperOrigin-RevId: 159835723
-rw-r--r--tensorflow/core/ops/io_ops.cc3
-rw-r--r--tensorflow/core/ops/io_ops_test.cc20
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index fa12816c92..b92c18416f 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -109,8 +109,7 @@ REGISTER_OP("RestoreV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
- c->set_output(0, c->UnknownShape());
- return Status::OK();
+ return UnknownShape(c);
})
.Doc(R"doc(
Restores tensors from a V2 checkpoint.
diff --git a/tensorflow/core/ops/io_ops_test.cc b/tensorflow/core/ops/io_ops_test.cc
index a915cdbe12..785fb96c64 100644
--- a/tensorflow/core/ops/io_ops_test.cc
+++ b/tensorflow/core/ops/io_ops_test.cc
@@ -79,6 +79,26 @@ TEST(IoOpsTest, Restore_ShapeFn) {
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[?]");
}
+TEST(IoOpsTest, RestoreV2_ShapeFn) {
+ ShapeInferenceTestOp op("RestoreV2");
+
+ TF_ASSERT_OK(NodeDefBuilder("test", op.name)
+ .Input({"prefix", 0, DT_STRING})
+ .Input({"tensor_names", 0, DT_STRING})
+ .Input({"shapes_and_slices", 0, DT_STRING})
+ .Attr("dtypes", {DT_FLOAT, DT_INT64})
+ .Finalize(&op.node_def));
+
+ INFER_OK(op, "?;?;?", "?;?");
+ INFER_OK(op, "[];[10];[10]", "?;?");
+
+ // Input shape validation.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];[?];[?]");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?,?];[?]");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[];[?];[?,?]");
+ INFER_ERROR("in both shapes must be equal", op, "[];[10];[20]");
+}
+
TEST(IoOpsTest, RestoreSlice_ShapeFn) {
ShapeInferenceTestOp op("RestoreSlice");