diff options
author | 2017-03-06 15:33:13 -0800 | |
---|---|---|
committer | 2017-03-06 15:45:45 -0800 | |
commit | 1d8934f63718c6b5009f1f84b336d2ac218c0922 (patch) | |
tree | 47dd069a30bb6c7df294f6da151f6d1ad280e6ae | |
parent | 9ccda399bc9e948255285326fb1dfd960b78df2d (diff) |
Add convenience methods to ShapeInference: MakeShapeFromTensorShape() and MakeShapeFromPartialTensorShape(). Update existing users of MakeShapeFromShapeProto to use the new helper methods where possible.
Change: 149353816
-rw-r--r-- | tensorflow/core/framework/partial_tensor_shape.h | 2 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 29 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 7 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 33 | ||||
-rw-r--r-- | tensorflow/core/kernels/hexagon/graph_transferer.cc | 14 |
5 files changed, 71 insertions, 14 deletions
diff --git a/tensorflow/core/framework/partial_tensor_shape.h b/tensorflow/core/framework/partial_tensor_shape.h index 1504b8c983..7a70167d7c 100644 --- a/tensorflow/core/framework/partial_tensor_shape.h +++ b/tensorflow/core/framework/partial_tensor_shape.h @@ -40,7 +40,7 @@ class PartialTensorShape { PartialTensorShape() : is_unknown_(true) {} /// \brief Construct a `PartialTensorShape` from the provided sizes. - /// REQUIRES: `dim_sizes[i] >= 0` + /// REQUIRES: `dim_sizes[i] >= -1`; `-1` means `unknown`. explicit PartialTensorShape(gtl::ArraySlice<int64> dim_sizes); PartialTensorShape(std::initializer_list<int64> dim_sizes) : PartialTensorShape(gtl::ArraySlice<int64>(dim_sizes)) {} diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index cbfa9bd20c..449d8f55f5 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -599,25 +599,36 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t, return ReturnCreatedShape(dims, out); } -Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, - ShapeHandle* out) { +Status InferenceContext::MakeShapeFromPartialTensorShape( + const PartialTensorShape& partial_shape, ShapeHandle* out) { *out = nullptr; - TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); - PartialTensorShape partial_shape(proto); if (partial_shape.dims() == -1) { return ReturnUnknownShape(out); } const int num_dims = partial_shape.dims(); - std::vector<DimensionHandle> dims; - dims.reserve(partial_shape.dims()); + std::vector<DimensionHandle> dims(num_dims); for (int i = 0; i < num_dims; ++i) { - // -1 is unknown in proto and in InferenceContext, so this size can be - // passed directly to MakeDim. - dims.push_back(MakeDim(partial_shape.dim_size(i))); + // -1 is unknown in PartialTensorShape and in InferenceContext, so this size + // can be passed directly to MakeDim. + dims[i] = MakeDim(partial_shape.dim_size(i)); } return ReturnCreatedShape(dims, out); } +Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape, + ShapeHandle* out) { + return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()), + out); +} + +Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, + ShapeHandle* out) { + *out = nullptr; + TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); + PartialTensorShape partial_shape(proto); + return MakeShapeFromPartialTensorShape(partial_shape, out); +} + // Returns a new dimension whose value is given by a scalar input tensor. Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { const Tensor* t = input_tensor(idx); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index dba8d30302..a24a615e33 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -355,6 +355,13 @@ class InferenceContext { Status MakeShapeFromShapeProto(const TensorShapeProto& proto, ShapeHandle* out); + // Returns in <out> a new shape corresponding to <partial_shape>. + Status MakeShapeFromPartialTensorShape( + const PartialTensorShape& partial_shape, ShapeHandle* out); + + // Returns in <out> a new shape corresponding to <shape>. + Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); + // Returns a new dimension of the given size. The returned value is owned by // this context. inline DimensionHandle MakeDim(DimensionOrConstant d) { diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 9fc068aebe..ff25063255 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -922,6 +922,39 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { } } +TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) { + NodeDef def; + std::vector<ShapeHandle> empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + + // With an unknown rank. + ShapeHandle out; + TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out)); + EXPECT_EQ("?", c.DebugString(out)); + + // With a known rank. + TF_ASSERT_OK( + c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out)); + EXPECT_EQ("[0]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape( + PartialTensorShape({0, -1, 1000}), &out)); + EXPECT_EQ("[0,?,1000]", c.DebugString(out)); +} + +TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) { + NodeDef def; + std::vector<ShapeHandle> empty; + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + + ShapeHandle out; + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out)); + EXPECT_EQ("[]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out)); + EXPECT_EQ("[0]", c.DebugString(out)); + TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out)); + EXPECT_EQ("[0,7,1000]", c.DebugString(out)); +} + TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { NodeDef def; std::vector<ShapeHandle> empty; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index c87aa82534..e00a72a8e7 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -80,6 +80,10 @@ Status GraphTransferer::LoadGraphFromProto( if (shape_inference_for_unknown_shape && !input_node_info_list.empty()) { auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) { + if (!status.ok()) { + return; + } + CHECK_NE(node, nullptr); // If we visit an input node, we use the shape provided and set the // shape accordingly. @@ -89,16 +93,18 @@ Status GraphTransferer::LoadGraphFromProto( if (node->name() == input_node_info.first) { shape_inference::InferenceContext* context = shape_refiner.GetContext(node); - TensorShapeProto proto; - input_node_info.second.shape().AsProto(&proto); shape_inference::ShapeHandle handle; - context->MakeShapeFromShapeProto(proto, &handle); + status = context->MakeShapeFromTensorShape( + input_node_info.second.shape(), &handle); shape_refiner.SetShape(node, 0, handle); is_input_node = true; } + if (!status.ok()) { + break; + } } // If not an input node call AddNode() that recomputes the shape. - if (!is_input_node) { + if (!is_input_node && status.ok()) { status = shape_refiner.AddNode(node); if (!status.ok()) { VLOG(1) << "Shape inference failed for node: " << node->name(); |