aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-02 13:41:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-02 14:47:56 -0700
commit9b2c80c4354cd08f3fda9ce75295226be72aa9d0 (patch)
tree05078473150dc657141a42f92411f6eb5d02c155 /tensorflow/contrib/quantization
parent930b8a0d58ab6617969d23be92d4de1f122ffedf (diff)
TensorFlow: Finish off quantization shape functions in contrib
Change: 129143268
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/math_ops.cc9
-rw-r--r--tensorflow/contrib/quantization/ops/nn_ops.cc20
2 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/math_ops.cc b/tensorflow/contrib/quantization/ops/math_ops.cc
index 6bc408531a..ed0930c2d6 100644
--- a/tensorflow/contrib/quantization/ops/math_ops.cc
+++ b/tensorflow/contrib/quantization/ops/math_ops.cc
@@ -80,6 +80,15 @@ REGISTER_OP("QuantizeDownAndShrinkRange")
.Output("output_max: float")
.Attr("Tinput: quantizedtype")
.Attr("out_type: quantizedtype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Convert the quantized 'input' tensor into a lower-precision 'output', using the
actual distribution of the values to maximize the usage of the lower bit depth
diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc
index 18db2b0eaa..c33f318c6e 100644
--- a/tensorflow/contrib/quantization/ops/nn_ops.cc
+++ b/tensorflow/contrib/quantization/ops/nn_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::Dimension;
using shape_inference::InferenceContext;
using shape_inference::Shape;
@@ -292,6 +293,25 @@ REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
.Attr("out_type: quantizedtype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+
+ const Dimension* last_dim = c->Dim(input, 3);
+ for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
+ const Shape* vec;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
+ TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
+ }
+
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
+ c->set_output(0, out);
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+
+ return Status::OK();
+ })
.Doc(R"doc(
Quantized Batch normalization.