diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-07-26 14:13:40 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-26 15:17:52 -0700 |
commit | d79dd70648e3978a194f37f425ffbcf610bdd840 (patch) | |
tree | 4c98409165f0c9e7ba0103904574230d8cf49b5e /tensorflow/contrib/quantization | |
parent | daa18e62995a388957a25a4ea3221d2d121af16e (diff) |
Add shape functions for some of the quantization ops.
Change: 128520578
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/ops/nn_ops.cc | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc index fd12d155db..814011e411 100644 --- a/tensorflow/contrib/quantization/ops/nn_ops.cc +++ b/tensorflow/contrib/quantization/ops/nn_ops.cc @@ -73,6 +73,17 @@ REGISTER_OP("QuantizedBiasAdd") .Attr("T1: quantizedtype") .Attr("T2: quantizedtype") .Attr("out_type: quantizedtype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Adds Tensor 'bias' to Tensor 'input' for Quantized types. @@ -103,6 +114,17 @@ REGISTER_OP("QuantizedConv2D") .Attr("out_type: quantizedtype = DT_QINT32") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Computes a 2D convolution given quantized 4D input and filter tensors. The inputs are quantized tensors where the lowest value represents the real @@ -159,6 +181,15 @@ REGISTER_OP("QuantizedRelu") .Output("max_activations: float") .Attr("Tinput: quantizedtype") .Attr("out_type: quantizedtype = DT_QUINT8") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Computes Quantized Rectified Linear: `max(features, 0)` @@ -179,6 +210,15 @@ REGISTER_OP("QuantizedRelu6") .Output("max_activations: float") .Attr("Tinput: quantizedtype") .Attr("out_type: quantizedtype = DT_QUINT8") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Computes Quantized Rectified Linear 6: `min(max(features, 0), 6)` @@ -200,6 +240,15 @@ REGISTER_OP("QuantizedReluX") .Output("max_activations: float") .Attr("Tinput: quantizedtype") .Attr("out_type: quantizedtype = DT_QUINT8") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Computes Quantized Rectified Linear X: `min(max(features, 0), max_value)` |