aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-16 08:45:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-16 08:49:10 -0700
commitd27953bb69ba44431b85fdf7ac43ed83c4422e40 (patch)
tree31b77c8e764361610994537d8853649f649257c4 /tensorflow/contrib/distributions
parente3aa44ec207bfdc798e26e92a38e80c3f9c5453b (diff)
Fix bug in masked_autoregressive_default_template where custom name was not creating custom variable scopes.
PiperOrigin-RevId: 204747987
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index b8f2a4b2c7..296e66f2b2 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -514,9 +514,8 @@ def masked_autoregressive_default_template(
Masked Autoencoder for Distribution Estimation. In _International
Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509
"""
-
- with ops.name_scope(name, "masked_autoregressive_default_template",
- values=[log_scale_min_clip, log_scale_max_clip]):
+ name = name or "masked_autoregressive_default_template"
+ with ops.name_scope(name, values=[log_scale_min_clip, log_scale_max_clip]):
def _fn(x):
"""MADE parameterized via `masked_autoregressive_default_template`."""
# TODO(b/67594795): Better support of dynamic shape.
@@ -552,8 +551,7 @@ def masked_autoregressive_default_template(
else _clip_by_value_preserve_grad)
log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip)
return shift, log_scale
- return template_ops.make_template(
- "masked_autoregressive_default_template", _fn)
+ return template_ops.make_template(name, _fn)
@deprecation.deprecated(