aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 20:03:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 20:05:39 -0700
commit42e50daa384183d2f64e0ab5ae3f9bed07128e07 (patch)
treee5ab14a2a1fb34c0248c226be51c80ee7c98748b /tensorflow/contrib/distributions
parentcd468ceee10646c5e023661537a20915f52677f9 (diff)
Set the correct shape in transformed distribution.
Also add distribution_util.maybe_get_static_event_ndims to be reused in bijector and transformed distribution classes. PiperOrigin-RevId: 197831651
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py30
-rw-r--r--tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py13
3 files changed, 20 insertions, 25 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
index 8b279ebcd9..f8a52615b0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
@@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase):
for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]:
method = getattr(b, name)
with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
- method(1., event_ndims=0., arg1="b1", arg2="b2")
+ method(1., event_ndims=0, arg1="b1", arg2="b2")
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index b158a51bb0..16f959560c 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -234,7 +234,7 @@ class Chain(bijector.Bijector):
if not self.bijectors:
return ildj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
@@ -248,12 +248,15 @@ class Chain(bijector.Bijector):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
+
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
@@ -270,7 +273,7 @@ class Chain(bijector.Bijector):
if not self.bijectors:
return fldj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
@@ -283,21 +286,14 @@ class Chain(bijector.Bijector):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
-
- def _maybe_get_event_ndims_statically(self, event_ndims):
- event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
- event_ndims)
- if event_ndims_ is None:
- return event_ndims
- return event_ndims_
-
-
diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
index 10b4536135..3598c8d23e 100644
--- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import transformed_distribution
@@ -106,7 +105,7 @@ class ConditionalTransformedDistribution(
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
@@ -131,7 +130,7 @@ class ConditionalTransformedDistribution(
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
@@ -220,14 +219,14 @@ class ConditionalTransformedDistribution(
inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
return self.bijector.forward(inv_cdf, **bijector_kwargs)
- def _maybe_get_event_ndims_statically(self):
+ def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
- static_event_ndims = tensor_util.constant_value(event_ndims)
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- if static_event_ndims is not None:
- return static_event_ndims
+ if event_ndims_ is not None:
+ return event_ndims_
return event_ndims