aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 18:49:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 18:54:59 -0700
commita2a561e7e8305734e75492be961b28197c07c261 (patch)
tree20ee753a8156d974de3e667efd81bdcfb27678e5 /tensorflow/contrib/layers
parentd29759fa5370c7fe4ba5f12ac26cba08d5bb3c4f (diff)
Ensure that functools.wraps is not called on functools.partial objects in rev_block.
PiperOrigin-RevId: 209524010
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index dad3da3748..b25f11b5a6 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -151,9 +151,19 @@ def _rev_block_forward(x1,
return y1, y2
+def _safe_wraps(fn):
+ if isinstance(fn, functools.partial):
+ # functools.partial objects cannot be wrapped as they are missing the
+ # necessary properties (__name__, __module__, __doc__).
+ def passthrough(f):
+ return f
+ return passthrough
+ return functools.wraps(fn)
+
+
def _scope_wrap(fn, scope):
- @functools.wraps(fn)
+ @_safe_wraps(fn)
def wrap(*args, **kwargs):
with variable_scope.variable_scope(scope, use_resource=True):
return fn(*args, **kwargs)
@@ -430,7 +440,7 @@ def rev_block(x1,
def enable_with_args(dec):
"""A decorator for decorators to enable their usage with or without args."""
- @functools.wraps(dec)
+ @_safe_wraps(dec)
def new_dec(*args, **kwargs):
if len(args) == 1 and not kwargs and callable(args[0]):
# Used as decorator without args
@@ -477,7 +487,7 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
tf.gradients).
"""
- @functools.wraps(fn)
+ @_safe_wraps(fn)
def wrapped(*args):
return _recompute_grad(
fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)