diff options
author | 2018-05-08 16:20:02 -0700 | |
---|---|---|
committer | 2018-05-08 17:05:41 -0700 | |
commit | bbebae04db61e137e4013a031f429543422ae373 (patch) | |
tree | 88cc3a4f238b2b02165ebd5c46ebcc8a71d983ec | |
parent | 0028bf843d8846bd16b25bf5447b1649fde10fb7 (diff) |
Only use integer values for event_ndims.
event_ndims have the semantics of being an integer. However, other code paths (such as const_value)
can return back numpy wrapped arrays, which can mess with how values are cached. Instead extract
everything as an integer.
PiperOrigin-RevId: 195894216
5 files changed, 106 insertions, 47 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ca20442c39..dc45114b1c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase): -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 85ad23e413..b158a51bb0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,10 +20,9 @@ from __future__ import print_function import itertools -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector @@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) -def _maybe_get_event_ndims_statically(event_ndims): - static_event_ndims = (event_ndims if isinstance(event_ndims, int) - else tensor_util.constant_value(event_ndims)) - if static_event_ndims is not None: - return static_event_ndims - - return event_ndims - - def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -238,13 +228,13 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant( - 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) if not self.bijectors: return ildj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -258,11 +248,12 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically( + event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -274,13 +265,12 @@ class Chain(bijector.Bijector): def _forward_log_det_jacobian(self, x, **kwargs): x = ops.convert_to_tensor(x, name="x") - fldj = constant_op.constant( - 0., dtype=x.dtype, name="inverse_log_det_jacobian") + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) if not self.bijectors: return fldj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -293,13 +283,21 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) x = b.forward(x, **kwargs.get(b.name, {})) return fldj + def _maybe_get_event_ndims_statically(self, event_ndims): + event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( + event_ndims) + if event_ndims_ is None: + return event_ndims + return event_ndims_ + + diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index f54f146e0a..b9fe197679 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -147,6 +147,32 @@ class AssertCloseTest(test.TestCase): array_ops.identity(w).eval(feed_dict=feed_dict) +class MaybeGetStaticTest(test.TestCase): + + def testGetStaticInt(self): + x = 2 + self.assertEqual(x, du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticNumpyArray(self): + x = np.array(2, dtype=np.int32) + self.assertEqual(x, du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticConstant(self): + x = constant_op.constant(2, dtype=dtypes.int32) + self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticPlaceholder(self): + x = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + self.assertEqual(None, du.maybe_get_static_value(x)) + self.assertEqual(None, du.maybe_get_static_value(x, dtype=np.float64)) + + @test_util.with_c_api class GetLogitsAndProbsTest(test.TestCase): diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index 36eee5ce78..caceadf53a 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -527,8 +528,6 @@ class Bijector(object): ValueError: If a member of `graph_parents` is not a `Tensor`. """ self._graph_parents = graph_parents or [] - forward_min_event_ndims = get_static_value(forward_min_event_ndims) - inverse_min_event_ndims = get_static_value(inverse_min_event_ndims) if forward_min_event_ndims is None and inverse_min_event_ndims is None: raise ValueError("Must specify at least one of `forward_min_event_ndims` " @@ -538,12 +537,23 @@ class Bijector(object): elif forward_min_event_ndims is None: forward_min_event_ndims = inverse_min_event_ndims + if not isinstance(forward_min_event_ndims, int): + raise TypeError("Expected forward_min_event_ndims to be of " + "type int, got {}".format( + type(forward_min_event_ndims).__name__)) + + if not isinstance(inverse_min_event_ndims, int): + raise TypeError("Expected inverse_min_event_ndims to be of " + "type int, got {}".format( + type(inverse_min_event_ndims).__name__)) + if forward_min_event_ndims < 0: raise ValueError("forward_min_event_ndims must be a non-negative " "integer.") if inverse_min_event_ndims < 0: raise ValueError("inverse_min_event_ndims must be a non-negative " "integer.") + self._forward_min_event_ndims = forward_min_event_ndims self._inverse_min_event_ndims = inverse_min_event_ndims self._is_constant_jacobian = is_constant_jacobian @@ -994,7 +1004,6 @@ class Bijector(object): def _reduce_jacobian_det_over_event( self, y, ildj, min_event_ndims, event_ndims): """Reduce jacobian over event_ndims - min_event_ndims.""" - assert_static(min_event_ndims) if not self.is_constant_jacobian: return math_ops.reduce_sum( @@ -1012,7 +1021,7 @@ class Bijector(object): axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. - event_ndims_ = get_static_value(event_ndims) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): @@ -1027,8 +1036,7 @@ class Bijector(object): def _get_event_reduce_dims(self, min_event_ndims, event_ndims): """Compute the reduction dimensions given event_ndims.""" - assert_static(min_event_ndims) - event_ndims_ = get_static_value(event_ndims, np.int32) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) if event_ndims_ is not None: return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] @@ -1038,8 +1046,7 @@ class Bijector(object): def _check_valid_event_ndims(self, min_event_ndims, event_ndims): """Check whether event_ndims is atleast min_event_ndims.""" - assert_static(min_event_ndims) - event_ndims_ = get_static_value(event_ndims, np.int32) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) assertions = [] if event_ndims_ is not None: if min_event_ndims > event_ndims_: @@ -1051,21 +1058,15 @@ class Bijector(object): check_ops.assert_greater_equal(event_ndims, min_event_ndims)] return assertions + def _maybe_get_event_ndims_statically(self, event_ndims): + """Helper which returns tries to return an integer static value.""" + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) -def get_static_value(x, dtype=None): - """Helper which returns static value; casting when dtype is preferred.""" - if x is None: - return x - try: - x_ = tensor_util.constant_value(x) - except TypeError: - x_ = x - if x_ is None or dtype is None: - return x_ - return np.array(x_, dtype) - + if isinstance(event_ndims_, np.ndarray): + if (event_ndims_.dtype not in (np.int32, np.int64) or + len(event_ndims_.shape)): + raise ValueError("Expected a scalar integer, got {}".format( + event_ndims_)) + event_ndims_ = event_ndims_.tolist() -def assert_static(x): - """Helper which asserts that input arg is known statically.""" - if x is None or type(x) != type(get_static_value(x)): # pylint: disable=unidiomatic-typecheck - raise TypeError("Input must be known statically.") + return event_ndims_ diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 2e067eab45..3afa85fda0 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -162,6 +162,30 @@ def same_dynamic_shape(a, b): lambda: constant_op.constant(False)) +def maybe_get_static_value(x, dtype=None): + """Helper which tries to return a static value. + + Given `x`, extract it's value statically, optionally casting to a specific + dtype. If this is not possible, None is returned. + + Args: + x: `Tensor` for which to extract a value statically. + dtype: Optional dtype to cast to. + + Returns: + Statically inferred value if possible, otherwise None. + """ + if x is None: + return x + try: + x_ = tensor_util.constant_value(x) + except TypeError: + x_ = x + if x_ is None or dtype is None: + return x_ + return np.array(x_, dtype) + + def get_logits_and_probs(logits=None, probs=None, multidimensional=False, |