diff options
author | 2017-03-16 06:15:48 -0800 | |
---|---|---|
committer | 2017-03-16 07:34:05 -0700 | |
commit | 69f8657157d46dd5e4e447d77f23d2c1db7e9800 (patch) | |
tree | 0e1bcbb42aabaf3a9cf3c57e1ac7236748b41b47 | |
parent | bed0e5c3292bcc094a2890183cfaec8273541fff (diff) |
Change data_flow_ops.cc for deprecated TensorArray ops
to have unknown shape function for cases where the shape
function is incorrectly returning no outputs.
Change: 150315933
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 22 |
3 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 6f11b2ab17..b7f1725c5f 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -182,6 +182,15 @@ class InferenceContext { if (!s.ok()) { return AttachContext(s); } +#if 0 + // TODO(cwhipkey): enable this check +#ifndef NDEBUG + for (int i = 0; i < num_outputs(); ++i) { + DCHECK(output(i).IsSet()) << i << " for " << node_def().name() + << " of type " << node_def().op(); + } +#endif // NDEBUG +#endif return s; } @@ -239,7 +248,7 @@ class InferenceContext { } int32 Rank(ShapeHandle s) const { DCHECK(s.IsSet()); - return s->rank_; + return s.IsSet() ? s->rank_ : kUnknownRank; } bool RankKnown(ShapeHandle s) const { return (s.IsSet() && (Rank(s) != kUnknownRank)); diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index ff25063255..c82b506e4b 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -134,6 +134,7 @@ TEST_F(ShapeInferenceTest, Run) { ShapeHandle h; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h)); c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); return Status::OK(); }; TF_ASSERT_OK(c.Run(fn)); @@ -144,6 +145,7 @@ TEST_F(ShapeInferenceTest, Run) { ShapeHandle h; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); c->set_output(0, c->input(0)); + c->set_output(1, c->input(0)); return Status::OK(); }; Status s = c.Run(fn); diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index e811f1a4b0..10b5df91f1 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1431,7 +1431,7 @@ REGISTER_OP("TensorArray") .Attr("element_shape: shape = { unknown_rank: true }") .Output("handle: Ref(string)") .SetIsStateful() - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayV2") @@ -1456,7 +1456,7 @@ REGISTER_OP("TensorArrayGrad") .Output("grad_handle: Ref(string)") .Attr("source: string") .SetIsStateful() - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGradV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayGradV2") @@ -1481,7 +1481,7 @@ REGISTER_OP("TensorArrayWrite") .Input("flow_in: float") .Output("flow_out: float") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayWriteV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayWriteV2") @@ -1509,7 +1509,7 @@ REGISTER_OP("TensorArrayRead") .Input("flow_in: float") .Output("value: dtype") .Attr("dtype: type") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayReadV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayReadV2") @@ -1535,7 +1535,7 @@ REGISTER_OP("TensorArrayPack") .Output("value: dtype") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp"); REGISTER_OP("TensorArrayUnpack") .Input("handle: Ref(string)") @@ -1543,7 +1543,7 @@ REGISTER_OP("TensorArrayUnpack") .Input("flow_in: float") .Output("flow_out: float") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp"); REGISTER_OP("TensorArrayGather") .Input("handle: Ref(string)") @@ -1552,7 +1552,7 @@ REGISTER_OP("TensorArrayGather") .Output("value: dtype") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGatherV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayGatherV2") @@ -1579,7 +1579,7 @@ REGISTER_OP("TensorArrayScatter") .Input("flow_in: float") .Output("flow_out: float") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(19, "Use TensorArrayGradV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArrayScatterV2") @@ -1606,7 +1606,7 @@ REGISTER_OP("TensorArrayConcat") .Output("lengths: int64") .Attr("dtype: type") .Attr("element_shape_except0: shape = { unknown_rank: true }") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArrayGradV3"); REGISTER_OP("TensorArrayConcatV2") .Input("handle: string") @@ -1634,7 +1634,7 @@ REGISTER_OP("TensorArraySplit") .Input("flow_in: float") .Output("flow_out: float") .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArraySplitV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArraySplitV2") @@ -1659,7 +1659,7 @@ REGISTER_OP("TensorArraySize") .Input("handle: Ref(string)") .Input("flow_in: float") .Output("size: int32") - .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(16, "Use TensorArraySizeV3"); // TODO(cwhipkey): mark this deprecated in favor of V3. REGISTER_OP("TensorArraySizeV2") |