aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-11 10:58:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-11 11:05:40 -0800
commitdf4ef50932de18d904e13f8ea4dbdcc4d5be2281 (patch)
tree0cbd3f37aed1bc9ecf8a58817cbfdca839952713
parent3c5d874b0dd920e11b1cc3ad304438df7f1db870 (diff)
Fix QuantizedConcat shape function. Enable it from python as well.
Change: 138895028
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc18
-rw-r--r--tensorflow/core/framework/common_shape_fns.h6
-rw-r--r--tensorflow/core/ops/array_ops.cc9
-rw-r--r--tensorflow/core/ops/array_ops_test.cc46
-rw-r--r--tensorflow/python/ops/array_ops.py3
5 files changed, 68 insertions, 14 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index cc6cd10a08..2434127acc 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -715,12 +715,8 @@ Status ReductionShapeForReduceJoin(InferenceContext* c) {
return Status::OK();
}
-Status ConcatShapeHelper(InferenceContext* c, bool dim_is_last_argument) {
- const int dim_index = dim_is_last_argument ? c->num_inputs() - 1 : 0;
- const int start_value_index = dim_is_last_argument ? 0 : 1;
- const int end_value_index =
- dim_is_last_argument ? c->num_inputs() - 1 : c->num_inputs();
-
+Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
+ int end_value_index, int dim_index) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
const Tensor* concat_dim_t = c->input_tensor(dim_index);
@@ -788,12 +784,16 @@ Status ConcatShapeHelper(InferenceContext* c, bool dim_is_last_argument) {
return Status::OK();
}
-Status ConcatShape(InferenceContext* c) {
- return ConcatShapeHelper(c, /* dim_is_last_argument */ false);
+Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
+ return ConcatShapeHelper(c, 1 /* start_value_index */,
+ 1 + num_inputs_to_concat /* end_value_index */,
+ 0 /* dim_index */);
}
Status ConcatV2Shape(InferenceContext* c) {
- return ConcatShapeHelper(c, /* dim_is_last_argument */ true);
+ return ConcatShapeHelper(c, 0 /* start_value_index */,
+ c->num_inputs() - 1 /* end_value_index */,
+ c->num_inputs() - 1 /* dim_index */);
}
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 176ea3519d..fc1288f298 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -190,7 +190,11 @@ Status ReductionShape(shape_inference::InferenceContext* c);
Status ReductionShapeForReduceJoin(shape_inference::InferenceContext* c);
// Shape function for concat operations.
-Status ConcatShape(shape_inference::InferenceContext* c);
+// <num_inputs_to_concat> is the number of inputs to concatenate and are taken
+// from inputs
+// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
+Status ConcatShape(shape_inference::InferenceContext* c,
+ int num_inputs_to_concat);
// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index cbc62d805a..a8139f3ee2 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -295,7 +295,9 @@ REGISTER_OP("Concat")
.Output("output: T")
.Attr("N: int >= 2")
.Attr("T: type")
- .SetShapeFn(shape_inference::ConcatShape)
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::ConcatShape(c, c->num_inputs() - 1);
+ })
.Doc(R"doc(
Concatenates tensors along one dimension.
@@ -4377,9 +4379,10 @@ REGISTER_OP("QuantizedConcat")
.Attr("N: int >= 2")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
- TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c));
+ const int n = (c->num_inputs() - 1) / 3;
+ TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c, n));
ShapeHandle unused;
- for (int i = std::max(0, c->num_inputs() - 2); i < c->num_inputs(); ++i) {
+ for (int i = n + 1; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
c->set_output(1, c->Scalar());
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 691380fd26..cf6b7718dd 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -1552,4 +1552,50 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
}
+TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) {
+ ShapeInferenceTestOp op("QuantizedConcat");
+ auto set_n = [&op](int n) {
+ std::vector<NodeDefBuilder::NodeOut> src_list;
+ std::vector<NodeDefBuilder::NodeOut> limit_list;
+ for (int i = 0; i < n; ++i) {
+ src_list.emplace_back("a", 0, DT_QUINT8);
+ limit_list.emplace_back("b", 0, DT_FLOAT);
+ }
+ TF_ASSERT_OK(NodeDefBuilder("test", "QuantizedConcat")
+ .Input({"concat_dim", 0, DT_INT32})
+ .Input(src_list)
+ .Input(limit_list)
+ .Input(limit_list)
+ .Attr("N", n)
+ .Finalize(&op.node_def));
+ };
+
+ // Confirm dimension[0] of the input (the concat_dim) is a scalar.
+ set_n(1);
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?");
+
+ // Last 2*<N> are all scalars.
+ set_n(2);
+ INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]");
+ INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?");
+ INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?");
+ INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?");
+
+ // First is concat dim; next N must be compatible for concat.
+ set_n(2);
+ INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?");
+ INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]");
+
+ // Test when the concat_dim tensor is known. The concatenated dimension is
+ // summed across all input tensors, and other dimensions are merged.
+ Tensor concat_dim_t;
+ op.input_tensors.push_back(&concat_dim_t);
+ set_n(2);
+ concat_dim_t = test::AsScalar(0); // Sum dim 0, merge the other two dims.
+ INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]");
+ INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
+ "[];[100,2,5];[10,?,3];?;?;?;?");
+ // Note that other cases of concat are covered in the Concat tests.
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 11a95a1847..79c909c7f6 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2617,7 +2617,8 @@ def _QuantizedReshapeShape(op):
# TODO(cwhipkey): Verify and enable shape functions for these.
ops.RegisterShape("QuantizeV2")(None)
ops.RegisterShape("QuantizedBatchNormWithGlobalNormalization")(None)
-ops.RegisterShape("QuantizedConcat")(None)
+
+ops.RegisterShape("QuantizedConcat")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("ScatterNd")