aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-08 16:20:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:05:41 -0700
commitbbebae04db61e137e4013a031f429543422ae373 (patch)
tree88cc3a4f238b2b02165ebd5c46ebcc8a71d983ec
parent0028bf843d8846bd16b25bf5447b1649fde10fb7 (diff)
Only use integer values for event_ndims.
event_ndims have the semantics of being an integer. However, other code paths (such as const_value) can return back numpy wrapped arrays, which can mess with how values are cached. Instead extract everything as an integer. PiperOrigin-RevId: 195894216
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py44
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py26
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py49
-rw-r--r--tensorflow/python/ops/distributions/util.py24
5 files changed, 106 insertions, 47 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index ca20442c39..dc45114b1c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
@@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase):
-np.log(6, dtype=np.float32) - np.sum(x),
self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
+ def testChainIldjWithPlaceholder(self):
+ chain = Chain((Exp(), Exp()))
+ samples = array_ops.placeholder(
+ dtype=np.float32, shape=[None, 10], name="samples")
+ ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
+ self.assertTrue(ildj is not None)
+ with self.test_session():
+ ildj.eval({samples: np.zeros([2, 10], np.float32)})
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index 85ad23e413..b158a51bb0 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -20,10 +20,9 @@ from __future__ import print_function
import itertools
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
@@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims):
return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)
-def _maybe_get_event_ndims_statically(event_ndims):
- static_event_ndims = (event_ndims if isinstance(event_ndims, int)
- else tensor_util.constant_value(event_ndims))
- if static_event_ndims is not None:
- return static_event_ndims
-
- return event_ndims
-
-
def _compute_min_event_ndims(bijector_list, compute_forward=True):
"""Computes the min_event_ndims associated with the give list of bijectors.
@@ -238,13 +228,13 @@ class Chain(bijector.Bijector):
return y
def _inverse_log_det_jacobian(self, y, **kwargs):
- ildj = constant_op.constant(
- 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian")
+ y = ops.convert_to_tensor(y, name="y")
+ ildj = math_ops.cast(0., dtype=y.dtype.base_dtype)
if not self.bijectors:
return ildj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
@@ -258,11 +248,12 @@ class Chain(bijector.Bijector):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(
+ event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
@@ -274,13 +265,12 @@ class Chain(bijector.Bijector):
def _forward_log_det_jacobian(self, x, **kwargs):
x = ops.convert_to_tensor(x, name="x")
- fldj = constant_op.constant(
- 0., dtype=x.dtype, name="inverse_log_det_jacobian")
+ fldj = math_ops.cast(0., dtype=x.dtype.base_dtype)
if not self.bijectors:
return fldj
- event_ndims = _maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_event_ndims_statically(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
@@ -293,13 +283,21 @@ class Chain(bijector.Bijector):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = _maybe_get_event_ndims_statically(
- array_ops.rank(event_shape))
+ event_ndims = self._maybe_get_event_ndims_statically(
+ array_ops.size(event_shape))
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
+ def _maybe_get_event_ndims_statically(self, event_ndims):
+ event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
+ event_ndims)
+ if event_ndims_ is None:
+ return event_ndims
+ return event_ndims_
+
+
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index f54f146e0a..b9fe197679 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -147,6 +147,32 @@ class AssertCloseTest(test.TestCase):
array_ops.identity(w).eval(feed_dict=feed_dict)
+class MaybeGetStaticTest(test.TestCase):
+
+ def testGetStaticInt(self):
+ x = 2
+ self.assertEqual(x, du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticNumpyArray(self):
+ x = np.array(2, dtype=np.int32)
+ self.assertEqual(x, du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticConstant(self):
+ x = constant_op.constant(2, dtype=dtypes.int32)
+ self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x))
+ self.assertAllClose(
+ np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
+
+ def testGetStaticPlaceholder(self):
+ x = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ self.assertEqual(None, du.maybe_get_static_value(x))
+ self.assertEqual(None, du.maybe_get_static_value(x, dtype=np.float64))
+
+
@test_util.with_c_api
class GetLogitsAndProbsTest(test.TestCase):
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 36eee5ce78..caceadf53a 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -527,8 +528,6 @@ class Bijector(object):
ValueError: If a member of `graph_parents` is not a `Tensor`.
"""
self._graph_parents = graph_parents or []
- forward_min_event_ndims = get_static_value(forward_min_event_ndims)
- inverse_min_event_ndims = get_static_value(inverse_min_event_ndims)
if forward_min_event_ndims is None and inverse_min_event_ndims is None:
raise ValueError("Must specify at least one of `forward_min_event_ndims` "
@@ -538,12 +537,23 @@ class Bijector(object):
elif forward_min_event_ndims is None:
forward_min_event_ndims = inverse_min_event_ndims
+ if not isinstance(forward_min_event_ndims, int):
+ raise TypeError("Expected forward_min_event_ndims to be of "
+ "type int, got {}".format(
+ type(forward_min_event_ndims).__name__))
+
+ if not isinstance(inverse_min_event_ndims, int):
+ raise TypeError("Expected inverse_min_event_ndims to be of "
+ "type int, got {}".format(
+ type(inverse_min_event_ndims).__name__))
+
if forward_min_event_ndims < 0:
raise ValueError("forward_min_event_ndims must be a non-negative "
"integer.")
if inverse_min_event_ndims < 0:
raise ValueError("inverse_min_event_ndims must be a non-negative "
"integer.")
+
self._forward_min_event_ndims = forward_min_event_ndims
self._inverse_min_event_ndims = inverse_min_event_ndims
self._is_constant_jacobian = is_constant_jacobian
@@ -994,7 +1004,6 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
- assert_static(min_event_ndims)
if not self.is_constant_jacobian:
return math_ops.reduce_sum(
@@ -1012,7 +1021,7 @@ class Bijector(object):
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
# The multiplication by ones can change the inferred static shape so we try
# to recover as much as possible.
- event_ndims_ = get_static_value(event_ndims)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
if (event_ndims_ is not None and
y.shape.ndims is not None and
ildj.shape.ndims is not None):
@@ -1027,8 +1036,7 @@ class Bijector(object):
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
"""Compute the reduction dimensions given event_ndims."""
- assert_static(min_event_ndims)
- event_ndims_ = get_static_value(event_ndims, np.int32)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
if event_ndims_ is not None:
return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
@@ -1038,8 +1046,7 @@ class Bijector(object):
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
- assert_static(min_event_ndims)
- event_ndims_ = get_static_value(event_ndims, np.int32)
+ event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
assertions = []
if event_ndims_ is not None:
if min_event_ndims > event_ndims_:
@@ -1051,21 +1058,15 @@ class Bijector(object):
check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
return assertions
+ def _maybe_get_event_ndims_statically(self, event_ndims):
+ """Helper which returns tries to return an integer static value."""
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
-def get_static_value(x, dtype=None):
- """Helper which returns static value; casting when dtype is preferred."""
- if x is None:
- return x
- try:
- x_ = tensor_util.constant_value(x)
- except TypeError:
- x_ = x
- if x_ is None or dtype is None:
- return x_
- return np.array(x_, dtype)
-
+ if isinstance(event_ndims_, np.ndarray):
+ if (event_ndims_.dtype not in (np.int32, np.int64) or
+ len(event_ndims_.shape)):
+ raise ValueError("Expected a scalar integer, got {}".format(
+ event_ndims_))
+ event_ndims_ = event_ndims_.tolist()
-def assert_static(x):
- """Helper which asserts that input arg is known statically."""
- if x is None or type(x) != type(get_static_value(x)): # pylint: disable=unidiomatic-typecheck
- raise TypeError("Input must be known statically.")
+ return event_ndims_
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 2e067eab45..3afa85fda0 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -162,6 +162,30 @@ def same_dynamic_shape(a, b):
lambda: constant_op.constant(False))
+def maybe_get_static_value(x, dtype=None):
+ """Helper which tries to return a static value.
+
+ Given `x`, extract it's value statically, optionally casting to a specific
+ dtype. If this is not possible, None is returned.
+
+ Args:
+ x: `Tensor` for which to extract a value statically.
+ dtype: Optional dtype to cast to.
+
+ Returns:
+ Statically inferred value if possible, otherwise None.
+ """
+ if x is None:
+ return x
+ try:
+ x_ = tensor_util.constant_value(x)
+ except TypeError:
+ x_ = x
+ if x_ is None or dtype is None:
+ return x_
+ return np.array(x_, dtype)
+
+
def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,