aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-29 15:40:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 15:42:56 -0700
commit2bc52cd2d481a89c9724d20e827097efa4ff3f1e (patch)
tree768f6197ff7f7fa453ddf98103809829c67aee32 /tensorflow/contrib/framework
parentaf670bdc0e61802778f61778dd1623c87f30e874 (diff)
- Expose slim arg_scope function to compute keys to enable tessting.
- Add is_training=None option to mobinenet arg_scopes. This allows the users to set is_training from an outer scope. PiperOrigin-RevId: 190997959
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/arg_scope.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py
index 3cad1fee19..5b15033995 100644
--- a/tensorflow/contrib/framework/python/ops/arg_scope.py
+++ b/tensorflow/contrib/framework/python/ops/arg_scope.py
@@ -68,7 +68,7 @@ from tensorflow.python.util import tf_decorator
__all__ = [
'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope',
- 'arg_scoped_arguments'
+ 'arg_scoped_arguments', 'arg_scope_func_key'
]
_ARGSTACK = [{}]
@@ -89,7 +89,7 @@ def current_arg_scope():
return stack[-1]
-def _key_op(op):
+def arg_scope_func_key(op):
return getattr(op, '_key_op', str(op))
@@ -103,9 +103,9 @@ def _kwarg_names(func):
def _add_op(op):
- key_op = _key_op(op)
- if key_op not in _DECORATED_OPS:
- _DECORATED_OPS[key_op] = _kwarg_names(op)
+ key = arg_scope_func_key(op)
+ if key not in _DECORATED_OPS:
+ _DECORATED_OPS[key] = _kwarg_names(op)
@tf_contextlib.contextmanager
@@ -147,16 +147,16 @@ def arg_scope(list_ops_or_scope, **kwargs):
try:
current_scope = current_arg_scope().copy()
for op in list_ops_or_scope:
- key_op = _key_op(op)
+ key = arg_scope_func_key(op)
if not has_arg_scope(op):
raise ValueError('%s is not decorated with @add_arg_scope',
_name_op(op))
- if key_op in current_scope:
- current_kwargs = current_scope[key_op].copy()
+ if key in current_scope:
+ current_kwargs = current_scope[key].copy()
current_kwargs.update(kwargs)
- current_scope[key_op] = current_kwargs
+ current_scope[key] = current_kwargs
else:
- current_scope[key_op] = kwargs.copy()
+ current_scope[key] = kwargs.copy()
_get_arg_stack().append(current_scope)
yield current_scope
finally:
@@ -176,14 +176,14 @@ def add_arg_scope(func):
def func_with_args(*args, **kwargs):
current_scope = current_arg_scope()
current_args = kwargs
- key_func = _key_op(func)
+ key_func = arg_scope_func_key(func)
if key_func in current_scope:
current_args = current_scope[key_func].copy()
current_args.update(kwargs)
return func(*args, **current_args)
_add_op(func)
- setattr(func_with_args, '_key_op', _key_op(func))
+ setattr(func_with_args, '_key_op', arg_scope_func_key(func))
return tf_decorator.make_decorator(func, func_with_args)
@@ -196,7 +196,7 @@ def has_arg_scope(func):
Returns:
a boolean.
"""
- return _key_op(func) in _DECORATED_OPS
+ return arg_scope_func_key(func) in _DECORATED_OPS
def arg_scoped_arguments(func):
@@ -209,4 +209,4 @@ def arg_scoped_arguments(func):
a list of kwargs names.
"""
assert has_arg_scope(func)
- return _DECORATED_OPS[_key_op(func)]
+ return _DECORATED_OPS[arg_scope_func_key(func)]