aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-08 16:20:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:05:41 -0700
commitbbebae04db61e137e4013a031f429543422ae373 (patch)
tree88cc3a4f238b2b02165ebd5c46ebcc8a71d983ec /tensorflow/contrib/distributions
parent0028bf843d8846bd16b25bf5447b1649fde10fb7 (diff)
Only use integer values for event_ndims.
event_ndims have the semantics of being an integer. However, other code paths (such as const_value) can return back numpy wrapped arrays, which can mess with how values are cached. Instead extract everything as an integer. PiperOrigin-RevId: 195894216
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py44
2 files changed, 31 insertions, 23 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index ca20442c39..dc45114b1c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase):
-np.log(6, dtype=np.float32) - np.sum(x),
self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
+ def testChainIldjWithPlaceholder(self):
+ chain = Chain((Exp(), Exp()))
+ samples = array_ops.placeholder(
+ dtype=np.float32, shape=[None, 10], name="samples")
+ ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
+ self.assertTrue(ildj is not None)
+ with self.test_session():
+ ildj.eval({samples: np.zeros([2, 10], np.float32)})
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index 85ad23e413..b158a51bb0 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -20,10 +20,9 @@ from __future__ import print_function
import itertools
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
@@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims):
return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)
-def _maybe_get_event_ndims_statically(event_ndims):
- static_event_ndims = (event_ndims if isinstance(event_ndims, int)
- else tensor_util.constant_value(event_ndims))
- if static_event_ndims is not None:
- return static_event_ndims
-
- return event_ndims
-
-
def _compute_min_event_ndims(bijector_list, compute_forward=True):
"""Computes the min_event_ndims associated with the give list of bijectors.
@@ -238,13 +228,13 @@ class Chain(bijector.Bijector):
return y
def _inverse_log_det_jacobian(self, y, **kwargs):
- ildj = constant_op.constant(
- 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian")
+ y = ops.convert_to_tensor(y, name="y")
+ ildj = math_ops.cast(0., dtype=y.dtype.base_dtype)
if not self.bijectors:
return ildj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
@@ -258,11 +248,12 @@ class Chain(bijector.Bijector):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(
+ event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
@@ -274,13 +265,12 @@ class Chain(bijector.Bijector):
def _forward_log_det_jacobian(self, x, **kwargs):
x = ops.convert_to_tensor(x, name="x")
- fldj = constant_op.constant(
- 0., dtype=x.dtype, name="inverse_log_det_jacobian")
+ fldj = math_ops.cast(0., dtype=x.dtype.base_dtype)
if not self.bijectors:
return fldj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
@@ -293,13 +283,21 @@ class Chain(bijector.Bijector):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
+ def _maybe_get_event_ndims_statically(self, event_ndims):
+ event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
+ event_ndims)
+ if event_ndims_ is None:
+ return event_ndims
+ return event_ndims_
+
+