diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/quantize_graph_test.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize_graph_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index 54faf582f1..e80d2183a6 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -20,10 +20,12 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.quantize.python import quantize_graph +from tensorflow.python import training from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest @@ -145,6 +147,19 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): self.assertTrue(('int64_val: %i' % quant_delay) in const_value) self.assertTrue(quant_delay_found) + def testTrainingOpsCheck(self): + self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck) + + def _TestTrainingOpsCheck(self, rewrite_fn): + with ops.Graph().as_default(): + output = self._ConvLayer() + output_scalar = math_ops.reduce_sum(output) + loss = math_ops.square(output_scalar - 1) + opt = training.gradient_descent.GradientDescentOptimizer(0.0001) + opt.minimize(loss) + with self.assertRaisesRegexp(ValueError, 'Training op found in graph'): + rewrite_fn() + def testWeightBits(self): self._RunTestOverExperimentalRewrites(self._TestWeightBits) |