aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-03-06 15:33:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-06 15:45:45 -0800
commit1d8934f63718c6b5009f1f84b336d2ac218c0922 (patch)
tree47dd069a30bb6c7df294f6da151f6d1ad280e6ae
parent9ccda399bc9e948255285326fb1dfd960b78df2d (diff)
Add convenience methods to ShapeInference: MakeShapeFromTensorShape() and MakeShapeFromPartialTensorShape(). Update existing users of MakeShapeFromShapeProto to use the new helper methods where possible.
Change: 149353816
-rw-r--r--tensorflow/core/framework/partial_tensor_shape.h2
-rw-r--r--tensorflow/core/framework/shape_inference.cc29
-rw-r--r--tensorflow/core/framework/shape_inference.h7
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc33
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc14
5 files changed, 71 insertions, 14 deletions
diff --git a/tensorflow/core/framework/partial_tensor_shape.h b/tensorflow/core/framework/partial_tensor_shape.h
index 1504b8c983..7a70167d7c 100644
--- a/tensorflow/core/framework/partial_tensor_shape.h
+++ b/tensorflow/core/framework/partial_tensor_shape.h
@@ -40,7 +40,7 @@ class PartialTensorShape {
PartialTensorShape() : is_unknown_(true) {}
/// \brief Construct a `PartialTensorShape` from the provided sizes.
- /// REQUIRES: `dim_sizes[i] >= 0`
+ /// REQUIRES: `dim_sizes[i] >= -1`; `-1` means `unknown`.
explicit PartialTensorShape(gtl::ArraySlice<int64> dim_sizes);
PartialTensorShape(std::initializer_list<int64> dim_sizes)
: PartialTensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index cbfa9bd20c..449d8f55f5 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -599,25 +599,36 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
return ReturnCreatedShape(dims, out);
}
-Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
- ShapeHandle* out) {
+Status InferenceContext::MakeShapeFromPartialTensorShape(
+ const PartialTensorShape& partial_shape, ShapeHandle* out) {
*out = nullptr;
- TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
- PartialTensorShape partial_shape(proto);
if (partial_shape.dims() == -1) {
return ReturnUnknownShape(out);
}
const int num_dims = partial_shape.dims();
- std::vector<DimensionHandle> dims;
- dims.reserve(partial_shape.dims());
+ std::vector<DimensionHandle> dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
- // -1 is unknown in proto and in InferenceContext, so this size can be
- // passed directly to MakeDim.
- dims.push_back(MakeDim(partial_shape.dim_size(i)));
+ // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
+ // can be passed directly to MakeDim.
+ dims[i] = MakeDim(partial_shape.dim_size(i));
}
return ReturnCreatedShape(dims, out);
}
+Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
+ ShapeHandle* out) {
+ return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
+ out);
+}
+
+Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
+ ShapeHandle* out) {
+ *out = nullptr;
+ TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
+ PartialTensorShape partial_shape(proto);
+ return MakeShapeFromPartialTensorShape(partial_shape, out);
+}
+
// Returns a new dimension whose value is given by a scalar input tensor.
Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
const Tensor* t = input_tensor(idx);
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index dba8d30302..a24a615e33 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -355,6 +355,13 @@ class InferenceContext {
Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
ShapeHandle* out);
+ // Returns in <out> a new shape corresponding to <partial_shape>.
+ Status MakeShapeFromPartialTensorShape(
+ const PartialTensorShape& partial_shape, ShapeHandle* out);
+
+ // Returns in <out> a new shape corresponding to <shape>.
+ Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out);
+
// Returns a new dimension of the given size. The returned value is owned by
// this context.
inline DimensionHandle MakeDim(DimensionOrConstant d) {
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 9fc068aebe..ff25063255 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -922,6 +922,39 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
}
}
+TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
+ NodeDef def;
+ std::vector<ShapeHandle> empty;
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+
+ // With an unknown rank.
+ ShapeHandle out;
+ TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
+ EXPECT_EQ("?", c.DebugString(out));
+
+ // With a known rank.
+ TF_ASSERT_OK(
+ c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
+ EXPECT_EQ("[0]", c.DebugString(out));
+ TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
+ PartialTensorShape({0, -1, 1000}), &out));
+ EXPECT_EQ("[0,?,1000]", c.DebugString(out));
+}
+
+TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
+ NodeDef def;
+ std::vector<ShapeHandle> empty;
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+
+ ShapeHandle out;
+ TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
+ EXPECT_EQ("[]", c.DebugString(out));
+ TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
+ EXPECT_EQ("[0]", c.DebugString(out));
+ TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
+ EXPECT_EQ("[0,7,1000]", c.DebugString(out));
+}
+
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index c87aa82534..e00a72a8e7 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -80,6 +80,10 @@ Status GraphTransferer::LoadGraphFromProto(
if (shape_inference_for_unknown_shape && !input_node_info_list.empty()) {
auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
+ if (!status.ok()) {
+ return;
+ }
+
CHECK_NE(node, nullptr);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
@@ -89,16 +93,18 @@ Status GraphTransferer::LoadGraphFromProto(
if (node->name() == input_node_info.first) {
shape_inference::InferenceContext* context =
shape_refiner.GetContext(node);
- TensorShapeProto proto;
- input_node_info.second.shape().AsProto(&proto);
shape_inference::ShapeHandle handle;
- context->MakeShapeFromShapeProto(proto, &handle);
+ status = context->MakeShapeFromTensorShape(
+ input_node_info.second.shape(), &handle);
shape_refiner.SetShape(node, 0, handle);
is_input_node = true;
}
+ if (!status.ok()) {
+ break;
+ }
}
// If not an input node call AddNode() that recomputes the shape.
- if (!is_input_node) {
+ if (!is_input_node && status.ok()) {
status = shape_refiner.AddNode(node);
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << node->name();