diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-08-10 15:43:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-10 16:48:21 -0700 |
commit | f5df66259d65665a3d5393a384e653b6b6241c1e (patch) | |
tree | 89863c526d18a27f4eb2246703ec75b5e1488e69 /tensorflow/contrib/quantization | |
parent | 331a2faead0876a50861709c82dd41485074c60d (diff) |
Add shape function for QuantizedConcat, moving ConcatShape into
a common shape fn.
Change: 129928393
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/ops/array_ops.cc | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc index e1cf3ded93..ea59a9ae3e 100644 --- a/tensorflow/contrib/quantization/ops/array_ops.cc +++ b/tensorflow/contrib/quantization/ops/array_ops.cc @@ -166,6 +166,16 @@ REGISTER_OP("QuantizedConcat") .Output("output_max: float") .Attr("N: int >= 2") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::ConcatShape(c)); + const Shape* unused; + for (int i = 2; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); + } + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Concatenates quantized tensors along one dimension. |