aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-05-14 16:17:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 16:20:31 -0700
commit5334631d7650d2212926fae661c2d0f8b9e7b358 (patch)
tree07faf2ac9ba7456661e394686e9059fc92b00dd4 /tensorflow/contrib/quantize
parent2451eef12c6b6b09dbf6b5b4a19d95272e197409 (diff)
Make sure that variables aren't created as partition variables since only non-scalar partition variables are supported.
PiperOrigin-RevId: 196584749
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py6
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops_test.py32
2 files changed, 36 insertions, 2 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 5c0e17dc86..27069444a4 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -81,7 +81,8 @@ def LastValueQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
- None, default_name=name_prefix, values=[inputs], reuse=reuse):
+ None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
+ scope.set_partitioner(None)
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
@@ -189,7 +190,8 @@ def MovingAvgQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
- None, default_name=name_prefix, values=[inputs], reuse=reuse):
+ None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
+ scope.set_partitioner(None)
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
index 3884679602..c2a8def480 100644
--- a/tensorflow/contrib/quantize/python/quant_ops_test.py
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase):
self.assertGreater(max_value, 0.0)
self.assertLess(max_value, 1.0)
+ def testVariablesNotParitioned_LastValue(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.LastValueQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
+ def testVariablesNotParitioned_MovingAvg(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.MovingAvgQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
def _GetMinMaxValues(self, sess):
min_max_vars = ops.get_collection(_MIN_MAX_VARS)
self.assertEqual(len(min_max_vars), 2)