diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2016-07-28 18:17:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-28 19:33:24 -0700 |
commit | 5ba35496849903c3d458bc8443735a9740b29b2c (patch) | |
tree | 74d9ec28da691adc0abe6ecef9295b8eaf8972c2 /tensorflow/contrib/quantization | |
parent | 7d9181d38d9ee3aed2190bb31f09ada0828e1a08 (diff) |
Add C++ shape inference for quantizev2 and dequantize.
Change: 128768449
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/ops/array_ops.cc | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc index 35d0e7f4c9..e1cf3ded93 100644 --- a/tensorflow/contrib/quantization/ops/array_ops.cc +++ b/tensorflow/contrib/quantization/ops/array_ops.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { -// -------------------------------------------------------------------------- +using shape_inference::InferenceContext; +using shape_inference::Shape; REGISTER_OP("QuantizeV2") .Input("input: float") @@ -28,6 +31,15 @@ REGISTER_OP("QuantizeV2") .Output("output_max: float") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. @@ -96,6 +108,13 @@ REGISTER_OP("Dequantize") .Output("output: float") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return Status::OK(); + }) .Doc(R"doc( Dequantize the 'input' tensor into a float Tensor. |