aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-07-26 14:13:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-26 15:17:52 -0700
commitd79dd70648e3978a194f37f425ffbcf610bdd840 (patch)
tree4c98409165f0c9e7ba0103904574230d8cf49b5e /tensorflow/contrib/quantization
parentdaa18e62995a388957a25a4ea3221d2d121af16e (diff)
Add shape functions for some of the quantization ops.
Change: 128520578
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/nn_ops.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc
index fd12d155db..814011e411 100644
--- a/tensorflow/contrib/quantization/ops/nn_ops.cc
+++ b/tensorflow/contrib/quantization/ops/nn_ops.cc
@@ -73,6 +73,17 @@ REGISTER_OP("QuantizedBiasAdd")
.Attr("T1: quantizedtype")
.Attr("T2: quantizedtype")
.Attr("out_type: quantizedtype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Adds Tensor 'bias' to Tensor 'input' for Quantized types.
@@ -103,6 +114,17 @@ REGISTER_OP("QuantizedConv2D")
.Attr("out_type: quantizedtype = DT_QINT32")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Computes a 2D convolution given quantized 4D input and filter tensors.
The inputs are quantized tensors where the lowest value represents the real
@@ -159,6 +181,15 @@ REGISTER_OP("QuantizedRelu")
.Output("max_activations: float")
.Attr("Tinput: quantizedtype")
.Attr("out_type: quantizedtype = DT_QUINT8")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Computes Quantized Rectified Linear: `max(features, 0)`
@@ -179,6 +210,15 @@ REGISTER_OP("QuantizedRelu6")
.Output("max_activations: float")
.Attr("Tinput: quantizedtype")
.Attr("out_type: quantizedtype = DT_QUINT8")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)`
@@ -200,6 +240,15 @@ REGISTER_OP("QuantizedReluX")
.Output("max_activations: float")
.Attr("Tinput: quantizedtype")
.Attr("out_type: quantizedtype = DT_QUINT8")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)`