aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2016-07-28 18:17:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-28 19:33:24 -0700
commit5ba35496849903c3d458bc8443735a9740b29b2c (patch)
tree74d9ec28da691adc0abe6ecef9295b8eaf8972c2 /tensorflow/contrib/quantization
parent7d9181d38d9ee3aed2190bb31f09ada0828e1a08 (diff)
Add C++ shape inference for quantizev2 and dequantize.
Change: 128768449
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/ops/array_ops.cc21
1 files changed, 20 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc
index 35d0e7f4c9..e1cf3ded93 100644
--- a/tensorflow/contrib/quantization/ops/array_ops.cc
+++ b/tensorflow/contrib/quantization/ops/array_ops.cc
@@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
-// --------------------------------------------------------------------------
+using shape_inference::InferenceContext;
+using shape_inference::Shape;
REGISTER_OP("QuantizeV2")
.Input("input: float")
@@ -28,6 +31,15 @@ REGISTER_OP("QuantizeV2")
.Output("output_max: float")
.Attr("T: quantizedtype")
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ })
.Doc(R"doc(
Quantize the 'input' tensor of type float to 'output' tensor of type 'T'.
@@ -96,6 +108,13 @@ REGISTER_OP("Dequantize")
.Output("output: float")
.Attr("T: quantizedtype")
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
+ const Shape* unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return Status::OK();
+ })
.Doc(R"doc(
Dequantize the 'input' tensor into a float Tensor.