From 9b2c80c4354cd08f3fda9ce75295226be72aa9d0 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Tue, 2 Aug 2016 13:41:55 -0800 Subject: TensorFlow: Finish off quantization shape functions in contrib Change: 129143268 --- tensorflow/contrib/quantization/ops/math_ops.cc | 9 +++++++++ tensorflow/contrib/quantization/ops/nn_ops.cc | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/quantization/ops/math_ops.cc b/tensorflow/contrib/quantization/ops/math_ops.cc index 6bc408531a..ed0930c2d6 100644 --- a/tensorflow/contrib/quantization/ops/math_ops.cc +++ b/tensorflow/contrib/quantization/ops/math_ops.cc @@ -80,6 +80,15 @@ REGISTER_OP("QuantizeDownAndShrinkRange") .Output("output_max: float") .Attr("Tinput: quantizedtype") .Attr("out_type: quantizedtype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Convert the quantized 'input' tensor into a lower-precision 'output', using the actual distribution of the values to maximize the usage of the lower bit depth diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc index 18db2b0eaa..c33f318c6e 100644 --- a/tensorflow/contrib/quantization/ops/nn_ops.cc +++ b/tensorflow/contrib/quantization/ops/nn_ops.cc @@ -21,6 +21,7 @@ limitations under the License. namespace tensorflow { +using shape_inference::Dimension; using shape_inference::InferenceContext; using shape_inference::Shape; @@ -292,6 +293,25 @@ REGISTER_OP("QuantizedBatchNormWithGlobalNormalization") .Attr("out_type: quantizedtype") .Attr("variance_epsilon: float") .Attr("scale_after_normalization: bool") + .SetShapeFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + + const Dimension* last_dim = c->Dim(input, 3); + for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma + const Shape* vec; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec)); + TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); + } + + const Shape* out; + TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); + c->set_output(0, out); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + + return Status::OK(); + }) .Doc(R"doc( Quantized Batch normalization. -- cgit v1.2.3