diff options
author | 2018-03-29 15:40:14 -0700 | |
---|---|---|
committer | 2018-03-29 15:42:56 -0700 | |
commit | 2bc52cd2d481a89c9724d20e827097efa4ff3f1e (patch) | |
tree | 768f6197ff7f7fa453ddf98103809829c67aee32 /tensorflow/contrib/framework | |
parent | af670bdc0e61802778f61778dd1623c87f30e874 (diff) |
- Expose slim arg_scope function to compute keys to enable tessting.
- Add is_training=None option to mobinenet arg_scopes. This allows the users to set is_training from an outer scope.
PiperOrigin-RevId: 190997959
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/arg_scope.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 3cad1fee19..5b15033995 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -68,7 +68,7 @@ from tensorflow.python.util import tf_decorator __all__ = [ 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', - 'arg_scoped_arguments' + 'arg_scoped_arguments', 'arg_scope_func_key' ] _ARGSTACK = [{}] @@ -89,7 +89,7 @@ def current_arg_scope(): return stack[-1] -def _key_op(op): +def arg_scope_func_key(op): return getattr(op, '_key_op', str(op)) @@ -103,9 +103,9 @@ def _kwarg_names(func): def _add_op(op): - key_op = _key_op(op) - if key_op not in _DECORATED_OPS: - _DECORATED_OPS[key_op] = _kwarg_names(op) + key = arg_scope_func_key(op) + if key not in _DECORATED_OPS: + _DECORATED_OPS[key] = _kwarg_names(op) @tf_contextlib.contextmanager @@ -147,16 +147,16 @@ def arg_scope(list_ops_or_scope, **kwargs): try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: - key_op = _key_op(op) + key = arg_scope_func_key(op) if not has_arg_scope(op): raise ValueError('%s is not decorated with @add_arg_scope', _name_op(op)) - if key_op in current_scope: - current_kwargs = current_scope[key_op].copy() + if key in current_scope: + current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) - current_scope[key_op] = current_kwargs + current_scope[key] = current_kwargs else: - current_scope[key_op] = kwargs.copy() + current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: @@ -176,14 +176,14 @@ def add_arg_scope(func): def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs - key_func = _key_op(func) + key_func = arg_scope_func_key(func) if key_func in current_scope: current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) _add_op(func) - setattr(func_with_args, '_key_op', _key_op(func)) + setattr(func_with_args, '_key_op', arg_scope_func_key(func)) return tf_decorator.make_decorator(func, func_with_args) @@ -196,7 +196,7 @@ def has_arg_scope(func): Returns: a boolean. """ - return _key_op(func) in _DECORATED_OPS + return arg_scope_func_key(func) in _DECORATED_OPS def arg_scoped_arguments(func): @@ -209,4 +209,4 @@ def arg_scoped_arguments(func): a list of kwargs names. """ assert has_arg_scope(func) - return _DECORATED_OPS[_key_op(func)] + return _DECORATED_OPS[arg_scope_func_key(func)] |