From df4ef50932de18d904e13f8ea4dbdcc4d5be2281 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Nov 2016 10:58:09 -0800 Subject: Fix QuantizedConcat shape function. Enable it from python as well. Change: 138895028 --- tensorflow/core/framework/common_shape_fns.cc | 18 +++++------ tensorflow/core/framework/common_shape_fns.h | 6 +++- tensorflow/core/ops/array_ops.cc | 9 ++++-- tensorflow/core/ops/array_ops_test.cc | 46 +++++++++++++++++++++++++++ tensorflow/python/ops/array_ops.py | 3 +- 5 files changed, 68 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index cc6cd10a08..2434127acc 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -715,12 +715,8 @@ Status ReductionShapeForReduceJoin(InferenceContext* c) { return Status::OK(); } -Status ConcatShapeHelper(InferenceContext* c, bool dim_is_last_argument) { - const int dim_index = dim_is_last_argument ? c->num_inputs() - 1 : 0; - const int start_value_index = dim_is_last_argument ? 0 : 1; - const int end_value_index = - dim_is_last_argument ? c->num_inputs() - 1 : c->num_inputs(); - +Status ConcatShapeHelper(InferenceContext* c, int start_value_index, + int end_value_index, int dim_index) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); const Tensor* concat_dim_t = c->input_tensor(dim_index); @@ -788,12 +784,16 @@ Status ConcatShapeHelper(InferenceContext* c, bool dim_is_last_argument) { return Status::OK(); } -Status ConcatShape(InferenceContext* c) { - return ConcatShapeHelper(c, /* dim_is_last_argument */ false); +Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { + return ConcatShapeHelper(c, 1 /* start_value_index */, + 1 + num_inputs_to_concat /* end_value_index */, + 0 /* dim_index */); } Status ConcatV2Shape(InferenceContext* c) { - return ConcatShapeHelper(c, /* dim_is_last_argument */ true); + return ConcatShapeHelper(c, 0 /* start_value_index */, + c->num_inputs() - 1 /* end_value_index */, + c->num_inputs() - 1 /* dim_index */); } Status BroadcastBinaryOpShapeFn(InferenceContext* c) { diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 176ea3519d..fc1288f298 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -190,7 +190,11 @@ Status ReductionShape(shape_inference::InferenceContext* c); Status ReductionShapeForReduceJoin(shape_inference::InferenceContext* c); // Shape function for concat operations. -Status ConcatShape(shape_inference::InferenceContext* c); +// is the number of inputs to concatenate and are taken +// from inputs +// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. +Status ConcatShape(shape_inference::InferenceContext* c, + int num_inputs_to_concat); // Shape function for concat operations. Status ConcatV2Shape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index cbc62d805a..a8139f3ee2 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -295,7 +295,9 @@ REGISTER_OP("Concat") .Output("output: T") .Attr("N: int >= 2") .Attr("T: type") - .SetShapeFn(shape_inference::ConcatShape) + .SetShapeFn([](InferenceContext* c) { + return shape_inference::ConcatShape(c, c->num_inputs() - 1); + }) .Doc(R"doc( Concatenates tensors along one dimension. @@ -4377,9 +4379,10 @@ REGISTER_OP("QuantizedConcat") .Attr("N: int >= 2") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { - TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c)); + const int n = (c->num_inputs() - 1) / 3; + TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n)); ShapeHandle unused; - for (int i = std::max(0, c->num_inputs() - 2); i < c->num_inputs(); ++i) { + for (int i = n + 1; i < c->num_inputs(); ++i) { TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); } c->set_output(1, c->Scalar()); diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 691380fd26..cf6b7718dd 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -1552,4 +1552,50 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) { INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]"); } +TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) { + ShapeInferenceTestOp op("QuantizedConcat"); + auto set_n = [&op](int n) { + std::vector src_list; + std::vector limit_list; + for (int i = 0; i < n; ++i) { + src_list.emplace_back("a", 0, DT_QUINT8); + limit_list.emplace_back("b", 0, DT_FLOAT); + } + TF_ASSERT_OK(NodeDefBuilder("test", "QuantizedConcat") + .Input({"concat_dim", 0, DT_INT32}) + .Input(src_list) + .Input(limit_list) + .Input(limit_list) + .Attr("N", n) + .Finalize(&op.node_def)); + }; + + // Confirm dimension[0] of the input (the concat_dim) is a scalar. + set_n(1); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?"); + + // Last 2* are all scalars. + set_n(2); + INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]"); + INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?"); + INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?"); + INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?"); + + // First is concat dim; next N must be compatible for concat. + set_n(2); + INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?"); + INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]"); + + // Test when the concat_dim tensor is known. The concatenated dimension is + // summed across all input tensors, and other dimensions are merged. + Tensor concat_dim_t; + op.input_tensors.push_back(&concat_dim_t); + set_n(2); + concat_dim_t = test::AsScalar(0); // Sum dim 0, merge the other two dims. + INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]"); + INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op, + "[];[100,2,5];[10,?,3];?;?;?;?"); + // Note that other cases of concat are covered in the Concat tests. +} + } // end namespace tensorflow diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 11a95a1847..79c909c7f6 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2617,7 +2617,8 @@ def _QuantizedReshapeShape(op): # TODO(cwhipkey): Verify and enable shape functions for these. ops.RegisterShape("QuantizeV2")(None) ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None) -ops.RegisterShape("QuantizedConcat")(None) + +ops.RegisterShape("QuantizedConcat")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("ScatterNd") -- cgit v1.2.3