diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-10-02 11:30:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 11:33:56 -0700 |
commit | feb0dc87078698fd335b528c661c54226a58efa9 (patch) | |
tree | 39553b0f28c0aec14d7bf5fde1db0ca37d58e481 /tensorflow/contrib/quantize | |
parent | dd66b78b38b457c7d37527472c4e92a7a07f4b09 (diff) |
Remove dependency on contrib model_variable.
Also remove add_arg_scope.
PiperOrigin-RevId: 215426187
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quant_ops.py | 28 |
2 files changed, 19 insertions, 10 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 23e3a25d71..94a2d9672d 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -138,7 +138,6 @@ py_library( srcs = ["python/quant_ops.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 27069444a4..d9dc7fa62e 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.contrib.framework.python.ops import model_variable from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -29,7 +27,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import moving_averages -@add_arg_scope def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): """Adds a fake quantize layer with fixed quantization interval. @@ -46,7 +43,21 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): inputs, min=init_min, max=init_max) -@add_arg_scope +def _ModelVariable(name, + shape=None, + initializer=None, + collections=None, + trainable=None): + collections = list(collections or []) + collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] + return variable_scope.get_variable( + name, + shape=shape, + initializer=initializer, + collections=collections, + trainable=trainable) + + def LastValueQuantize(inputs, per_channel=False, init_min=-6.0, @@ -93,13 +104,13 @@ def LastValueQuantize(inputs, else: min_max_shape = [] - min_var = model_variable( + min_var = _ModelVariable( 'min', shape=min_max_shape, initializer=init_ops.constant_initializer(init_min), collections=[vars_collection], trainable=False) - max_var = model_variable( + max_var = _ModelVariable( 'max', shape=min_max_shape, initializer=init_ops.constant_initializer(init_max), @@ -153,7 +164,6 @@ def LastValueQuantize(inputs, narrow_range=narrow_range) -@add_arg_scope def MovingAvgQuantize(inputs, per_channel=False, init_min=-6.0, @@ -202,13 +212,13 @@ def MovingAvgQuantize(inputs, else: min_max_shape = [] - min_var = model_variable( + min_var = _ModelVariable( 'min', shape=min_max_shape, initializer=init_ops.constant_initializer(init_min), collections=[vars_collection], trainable=False) - max_var = model_variable( + max_var = _ModelVariable( 'max', shape=min_max_shape, initializer=init_ops.constant_initializer(init_max), |