aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/python/math_ops.py12
-rw-r--r--tensorflow/contrib/quantization/python/nn_ops.py56
2 files changed, 5 insertions, 63 deletions
diff --git a/tensorflow/contrib/quantization/python/math_ops.py b/tensorflow/contrib/quantization/python/math_ops.py
index 43c1409358..d4fabbd36b 100644
--- a/tensorflow/contrib/quantization/python/math_ops.py
+++ b/tensorflow/contrib/quantization/python/math_ops.py
@@ -23,16 +23,6 @@ from tensorflow.contrib.quantization.ops import gen_math_ops
from tensorflow.contrib.quantization.ops.gen_math_ops import *
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-# QuantizedMatMul* ops.
-@ops.RegisterShape("QuantizedMatMul")
-def _QuantizedMatMulShape(op):
- unused_a_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- unused_a_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- unused_b_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
- unused_b_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
- result = common_shapes.matmul_shape(op)
- result.extend([tensor_shape.scalar(), tensor_shape.scalar()])
- return result
+ops.RegisterShape("QuantizedMatMul")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/contrib/quantization/python/nn_ops.py b/tensorflow/contrib/quantization/python/nn_ops.py
index 122d93fd23..d31f1d4e68 100644
--- a/tensorflow/contrib/quantization/python/nn_ops.py
+++ b/tensorflow/contrib/quantization/python/nn_ops.py
@@ -23,60 +23,12 @@ from tensorflow.contrib.quantization.ops import gen_nn_ops
from tensorflow.contrib.quantization.ops.gen_nn_ops import *
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-
-
-# QuantizedAvgPool* ops.
-@ops.RegisterShape("QuantizedAvgPool")
-def _QuantizedAvgPoolShape(op):
- return [common_shapes.avg_pool_shape(op)[0], tensor_shape.scalar(),
- tensor_shape.scalar()]
-
-
-# QuantizedBiasAdd op.
-@ops.RegisterShape("QuantizedBiasAdd")
-def _QuantizedBiasAddShape(op):
- """Returns the same shape as the input, plus min and max scalar values.
-
- Args:
- op: Input operation.
- Returns:
- Shape of ops first input, plus min and max tensors.
- """
- unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- unused_bias_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
- unused_bias_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
- return [op.inputs[0].get_shape(), tensor_shape.scalar(),
- tensor_shape.scalar()]
-
-
-# QuantizedConv2D* ops.
-@ops.RegisterShape("QuantizedConv2D")
-def _QuantizedConv2DShape(op):
- """Returns the same shape as Conv2D, plus min and max scalar values.
-
- Args:
- op: Input operation.
- Returns:
- Shape of float Conv2D, plus min and max tensors.
- """
- unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- unused_filter_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
- unused_filter_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
- result = common_shapes.conv2d_shape(op)
- result.extend([tensor_shape.scalar(), tensor_shape.scalar()])
- return result
-
-
-# QuantizedMaxPool* ops.
-@ops.RegisterShape("QuantizedMaxPool")
-def _QuantizedMaxPoolShape(op):
- return [common_shapes.max_pool_shape(op)[0], tensor_shape.scalar(),
- tensor_shape.scalar()]
+ops.RegisterShape("QuantizedAvgPool")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("QuantizedBiasAdd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("QuantizedConv2D")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("QuantizedMaxPool")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("QuantizedRelu")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("QuantizedRelu6")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("QuantizedReluX")(common_shapes.call_cpp_shape_fn)