diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-05-14 16:17:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 16:20:31 -0700 |
commit | 5334631d7650d2212926fae661c2d0f8b9e7b358 (patch) | |
tree | 07faf2ac9ba7456661e394686e9059fc92b00dd4 /tensorflow/contrib/quantize | |
parent | 2451eef12c6b6b09dbf6b5b4a19d95272e197409 (diff) |
Make sure that variables aren't created as partition variables since only non-scalar partition variables are supported.
PiperOrigin-RevId: 196584749
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/quant_ops.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quant_ops_test.py | 32 |
2 files changed, 36 insertions, 2 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 5c0e17dc86..27069444a4 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -81,7 +81,8 @@ def LastValueQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - None, default_name=name_prefix, values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: + scope.set_partitioner(None) input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: @@ -189,7 +190,8 @@ def MovingAvgQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - None, default_name=name_prefix, values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: + scope.set_partitioner(None) input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 3884679602..c2a8def480 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -23,6 +23,8 @@ from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) + def testVariablesNotParitioned_LastValue(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.LastValueQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + + def testVariablesNotParitioned_MovingAvg(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.MovingAvgQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + def _GetMinMaxValues(self, sess): min_max_vars = ops.get_collection(_MIN_MAX_VARS) self.assertEqual(len(min_max_vars), 2) |