aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 11:57:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 12:00:17 -0700
commit3ffa132c03ff02decc86a31d8bf888e9381278a7 (patch)
tree136099708056150061c0fce5bdaae676177b1daa /tensorflow
parent71b88284d9834f83a5d73feda3cf67944b878362 (diff)
Use distribution_util.arguments instead of locals. This fixes a bug in newer python version
where locals is a dynamic list. PiperOrigin-RevId: 196150149
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/chi2.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/geometric.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/negative_binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py6
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py56
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py2
-rw-r--r--tensorflow/python/ops/distributions/beta.py4
-rw-r--r--tensorflow/python/ops/distributions/categorical.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py3
-rw-r--r--tensorflow/python/ops/distributions/exponential.py5
-rw-r--r--tensorflow/python/ops/distributions/gamma.py4
-rw-r--r--tensorflow/python/ops/distributions/laplace.py5
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/normal.py5
-rw-r--r--tensorflow/python/ops/distributions/student_t.py4
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py2
-rw-r--r--tensorflow/python/ops/distributions/uniform.py3
-rw-r--r--tensorflow/python/ops/distributions/util.py38
52 files changed, 169 insertions, 60 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index 88ed012784..d813831bef 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -144,7 +144,7 @@ class Autoregressive(distribution_lib.Distribution):
`distribution_fn(sample0).event_shape.num_elements()` are both `None`.
ValueError: if `num_steps < 1`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._distribution_fn = distribution_fn
self._sample0 = sample0
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index bf5590cd55..8a4041cf43 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -104,7 +105,7 @@ class BatchReshape(distribution_lib.Distribution):
ValueError: if `batch_shape` size is not the same as a
`distribution.batch_shape` size.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or "BatchReshape" + distribution.name
self._distribution = distribution
with ops.name_scope(name, values=[batch_shape]) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
index 12d1603178..24b26bf124 100644
--- a/tensorflow/contrib/distributions/python/ops/binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -163,7 +163,7 @@ class Binomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = self._maybe_assert_valid_total_count(
ops.convert_to_tensor(total_count, name="total_count"),
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index daacfe657f..f5ffdd8731 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Cauchy",
@@ -120,7 +121,7 @@ class Cauchy(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index c77c5fd208..08cdc15828 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -83,7 +84,7 @@ class Chi2(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Even though all stats of chi2 are defined for valid parameters, this is
# not true in the parent class "gamma." therefore, passing
# allow_nan_stats=True
@@ -119,7 +120,7 @@ class Chi2WithAbsDf(Chi2):
validate_args=False,
allow_nan_stats=True,
name="Chi2WithAbsDf"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df]) as name:
super(Chi2WithAbsDf, self).__init__(
df=math_ops.floor(
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index a42350430e..6d7d6d307b 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Deterministic",
@@ -86,7 +87,7 @@ class _BaseDeterministic(distribution.Distribution):
Raises:
ValueError: If `loc` is a scalar.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, atol, rtol]) as name:
loc = ops.convert_to_tensor(loc, name="loc")
if is_vector and validate_args:
diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py
index 53dd42f4c8..446cff6ec2 100644
--- a/tensorflow/contrib/distributions/python/ops/geometric.py
+++ b/tensorflow/contrib/distributions/python/ops/geometric.py
@@ -85,7 +85,7 @@ class Geometric(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index 2c261073ee..ed9ea6f4f3 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
class _Gumbel(distribution.Distribution):
@@ -124,7 +125,7 @@ class _Gumbel(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index d0df2befd6..7e12767f6d 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -105,7 +106,7 @@ class HalfNormal(distribution.Distribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index fbde55ef31..fa89fff3b7 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
+from tensorflow.python.ops.distributions import util as distribution_util
class Independent(distribution_lib.Distribution):
@@ -116,7 +117,7 @@ class Independent(distribution_lib.Distribution):
ValueError: if `reinterpreted_batch_ndims` exceeds
`distribution.batch_ndims`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or "Independent" + distribution.name
self._distribution = distribution
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 502bd4f493..85e8e10466 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -125,7 +125,7 @@ class InverseGamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -280,7 +280,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
validate_args=False,
allow_nan_stats=True,
name="InverseGammaWithSoftplusConcentrationRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index c83b5bc2e3..0103283259 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
class Logistic(distribution.Distribution):
@@ -119,7 +120,7 @@ class Logistic(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 2ef294af2e..d54f30dc63 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -116,7 +116,7 @@ class Mixture(distribution.Distribution):
matching static batch shapes, or all components do not
have matching static event shapes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if not isinstance(cat, categorical.Categorical):
raise TypeError("cat must be a Categorical distribution, but saw: %s" %
cat)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 0b1301e551..c7c90cf875 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -130,7 +130,7 @@ class MixtureSameFamily(distribution.Distribution):
ValueError: if `mixture_distribution` categories does not equal
`components_distribution` rightmost batch shape.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._mixture_distribution = mixture_distribution
self._components_distribution = components_distribution
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index e3236c2db9..cad398582b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -193,7 +193,7 @@ class MultivariateNormalDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
@@ -224,7 +224,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
validate_args=False,
allow_nan_stats=True,
name="MultivariateNormalDiagWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale_diag]) as name:
super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 2f6a6f198c..1c11594df3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -215,7 +215,7 @@ class MultivariateNormalDiagPlusLowRank(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 5d06a396fe..47d7d13cf3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -155,7 +156,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
Raises:
ValueError: if neither `loc` nor `covariance_matrix` are specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Convert the covariance_matrix up to a scale_tril and call MVNTriL.
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index 44c92312c7..79916fef8d 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -170,7 +170,7 @@ class MultivariateNormalLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d6f8b731cb..d6b0ed994e 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -179,7 +179,7 @@ class MultivariateNormalTriL(
Raises:
ValueError: if neither `loc` nor `scale_tril` are specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
if loc is None and scale_tril is None:
diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
index eeaf9c0a5e..1085c56dc8 100644
--- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
@@ -90,7 +90,7 @@ class NegativeBinomial(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
index 305b138fdc..a4b9f3b78d 100644
--- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -115,7 +115,7 @@ class OneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
name=name, logits=logits, probs=probs, validate_args=validate_args,
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index a84aad6fc9..b345394021 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -93,7 +93,7 @@ class Poisson(distribution.Distribution):
TypeError: if `rate` is not a float-type.
TypeError: if `log_rate` is not a float-type.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[rate]) as name:
if (rate is None) == (log_rate is None):
raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 19c99dcee9..fe72091d7d 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -255,7 +255,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
TypeError: if `quadrature_grid` and `quadrature_probs` have different base
`dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
if loc is not None:
loc = ops.convert_to_tensor(loc, name="loc")
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index eb94760ad7..584d2c385f 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -263,7 +263,7 @@ class QuantizedDistribution(distributions.Distribution):
`Distribution` or continuous.
NotImplementedError: If the base distribution does not implement `cdf`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
values = (
list(distribution.parameters.values()) +
[low, high])
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 84c8d29072..0362996e68 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -165,7 +165,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
Raises:
ValueError: If both `probs` and `logits` are passed, or if neither.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
with ops.control_dependencies([check_ops.assert_positive(temperature)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 325f41e37c..910c430ae7 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -162,7 +162,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index 03828fa612..f04dc8da39 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -132,7 +132,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name,
values=[loc, scale, skewness, tailweight]) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index af6ff8162b..cd6d749959 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -395,7 +395,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_batch`.
ValueError: if `not distribution.is_scalar_event`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[mix_loc, temperature]) as name:
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index e265b5d0f7..3465d66b30 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -175,7 +175,7 @@ class VectorExponentialDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 89136d6760..2c31b01984 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -175,7 +175,7 @@ class VectorExponentialLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 8dd983b750..6a36018d6f 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -210,7 +210,7 @@ class VectorLaplaceDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name):
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index ec485c95c1..97e5c76d80 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -191,7 +191,7 @@ class VectorLaplaceLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 1438ede265..ff5ca45257 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -163,7 +163,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(
name,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 7e78ded9df..4742f75218 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -175,7 +175,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
scale_tril, scale_perturb_factor, scale_perturb_diag]
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 91453fed5d..f555867e7f 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -107,7 +107,7 @@ class _WishartLinearOperator(distribution.Distribution):
ValueError: if df < k, where scale operator event shape is
`(k, k)`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
self._cholesky_input_output_matrices = cholesky_input_output_matrices
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[df, scale_operator]):
@@ -530,7 +530,7 @@ class WishartCholesky(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
@@ -646,7 +646,7 @@ class WishartFull(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index f3cc9636f9..cf2e8832fd 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -41,6 +41,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
+ shard_count = 3,
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index b9fe197679..8569b36539 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -1017,6 +1017,62 @@ class SoftplusTest(test.TestCase):
self.assertAllEqual(
np.ones_like(grads).astype(np.bool), np.isfinite(grads))
+class ArgumentsTest(test.TestCase):
+
+ def testNoArguments(self):
+ def foo():
+ return du.parent_frame_arguments()
+
+ self.assertEqual({}, foo())
+
+ def testPositionalArguments(self):
+ def foo(a, b, c, d): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(1, 2, 3, 4))
+
+ # Tests that it does not matter where this function is called, and
+ # no other local variables are returned back.
+ def bar(a, b, c):
+ unused_x = a * b
+ unused_y = c * 3
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, bar(1, 2, 3))
+
+ def testOverloadedArgumentValues(self):
+ def foo(a, b, c): # pylint: disable=unused-argument
+ a = 42
+ b = 31
+ c = 42
+ return du.parent_frame_arguments()
+ self.assertEqual({"a": 42, "b": 31, "c": 42}, foo(1, 2, 3))
+
+ def testKeywordArguments(self):
+ def foo(**kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(a=1, b=2, c=3, d=4))
+
+ def testPositionalKeywordArgs(self):
+ def foo(a, b, c, **kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(a=1, b=2, c=3, unicorn=None))
+
+ def testNoVarargs(self):
+ def foo(a, b, c, *varargs, **kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(1, 2, 3, *[1, 2, 3]))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(1, 2, 3, unicorn=None))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(1, 2, 3, *[1, 2, 3], unicorn=None))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index 2c9f0e9a32..d7fb3f1f78 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -71,7 +71,7 @@ class Bernoulli(distribution.Distribution):
Raises:
ValueError: If p and logits are passed, or if neither are passed.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 8beab99bf8..b697848600 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -150,7 +150,7 @@ class Beta(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
self._concentration1 = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration1, name="concentration1"),
@@ -321,7 +321,7 @@ class BetaWithSoftplusConcentration(Beta):
validate_args=False,
allow_nan_stats=True,
name="BetaWithSoftplusConcentration"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration1,
concentration0]) as name:
super(BetaWithSoftplusConcentration, self).__init__(
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 8f25b1149c..bbdc8c455a 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -182,7 +182,7 @@ class Categorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index eafcd5c78f..8d0d1d860b 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -154,7 +154,7 @@ class Dirichlet(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration]) as name:
self._concentration = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration, name="concentration"),
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index fe0ed7e07d..3a35e0caa0 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -191,7 +191,7 @@ class DirichletMultinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, concentration]) as name:
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 3815abf72d..fd08bda9b9 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -524,7 +524,8 @@ class Distribution(_BaseDistribution):
def parameters(self):
"""Dictionary of parameters used to instantiate this `Distribution`."""
# Remove "self", "__class__", or other special variables. These can appear
- # if the subclass used `parameters = locals()`.
+ # if the subclass used:
+ # `parameters = distribution_util.parent_frame_arguments()`.
return dict((k, v) for k, v in self._parameters.items()
if not k.startswith("__") and k != "self")
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index cf0e729e1a..1e08f48d52 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -90,7 +91,7 @@ class Exponential(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Even though all statistics of are defined for valid inputs, this is not
# true in the parent class "Gamma." Therefore, passing
# allow_nan_stats=True
@@ -143,7 +144,7 @@ class ExponentialWithSoftplusRate(Exponential):
validate_args=False,
allow_nan_stats=True,
name="ExponentialWithSoftplusRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[rate]) as name:
super(ExponentialWithSoftplusRate, self).__init__(
rate=nn.softplus(rate, name="softplus_rate"),
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index d39f7c56d3..7ca690d9d2 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -126,7 +126,7 @@ class Gamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -261,7 +261,7 @@ class GammaWithSoftplusConcentrationRate(Gamma):
validate_args=False,
allow_nan_stats=True,
name="GammaWithSoftplusConcentrationRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
super(GammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index 3ccfc618d1..ee3a6a40ff 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -100,7 +101,7 @@ class Laplace(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` are of different dtype.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -217,7 +218,7 @@ class LaplaceWithSoftplusScale(Laplace):
validate_args=False,
allow_nan_stats=True,
name="LaplaceWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
super(LaplaceWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index ab77f5c1f8..036ba45ccc 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -182,7 +182,7 @@ class Multinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
if validate_args:
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 20d4420e91..0620aae10d 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -131,7 +132,7 @@ class Normal(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -243,7 +244,7 @@ class NormalWithSoftplusScale(Normal):
validate_args=False,
allow_nan_stats=True,
name="NormalWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
super(NormalWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index 961b07a7bd..9330b930b5 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -157,7 +157,7 @@ class StudentT(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df, loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(df)]
if validate_args else []):
@@ -349,7 +349,7 @@ class StudentTWithAbsDfSoftplusScale(StudentT):
validate_args=False,
allow_nan_stats=True,
name="StudentTWithAbsDfSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df, scale]) as name:
super(StudentTWithAbsDfSoftplusScale, self).__init__(
df=math_ops.floor(math_ops.abs(df)),
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index bc321900dc..9392464ec1 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -252,7 +252,7 @@ class TransformedDistribution(distribution_lib.Distribution):
name: Python `str` name prefixed to Ops created by this class. Default:
`bijector.name + distribution.name`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index 087797c653..dfa10331e3 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -102,7 +103,7 @@ class Uniform(distribution.Distribution):
Raises:
InvalidArgumentError: if `low >= high` and `validate_args=False`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[low, high]) as name:
with ops.control_dependencies([
check_ops.assert_less(
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3afa85fda0..59c89d21f9 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import tf_inspect
def assert_close(
@@ -1297,6 +1298,43 @@ def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
return x
+def parent_frame_arguments():
+ """Returns parent frame arguments.
+
+ When called inside a function, returns a dictionary with the caller's function
+ arguments. These are positional arguments and keyword arguments (**kwargs),
+ while variable arguments (*varargs) are excluded.
+
+ When called at global scope, this will return an empty dictionary, since there
+ are no arguments.
+
+ WARNING: If caller function argument names are overloaded before invoking
+ this method, then values will reflect the overloaded value. For this reason,
+ we recommend calling `parent_frame_arguments` at the beginning of the
+ function.
+ """
+ # All arguments and the names used for *varargs, and **kwargs
+ arg_names, variable_arg_name, keyword_arg_name, local_vars = (
+ tf_inspect._inspect.getargvalues( # pylint: disable=protected-access
+ # Get the first frame of the caller of this method.
+ tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access
+
+ # Remove the *varargs, and flatten the **kwargs. Both are
+ # nested lists.
+ local_vars.pop(variable_arg_name, {})
+ keyword_args = local_vars.pop(keyword_arg_name, {})
+
+ final_args = {}
+ # Copy over arguments and their values. In general, local_vars
+ # may contain more than just the arguments, since this method
+ # can be called anywhere in a function.
+ for arg_name in arg_names:
+ final_args[arg_name] = local_vars.pop(arg_name)
+ final_args.update(keyword_args)
+
+ return final_args
+
+
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.