aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quantize_graph_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize_graph_test.py')
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index 54faf582f1..e80d2183a6 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -20,10 +20,12 @@ from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import quantize_graph
+from tensorflow.python import training
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@@ -145,6 +147,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
self.assertTrue(quant_delay_found)
+ def testTrainingOpsCheck(self):
+ self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck)
+
+ def _TestTrainingOpsCheck(self, rewrite_fn):
+ with ops.Graph().as_default():
+ output = self._ConvLayer()
+ output_scalar = math_ops.reduce_sum(output)
+ loss = math_ops.square(output_scalar - 1)
+ opt = training.gradient_descent.GradientDescentOptimizer(0.0001)
+ opt.minimize(loss)
+ with self.assertRaisesRegexp(ValueError, 'Training op found in graph'):
+ rewrite_fn()
+
def testWeightBits(self):
self._RunTestOverExperimentalRewrites(self._TestWeightBits)