aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/shape_inference.h11
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc2
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc22
3 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 6f11b2ab17..b7f1725c5f 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -182,6 +182,15 @@ class InferenceContext {
if (!s.ok()) {
return AttachContext(s);
}
+#if 0
+ // TODO(cwhipkey): enable this check
+#ifndef NDEBUG
+ for (int i = 0; i < num_outputs(); ++i) {
+ DCHECK(output(i).IsSet()) << i << " for " << node_def().name()
+ << " of type " << node_def().op();
+ }
+#endif // NDEBUG
+#endif
return s;
}
@@ -239,7 +248,7 @@ class InferenceContext {
}
int32 Rank(ShapeHandle s) const {
DCHECK(s.IsSet());
- return s->rank_;
+ return s.IsSet() ? s->rank_ : kUnknownRank;
}
bool RankKnown(ShapeHandle s) const {
return (s.IsSet() && (Rank(s) != kUnknownRank));
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index ff25063255..c82b506e4b 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -134,6 +134,7 @@ TEST_F(ShapeInferenceTest, Run) {
ShapeHandle h;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
c->set_output(0, c->input(0));
+ c->set_output(1, c->input(0));
return Status::OK();
};
TF_ASSERT_OK(c.Run(fn));
@@ -144,6 +145,7 @@ TEST_F(ShapeInferenceTest, Run) {
ShapeHandle h;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
c->set_output(0, c->input(0));
+ c->set_output(1, c->input(0));
return Status::OK();
};
Status s = c.Run(fn);
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index e811f1a4b0..10b5df91f1 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1431,7 +1431,7 @@ REGISTER_OP("TensorArray")
.Attr("element_shape: shape = { unknown_rank: true }")
.Output("handle: Ref(string)")
.SetIsStateful()
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayV2")
@@ -1456,7 +1456,7 @@ REGISTER_OP("TensorArrayGrad")
.Output("grad_handle: Ref(string)")
.Attr("source: string")
.SetIsStateful()
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGradV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayGradV2")
@@ -1481,7 +1481,7 @@ REGISTER_OP("TensorArrayWrite")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayWriteV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayWriteV2")
@@ -1509,7 +1509,7 @@ REGISTER_OP("TensorArrayRead")
.Input("flow_in: float")
.Output("value: dtype")
.Attr("dtype: type")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayReadV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayReadV2")
@@ -1535,7 +1535,7 @@ REGISTER_OP("TensorArrayPack")
.Output("value: dtype")
.Attr("dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGatherV3 with RangeOp");
REGISTER_OP("TensorArrayUnpack")
.Input("handle: Ref(string)")
@@ -1543,7 +1543,7 @@ REGISTER_OP("TensorArrayUnpack")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(20, "Use TensorArrayScatterV3 with RangeOp");
REGISTER_OP("TensorArrayGather")
.Input("handle: Ref(string)")
@@ -1552,7 +1552,7 @@ REGISTER_OP("TensorArrayGather")
.Output("value: dtype")
.Attr("dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGatherV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayGatherV2")
@@ -1579,7 +1579,7 @@ REGISTER_OP("TensorArrayScatter")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(19, "Use TensorArrayGradV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayScatterV2")
@@ -1606,7 +1606,7 @@ REGISTER_OP("TensorArrayConcat")
.Output("lengths: int64")
.Attr("dtype: type")
.Attr("element_shape_except0: shape = { unknown_rank: true }")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGradV3");
REGISTER_OP("TensorArrayConcatV2")
.Input("handle: string")
@@ -1634,7 +1634,7 @@ REGISTER_OP("TensorArraySplit")
.Input("flow_in: float")
.Output("flow_out: float")
.Attr("T: type")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArraySplitV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArraySplitV2")
@@ -1659,7 +1659,7 @@ REGISTER_OP("TensorArraySize")
.Input("handle: Ref(string)")
.Input("flow_in: float")
.Output("size: int32")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); })
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArraySizeV3");
// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArraySizeV2")