aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-10-02 11:30:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 11:33:56 -0700
commitfeb0dc87078698fd335b528c661c54226a58efa9 (patch)
tree39553b0f28c0aec14d7bf5fde1db0ca37d58e481 /tensorflow/contrib/quantize
parentdd66b78b38b457c7d37527472c4e92a7a07f4b09 (diff)
Remove dependency on contrib model_variable.
Also remove add_arg_scope. PiperOrigin-RevId: 215426187
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/BUILD1
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py28
2 files changed, 19 insertions, 10 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 23e3a25d71..94a2d9672d 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -138,7 +138,6 @@ py_library(
srcs = ["python/quant_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 27069444a4..d9dc7fa62e 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.framework.python.ops import add_arg_scope
-from tensorflow.contrib.framework.python.ops import model_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -29,7 +27,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
-@add_arg_scope
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
"""Adds a fake quantize layer with fixed quantization interval.
@@ -46,7 +43,21 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
inputs, min=init_min, max=init_max)
-@add_arg_scope
+def _ModelVariable(name,
+ shape=None,
+ initializer=None,
+ collections=None,
+ trainable=None):
+ collections = list(collections or [])
+ collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
+ return variable_scope.get_variable(
+ name,
+ shape=shape,
+ initializer=initializer,
+ collections=collections,
+ trainable=trainable)
+
+
def LastValueQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -93,13 +104,13 @@ def LastValueQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),
@@ -153,7 +164,6 @@ def LastValueQuantize(inputs,
narrow_range=narrow_range)
-@add_arg_scope
def MovingAvgQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -202,13 +212,13 @@ def MovingAvgQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),