aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-14 13:22:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-14 13:43:16 -0800
commit0c2018764516add343acc9ee56c487715bc98c2a (patch)
tree4a61a5235b6c339bedf9d97fc048570fd72cec22
parent750c98508c4ea51bf45694d32773b075eb2c8c8d (diff)
Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
-rw-r--r--tensorflow/contrib/metrics/ops/set_ops.cc12
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc41
-rw-r--r--tensorflow/core/framework/common_shape_fns.h5
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc146
-rw-r--r--tensorflow/core/framework/shape_inference.h44
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc133
-rw-r--r--tensorflow/core/ops/array_ops.cc8
7 files changed, 202 insertions, 187 deletions
diff --git a/tensorflow/contrib/metrics/ops/set_ops.cc b/tensorflow/contrib/metrics/ops/set_ops.cc
index ee377c57e9..3da83ddae2 100644
--- a/tensorflow/contrib/metrics/ops/set_ops.cc
+++ b/tensorflow/contrib/metrics/ops/set_ops.cc
@@ -173,8 +173,8 @@ REGISTER_OP("DenseToSparseSetOperation")
} else {
output_rank = c->UnknownDim();
}
- TF_RETURN_IF_ERROR(
- c->ValidateSparseTensor(c->input(1), c->input(2), c->input(3)));
+ TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
+ c, c->input(1), c->input(2), c->input(3)));
DimensionHandle output_num_elements = c->Dim(input0_shape, 0);
if (!c->ValueKnown(output_num_elements)) {
output_num_elements = c->UnknownDim();
@@ -239,10 +239,10 @@ REGISTER_OP("SparseToSparseSetOperation")
}
// The following should stay in sync with `ComputeSparseToSparse` shape
// assertions in kernels/set_kernels.cc.
- TF_RETURN_IF_ERROR(
- c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
- TF_RETURN_IF_ERROR(
- c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
+ TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
+ c, c->input(0), c->input(1), c->input(2)));
+ TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
+ c, c->input(3), c->input(4), c->input(5)));
c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim()));
c->set_output(1, c->Vector(c->UnknownDim()));
c->set_output(2, c->Vector(c->UnknownDim()));
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 2434127acc..794f9c37cf 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
+ ShapeHandle values_shape, ShapeHandle shape_shape) {
+ // Validate ranks.
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
+
+ // Number of elements in indices and values must match.
+ DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
+ if (c->ValueKnown(num_index_elements_dim)) {
+ DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
+ if (c->ValueKnown(num_values_elements_dim)) {
+ int64 num_index_elements = c->Value(num_index_elements_dim);
+ int64 num_values_elements = c->Value(num_values_elements_dim);
+ if (num_index_elements != num_values_elements) {
+ return errors::InvalidArgument("Number of elements in index (",
+ num_index_elements, ") and values (",
+ num_values_elements, ") do not match.");
+ }
+ }
+ }
+
+ // Rank embedded in indices must match shape.
+ DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
+ if (c->ValueKnown(index_rank_dim)) {
+ DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
+ if (c->ValueKnown(shape_rank_dim)) {
+ int64 index_rank = c->Value(index_rank_dim);
+ int32 shape_rank = c->Value(shape_rank_dim);
+ if (index_rank != shape_rank) {
+ return errors::InvalidArgument("Index rank (", index_rank,
+ ") and shape rank (", shape_rank,
+ ") do not match.");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index fc1288f298..305b7b4056 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -203,6 +203,11 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Tested by ops/math_ops_test.cc.
Status BroadcastBinaryOpShapeFn(InferenceContext* c);
+// Validates the 3 component tensors of a sparse tensor have the proper
+// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
+Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
+ ShapeHandle values_shape, ShapeHandle shape_shape);
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index ca1326844c..2be771e3a9 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -40,6 +40,19 @@ TensorShapeProto Unknown() {
return ret;
}
+OpDef MakeOpDef(int num_inputs, int num_outputs) {
+ OpRegistrationData op_reg_data;
+ OpDefBuilder b("dummy");
+ for (int i = 0; i < num_inputs; ++i) {
+ b.Input(strings::StrCat("i", i, ": float"));
+ }
+ for (int i = 0; i < num_outputs; ++i) {
+ b.Output(strings::StrCat("o", i, ": float"));
+ }
+ CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
+ return op_reg_data.op_def;
+}
+
} // namespace
TEST(CommonShapeFnsTest, NoOutputShapeTest) {
@@ -840,5 +853,138 @@ TEST(CommonShapeFnsTest, ReduceForReduceJoin_ShapeFn) {
INFER_OK(op, "[?,?,?];[2]", "?");
}
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
+ {}, {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
+ {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
+ {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ EXPECT_EQ(error::INVALID_ARGUMENT,
+ ValidateSparseTensor(&c, indices, values, shape).code());
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
+ {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ EXPECT_EQ(error::INVALID_ARGUMENT,
+ ValidateSparseTensor(&c, indices, values, shape).code());
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
+ {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ EXPECT_EQ(error::INVALID_ARGUMENT,
+ ValidateSparseTensor(&c, indices, values, shape).code());
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
+ {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
+ {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
+ {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
+ {}, {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
+TEST(CommonShapeFnsTest, ValidateSparseTensor) {
+ NodeDef def;
+ InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
+ {}, {});
+ EXPECT_EQ(3, c.num_inputs());
+ EXPECT_EQ(1, c.num_outputs());
+
+ auto indices = c.input(0);
+ auto values = c.input(1);
+ auto shape = c.input(2);
+ TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape));
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 6aa2a9fd4f..d91775152c 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -424,50 +424,6 @@ class InferenceContext {
return output_handle_dtype_[idx];
}
- // Validates the 3 component tensors of a sparse tensor have the proper
- // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
- Status ValidateSparseTensor(ShapeHandle indices_shape,
- ShapeHandle values_shape,
- ShapeHandle shape_shape) {
- // Validate ranks.
- ShapeHandle unused_shape;
- TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape));
- TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape));
- TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape));
-
- // Number of elements in indices and values must match.
- DimensionHandle num_index_elements_dim = Dim(indices_shape, 0);
- if (ValueKnown(num_index_elements_dim)) {
- DimensionHandle num_values_elements_dim = Dim(values_shape, 0);
- if (ValueKnown(num_values_elements_dim)) {
- int64 num_index_elements = Value(num_index_elements_dim);
- int64 num_values_elements = Value(num_values_elements_dim);
- if (num_index_elements != num_values_elements) {
- return errors::InvalidArgument(
- "Number of elements in index (", num_index_elements,
- ") and values (", num_values_elements, ") do not match.");
- }
- }
- }
-
- // Rank embedded in indices must match shape.
- DimensionHandle index_rank_dim = Dim(indices_shape, 1);
- if (ValueKnown(index_rank_dim)) {
- DimensionHandle shape_rank_dim = Dim(shape_shape, 0);
- if (ValueKnown(shape_rank_dim)) {
- int64 index_rank = Value(index_rank_dim);
- int32 shape_rank = Value(shape_rank_dim);
- if (index_rank != shape_rank) {
- return errors::InvalidArgument("Index rank (", index_rank,
- ") and shape rank (", shape_rank,
- ") do not match.");
- }
- }
- }
-
- return Status::OK();
- }
-
// Note that shape functions should usually call MakeShapeFromShapeTensor,
// as it does more analysis to provide partial shapes.
//
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 8d6b4ac021..80a8639c02 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -1264,138 +1264,5 @@ TEST_F(ShapeInferenceTest, Max) {
EXPECT_TRUE(SameHandle(d_2, out));
}
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
- {}, {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
- {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
- {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- EXPECT_EQ(error::INVALID_ARGUMENT,
- c.ValidateSparseTensor(indices, values, shape).code());
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
- {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- EXPECT_EQ(error::INVALID_ARGUMENT,
- c.ValidateSparseTensor(indices, values, shape).code());
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
- {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- EXPECT_EQ(error::INVALID_ARGUMENT,
- c.ValidateSparseTensor(indices, values, shape).code());
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
- {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
- {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
- {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
- {}, {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
-TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
- NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
- {}, {});
- EXPECT_EQ(3, c.num_inputs());
- EXPECT_EQ(1, c.num_outputs());
-
- auto indices = c.input(0);
- auto values = c.input(1);
- auto shape = c.input(2);
- TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape));
-}
-
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index cd287b95fd..2cc545dd70 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -997,10 +997,10 @@ REGISTER_OP("EditDistance")
.Attr("T: type")
.Output("output: float")
.SetShapeFn([](InferenceContext* c) {
- TF_RETURN_IF_ERROR(
- c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2)));
- TF_RETURN_IF_ERROR(
- c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5)));
+ TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
+ c, c->input(0), c->input(1), c->input(2)));
+ TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
+ c, c->input(3), c->input(4), c->input(5)));
const Tensor* hypothesis_shape_t = c->input_tensor(2);
const Tensor* truth_shape_t = c->input_tensor(5);
if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) {