diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-14 13:22:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-14 13:43:16 -0800 |
commit | 0c2018764516add343acc9ee56c487715bc98c2a (patch) | |
tree | 4a61a5235b6c339bedf9d97fc048570fd72cec22 | |
parent | 750c98508c4ea51bf45694d32773b075eb2c8c8d (diff) |
Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
-rw-r--r-- | tensorflow/contrib/metrics/ops/set_ops.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 41 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 146 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 44 | ||||
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 133 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 8 |
7 files changed, 202 insertions, 187 deletions
diff --git a/tensorflow/contrib/metrics/ops/set_ops.cc b/tensorflow/contrib/metrics/ops/set_ops.cc index ee377c57e9..3da83ddae2 100644 --- a/tensorflow/contrib/metrics/ops/set_ops.cc +++ b/tensorflow/contrib/metrics/ops/set_ops.cc @@ -173,8 +173,8 @@ REGISTER_OP("DenseToSparseSetOperation") } else { output_rank = c->UnknownDim(); } - TF_RETURN_IF_ERROR( - c->ValidateSparseTensor(c->input(1), c->input(2), c->input(3))); + TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( + c, c->input(1), c->input(2), c->input(3))); DimensionHandle output_num_elements = c->Dim(input0_shape, 0); if (!c->ValueKnown(output_num_elements)) { output_num_elements = c->UnknownDim(); @@ -239,10 +239,10 @@ REGISTER_OP("SparseToSparseSetOperation") } // The following should stay in sync with `ComputeSparseToSparse` shape // assertions in kernels/set_kernels.cc. - TF_RETURN_IF_ERROR( - c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2))); - TF_RETURN_IF_ERROR( - c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5))); + TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( + c, c->input(0), c->input(1), c->input(2))); + TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( + c, c->input(3), c->input(4), c->input(5))); c->set_output(0, c->Matrix(c->UnknownDim(), c->UnknownDim())); c->set_output(1, c->Vector(c->UnknownDim())); c->set_output(2, c->Vector(c->UnknownDim())); diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 2434127acc..794f9c37cf 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) { return Status::OK(); } +Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle values_shape, ShapeHandle shape_shape) { + // Validate ranks. + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); + + // Number of elements in indices and values must match. + DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); + if (c->ValueKnown(num_index_elements_dim)) { + DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); + if (c->ValueKnown(num_values_elements_dim)) { + int64 num_index_elements = c->Value(num_index_elements_dim); + int64 num_values_elements = c->Value(num_values_elements_dim); + if (num_index_elements != num_values_elements) { + return errors::InvalidArgument("Number of elements in index (", + num_index_elements, ") and values (", + num_values_elements, ") do not match."); + } + } + } + + // Rank embedded in indices must match shape. + DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); + if (c->ValueKnown(index_rank_dim)) { + DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); + if (c->ValueKnown(shape_rank_dim)) { + int64 index_rank = c->Value(index_rank_dim); + int32 shape_rank = c->Value(shape_rank_dim); + if (index_rank != shape_rank) { + return errors::InvalidArgument("Index rank (", index_rank, + ") and shape rank (", shape_rank, + ") do not match."); + } + } + } + + return Status::OK(); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index fc1288f298..305b7b4056 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -203,6 +203,11 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c); // Tested by ops/math_ops_test.cc. Status BroadcastBinaryOpShapeFn(InferenceContext* c); +// Validates the 3 component tensors of a sparse tensor have the proper +// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. +Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle values_shape, ShapeHandle shape_shape); + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index ca1326844c..2be771e3a9 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -40,6 +40,19 @@ TensorShapeProto Unknown() { return ret; } +OpDef MakeOpDef(int num_inputs, int num_outputs) { + OpRegistrationData op_reg_data; + OpDefBuilder b("dummy"); + for (int i = 0; i < num_inputs; ++i) { + b.Input(strings::StrCat("i", i, ": float")); + } + for (int i = 0; i < num_outputs; ++i) { + b.Output(strings::StrCat("o", i, ": float")); + } + CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); + return op_reg_data.op_def; +} + } // namespace TEST(CommonShapeFnsTest, NoOutputShapeTest) { @@ -840,5 +853,138 @@ TEST(CommonShapeFnsTest, ReduceForReduceJoin_ShapeFn) { INFER_OK(op, "[?,?,?];[2]", "?"); } +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()}, + {}, {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {}, + {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {}, + {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {}, + {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {}, + {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + EXPECT_EQ(error::INVALID_ARGUMENT, + ValidateSparseTensor(&c, indices, values, shape).code()); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {}, + {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {}, + {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {}, + {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {}, + {}, {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + +TEST(CommonShapeFnsTest, ValidateSparseTensor) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {}, + {}, {}); + EXPECT_EQ(3, c.num_inputs()); + EXPECT_EQ(1, c.num_outputs()); + + auto indices = c.input(0); + auto values = c.input(1); + auto shape = c.input(2); + TF_EXPECT_OK(ValidateSparseTensor(&c, indices, values, shape)); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 6aa2a9fd4f..d91775152c 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -424,50 +424,6 @@ class InferenceContext { return output_handle_dtype_[idx]; } - // Validates the 3 component tensors of a sparse tensor have the proper - // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. - Status ValidateSparseTensor(ShapeHandle indices_shape, - ShapeHandle values_shape, - ShapeHandle shape_shape) { - // Validate ranks. - ShapeHandle unused_shape; - TF_RETURN_IF_ERROR(WithRank(indices_shape, 2, &unused_shape)); - TF_RETURN_IF_ERROR(WithRank(values_shape, 1, &unused_shape)); - TF_RETURN_IF_ERROR(WithRank(shape_shape, 1, &unused_shape)); - - // Number of elements in indices and values must match. - DimensionHandle num_index_elements_dim = Dim(indices_shape, 0); - if (ValueKnown(num_index_elements_dim)) { - DimensionHandle num_values_elements_dim = Dim(values_shape, 0); - if (ValueKnown(num_values_elements_dim)) { - int64 num_index_elements = Value(num_index_elements_dim); - int64 num_values_elements = Value(num_values_elements_dim); - if (num_index_elements != num_values_elements) { - return errors::InvalidArgument( - "Number of elements in index (", num_index_elements, - ") and values (", num_values_elements, ") do not match."); - } - } - } - - // Rank embedded in indices must match shape. - DimensionHandle index_rank_dim = Dim(indices_shape, 1); - if (ValueKnown(index_rank_dim)) { - DimensionHandle shape_rank_dim = Dim(shape_shape, 0); - if (ValueKnown(shape_rank_dim)) { - int64 index_rank = Value(index_rank_dim); - int32 shape_rank = Value(shape_rank_dim); - if (index_rank != shape_rank) { - return errors::InvalidArgument("Index rank (", index_rank, - ") and shape rank (", shape_rank, - ") do not match."); - } - } - } - - return Status::OK(); - } - // Note that shape functions should usually call MakeShapeFromShapeTensor, // as it does more analysis to provide partial shapes. // diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 8d6b4ac021..80a8639c02 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -1264,138 +1264,5 @@ TEST_F(ShapeInferenceTest, Max) { EXPECT_TRUE(SameHandle(d_2, out)); } -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()}, - {}, {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {}, - {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {}, - {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - EXPECT_EQ(error::INVALID_ARGUMENT, - c.ValidateSparseTensor(indices, values, shape).code()); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {}, - {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - EXPECT_EQ(error::INVALID_ARGUMENT, - c.ValidateSparseTensor(indices, values, shape).code()); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {}, - {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - EXPECT_EQ(error::INVALID_ARGUMENT, - c.ValidateSparseTensor(indices, values, shape).code()); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {}, - {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {}, - {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {}, - {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {}, - {}, {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - -TEST_F(ShapeInferenceTest, ValidateSparseTensor) { - NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {}, - {}, {}); - EXPECT_EQ(3, c.num_inputs()); - EXPECT_EQ(1, c.num_outputs()); - - auto indices = c.input(0); - auto values = c.input(1); - auto shape = c.input(2); - TF_EXPECT_OK(c.ValidateSparseTensor(indices, values, shape)); -} - } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index cd287b95fd..2cc545dd70 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -997,10 +997,10 @@ REGISTER_OP("EditDistance") .Attr("T: type") .Output("output: float") .SetShapeFn([](InferenceContext* c) { - TF_RETURN_IF_ERROR( - c->ValidateSparseTensor(c->input(0), c->input(1), c->input(2))); - TF_RETURN_IF_ERROR( - c->ValidateSparseTensor(c->input(3), c->input(4), c->input(5))); + TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( + c, c->input(0), c->input(1), c->input(2))); + TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor( + c, c->input(3), c->input(4), c->input(5))); const Tensor* hypothesis_shape_t = c->input_tensor(2); const Tensor* truth_shape_t = c->input_tensor(5); if (hypothesis_shape_t == nullptr || truth_shape_t == nullptr) { |