aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-10 15:43:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-10 16:48:21 -0700
commitf5df66259d65665a3d5393a384e653b6b6241c1e (patch)
tree89863c526d18a27f4eb2246703ec75b5e1488e69 /tensorflow/contrib/quantization
parent331a2faead0876a50861709c82dd41485074c60d (diff)
Add shape function for QuantizedConcat, moving ConcatShape into
a common shape fn. Change: 129928393
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/array_ops.cc10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc
index e1cf3ded93..ea59a9ae3e 100644
--- a/tensorflow/contrib/quantization/ops/array_ops.cc
+++ b/tensorflow/contrib/quantization/ops/array_ops.cc
@@ -166,6 +166,16 @@ REGISTER_OP("QuantizedConcat")
.Output("output_max: float")
.Attr("N: int >= 2")
.Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c));
+ const Shape* unused;
+ for (int i = 2; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
+ }
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Concatenates quantized tensors along one dimension.