aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quant_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quant_ops_test.py')
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops_test.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
index 3884679602..c2a8def480 100644
--- a/tensorflow/contrib/quantize/python/quant_ops_test.py
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase):
self.assertGreater(max_value, 0.0)
self.assertLess(max_value, 1.0)
+ def testVariablesNotParitioned_LastValue(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.LastValueQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
+ def testVariablesNotParitioned_MovingAvg(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.MovingAvgQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
def _GetMinMaxValues(self, sess):
min_max_vars = ops.get_collection(_MIN_MAX_VARS)
self.assertEqual(len(min_max_vars), 2)