diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/quant_ops_test.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/quant_ops_test.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 3884679602..c2a8def480 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -23,6 +23,8 @@ from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) + def testVariablesNotParitioned_LastValue(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.LastValueQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + + def testVariablesNotParitioned_MovingAvg(self): + # Variables added should not use a default partiioner since they are + # scalar. There would be a tensorflow error thrown if the partitioner was + # respected by the rewrite. + with ops.Graph().as_default(): + with variable_scope.variable_scope( + 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)): + x = array_ops.placeholder(dtypes.float32, shape=[2]) + _ = quant_ops.MovingAvgQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + def _GetMinMaxValues(self, sess): min_max_vars = ops.get_collection(_MIN_MAX_VARS) self.assertEqual(len(min_max_vars), 2) |