aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-19 08:15:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 09:33:09 -0700
commit1a5210752d444e7c0e6c2ab58ad034e7b736d573 (patch)
treee3f7933d959382004b7dcc33f60919c07e9a2735
parent94948d379919ada6a3521b44ce409df758a72bc8 (diff)
Implement multi-batch, multivariate ShapeScale bijector. Fine-tune ShapeUtil API.
Change: 130758781
-rw-r--r--tensorflow/contrib/distributions/BUILD11
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py270
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py79
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/shape_test.py440
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijector.py553
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py132
-rw-r--r--tensorflow/contrib/distributions/python/ops/shape.py600
7 files changed, 1519 insertions, 566 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 72d8629a43..3369437539 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -299,6 +299,17 @@ cuda_py_tests(
)
cuda_py_tests(
+ name = "distribution_util_test",
+ size = "small",
+ srcs = ["python/kernel_tests/distribution_util_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
name = "shape_test",
size = "small",
srcs = ["python/kernel_tests/shape_test.py"],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
index fd2cf58fd2..fe4ac93171 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
@@ -20,11 +20,12 @@ from __future__ import print_function
import math
+import numpy as np
import tensorflow as tf
-from tensorflow.contrib.distributions.python.ops.bijector import _Exp # pylint: disable=line-too-long
-from tensorflow.contrib.distributions.python.ops.bijector import _Identity # pylint: disable=line-too-long
-from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
+from tensorflow.contrib.distributions.python.ops.bijector import _Exp
+from tensorflow.contrib.distributions.python.ops.bijector import _Identity
+from tensorflow.contrib.distributions.python.ops.bijector import _ShiftAndScale
class IdentityBijectorTest(tf.test.TestCase):
@@ -32,16 +33,16 @@ class IdentityBijectorTest(tf.test.TestCase):
def testBijector(self):
with self.test_session():
- bijector = _Identity(_ShapeUtil(batch_ndims=1, event_ndims=1))
- self.assertEqual(bijector.name, 'Identity')
- x = [[[0.], [1]]]
- self.assertAllEqual(bijector.forward(x).eval(), x)
- self.assertAllEqual(bijector.inverse(x).eval(), x)
- self.assertAllEqual(bijector.inverse_log_det_jacobian(x).eval(),
- [[0., 0]])
+ bijector = _Identity()
+ self.assertEqual("Identity", bijector.name)
+ x = [[[0.],
+ [1.]]]
+ self.assertAllEqual(x, bijector.forward(x).eval())
+ self.assertAllEqual(x, bijector.inverse(x).eval())
+ self.assertAllEqual(0., bijector.inverse_log_det_jacobian(x).eval())
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
- self.assertAllEqual(rev.eval(), x)
- self.assertAllEqual(jac.eval(), [[0., 0]])
+ self.assertAllEqual(x, rev.eval())
+ self.assertAllEqual(0., jac.eval())
class ExpBijectorTest(tf.test.TestCase):
@@ -49,19 +50,240 @@ class ExpBijectorTest(tf.test.TestCase):
def testBijector(self):
with self.test_session():
- bijector = _Exp(_ShapeUtil(batch_ndims=1, event_ndims=1))
- self.assertEqual(bijector.name, 'Exp')
- x = [[[1.], [2]]]
- self.assertAllClose(bijector.forward(x).eval(),
- [[[math.exp(1.)], [math.exp(2.)]]])
- self.assertAllClose(bijector.inverse(x).eval(),
- [[[math.log(1.)], [math.log(2.)]]])
- self.assertAllClose(bijector.inverse_log_det_jacobian(x).eval(),
- [[0., -math.log(2.)]])
+ bijector = _Exp(event_ndims=1)
+ self.assertEqual("Exp", bijector.name)
+ x = [[[1.],
+ [2.]]]
+ self.assertAllClose(np.exp(x), bijector.forward(x).eval())
+ self.assertAllClose(np.log(x), bijector.inverse(x).eval())
+ self.assertAllClose([[0., -math.log(2.)]],
+ bijector.inverse_log_det_jacobian(x).eval())
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
- self.assertAllClose(rev.eval(), [[[math.log(1.)], [math.log(2.)]]])
- self.assertAllClose(jac.eval(), [[0., -math.log(2.)]])
+ self.assertAllClose(np.log(x), rev.eval())
+ self.assertAllClose([[0., -math.log(2.)]], jac.eval())
-if __name__ == '__main__':
+class _ShiftAndScaleBijectorTest(tf.test.TestCase):
+
+ def testProperties(self):
+ with self.test_session():
+ mu = -1.
+ sigma = 2.
+ bijector = _ShiftAndScale(loc=mu, scale=sigma)
+ self.assertEqual("ShiftAndScale", bijector.name)
+
+ def testNoBatchScalar(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = -1.
+ sigma = 2. # Scalar.
+ bijector = _ShiftAndScale(loc=mu, scale=sigma)
+ self.assertEqual(0, bijector.shaper.batch_ndims.eval()) # "no batches"
+ self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ x = [1., 2, 3] # Three scalar samples (no batches).
+ self.assertAllClose([1., 3, 5], run(bijector.forward, x))
+ self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
+ self.assertAllClose([-math.log(2.)],
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testWeirdSampleNoBatchScalar(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = -1.
+ sigma = 2. # Scalar.
+ bijector = _ShiftAndScale(loc=mu, scale=sigma)
+ self.assertEqual(0, bijector.shaper.batch_ndims.eval()) # "no batches"
+ self.assertEqual(0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ x = [[1., 2, 3],
+ [4, 5, 6]] # Weird sample shape.
+ self.assertAllClose([[1., 3, 5],
+ [7, 9, 11]],
+ run(bijector.forward, x))
+ self.assertAllClose([[1., 1.5, 2.],
+ [2.5, 3, 3.5]],
+ run(bijector.inverse, x))
+ self.assertAllClose([-math.log(2.)],
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testOneBatchScalar(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [1.]
+ sigma = [1.] # One batch, scalar.
+ bijector = _ShiftAndScale(loc=mu, scale=sigma)
+ self.assertEqual(
+ 1, bijector.shaper.batch_ndims.eval()) # "one batch dim"
+ self.assertEqual(
+ 0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ x = [1.] # One sample from one batches.
+ self.assertAllClose([2.], run(bijector.forward, x))
+ self.assertAllClose([0.], run(bijector.inverse, x))
+ self.assertAllClose([0.],
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testTwoBatchScalar(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [1., -1]
+ sigma = [1., 1] # Univariate, two batches.
+ bijector = _ShiftAndScale(loc=mu, scale=sigma)
+ self.assertEqual(
+ 1, bijector.shaper.batch_ndims.eval()) # "one batch dim"
+ self.assertEqual(
+ 0, bijector.shaper.event_ndims.eval()) # "is scalar"
+ x = [1., 1] # One sample from each of two batches.
+ self.assertAllClose([2., 0], run(bijector.forward, x))
+ self.assertAllClose([0., 2], run(bijector.inverse, x))
+ self.assertAllClose([0., 0],
+ run(bijector.inverse_log_det_jacobian, x))
+
+ def testNoBatchMultivariate(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [1., -1]
+ sigma = np.eye(2, dtype=np.float32)
+ bijector = _ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ self.assertEqual(0, bijector.shaper.batch_ndims.eval()) # "no batches"
+ self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ x = [1., 1]
+ self.assertAllClose([2., 0], run(bijector.forward, x))
+ self.assertAllClose([0., 2], run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+ x = [[1., 1],
+ [-1., -1]]
+ self.assertAllClose([[2., 0],
+ [0, -2]],
+ run(bijector.forward, x))
+ self.assertAllClose([[0., 2],
+ [-2., 0]],
+ run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+ # When mu is a scalar and x is multivariate then the location is
+ # broadcast.
+ for run in (static_run, dynamic_run):
+ mu = 1.
+ sigma = np.eye(2, dtype=np.float32)
+ bijector = _ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ self.assertEqual(0, bijector.shaper.batch_ndims.eval()) # "no batches"
+ self.assertEqual(1, bijector.shaper.event_ndims.eval()) # "is vector"
+ x = [1., 1]
+ self.assertAllClose([2., 2], run(bijector.forward, x))
+ self.assertAllClose([0., 0], run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+ x = [[1., 1]]
+ self.assertAllClose([[2., 2]], run(bijector.forward, x))
+ self.assertAllClose([[0., 0]], run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+ def testNoBatchMultivariateFullDynamic(self):
+ with self.test_session() as sess:
+ x = tf.placeholder(tf.float32, name="x")
+ mu = tf.placeholder(tf.float32, name="mu")
+ sigma = tf.placeholder(tf.float32, name="sigma")
+ event_ndims = tf.placeholder(tf.int32, name="event_ndims")
+
+ x_value = np.array([[1., 1]], dtype=np.float32)
+ mu_value = np.array([1., -1], dtype=np.float32)
+ sigma_value = np.eye(2, dtype=np.float32)
+ event_ndims_value = np.array(1, dtype=np.int32)
+ feed_dict = {x: x_value, mu: mu_value, sigma: sigma_value, event_ndims:
+ event_ndims_value}
+
+ bijector = _ShiftAndScale(loc=mu, scale=sigma, event_ndims=event_ndims)
+ self.assertEqual(0, sess.run(bijector.shaper.batch_ndims, feed_dict))
+ self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
+ self.assertAllClose([[2., 0]], sess.run(bijector.forward(x), feed_dict))
+ self.assertAllClose([[0., 2]], sess.run(bijector.inverse(x), feed_dict))
+ self.assertAllClose(
+ [0.], sess.run(bijector.inverse_log_det_jacobian(x), feed_dict))
+
+ def testBatchMultivariate(self):
+ with self.test_session() as sess:
+ def static_run(fun, x):
+ return fun(x).eval()
+
+ def dynamic_run(fun, x_value):
+ x_value = np.array(x_value, dtype=np.float32)
+ x = tf.placeholder(tf.float32, name="x")
+ return sess.run(fun(x), feed_dict={x: x_value})
+
+ for run in (static_run, dynamic_run):
+ mu = [[1., -1]]
+ sigma = np.array([np.eye(2, dtype=np.float32)])
+ bijector = _ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ self.assertEqual(
+ 1, bijector.shaper.batch_ndims.eval()) # "one batch dim"
+ self.assertEqual(
+ 1, bijector.shaper.event_ndims.eval()) # "is vector"
+ x = [[[1., 1]]]
+ self.assertAllClose([[[2., 0]]], run(bijector.forward, x))
+ self.assertAllClose([[[0., 2]]], run(bijector.inverse, x))
+ self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+
+ def testBatchMultivariateFullDynamic(self):
+ with self.test_session() as sess:
+ x = tf.placeholder(tf.float32, name="x")
+ mu = tf.placeholder(tf.float32, name="mu")
+ sigma = tf.placeholder(tf.float32, name="sigma")
+ event_ndims = tf.placeholder(tf.int32, name="event_ndims")
+
+ x_value = np.array([[[1., 1]]], dtype=np.float32)
+ mu_value = np.array([[1., -1]], dtype=np.float32)
+ sigma_value = np.array([np.eye(2, dtype=np.float32)])
+ event_ndims_value = np.array(1, dtype=np.int32)
+ feed_dict = {x: x_value, mu: mu_value, sigma: sigma_value,
+ event_ndims: event_ndims_value}
+
+ bijector = _ShiftAndScale(loc=mu, scale=sigma, event_ndims=event_ndims)
+ self.assertEqual(1, sess.run(bijector.shaper.batch_ndims, feed_dict))
+ self.assertEqual(1, sess.run(bijector.shaper.event_ndims, feed_dict))
+ self.assertAllClose([[[2., 0]]], sess.run(bijector.forward(x), feed_dict))
+ self.assertAllClose([[[0., 2]]], sess.run(bijector.inverse(x), feed_dict))
+ self.assertAllClose(
+ [0.], sess.run(bijector.inverse_log_det_jacobian(x), feed_dict))
+
+
+if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
new file mode 100644
index 0000000000..b76fc9afef
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -0,0 +1,79 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.distributions.python.ops import distribution_util
+
+
+class DistributionUtilTest(tf.test.TestCase):
+
+ def _np_rotate_transpose(self, x, shift):
+ if not isinstance(x, np.ndarray):
+ x = np.array(x)
+ return np.transpose(x, np.roll(np.arange(len(x.shape)), shift))
+
+ def testRollStatic(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(
+ ValueError, "None values not supported."):
+ distribution_util.rotate_transpose(None, 1)
+ for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
+ for shift in np.arange(-5, 5):
+ y = distribution_util.rotate_transpose(x, shift)
+ self.assertAllEqual(self._np_rotate_transpose(x, shift),
+ y.eval())
+ self.assertAllEqual(np.roll(x.shape, shift),
+ y.get_shape().as_list())
+
+ def testRollDynamic(self):
+ with self.test_session() as sess:
+ x = tf.placeholder(tf.float32)
+ shift = tf.placeholder(tf.int32)
+ for x_value in (np.ones(1, dtype=x.dtype.as_numpy_dtype()),
+ np.ones((2, 1), dtype=x.dtype.as_numpy_dtype()),
+ np.ones((3, 2, 1), dtype=x.dtype.as_numpy_dtype())):
+ for shift_value in np.arange(-5, 5):
+ self.assertAllEqual(
+ self._np_rotate_transpose(x_value, shift_value),
+ sess.run(distribution_util.rotate_transpose(x, shift),
+ feed_dict={x: x_value, shift: shift_value}))
+
+ def testChooseVector(self):
+ with self.test_session():
+ x = np.arange(10, 12)
+ y = np.arange(15, 18)
+ self.assertAllEqual(
+ x, distribution_util.pick_vector(
+ tf.less(0, 5), x, y).eval())
+ self.assertAllEqual(
+ y, distribution_util.pick_vector(
+ tf.less(5, 0), x, y).eval())
+ self.assertAllEqual(
+ x, distribution_util.pick_vector(
+ tf.constant(True), x, y)) # No eval.
+ self.assertAllEqual(
+ y, distribution_util.pick_vector(
+ tf.constant(False), x, y)) # No eval.
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
index 351c69c747..f4173f5d05 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py
@@ -20,145 +20,343 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
-from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
+from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
+from tensorflow.python.framework import tensor_util
-class ShapeUtilTest(tf.test.TestCase):
- def testShapeUtilGetNdims(self):
+_empty_shape = np.array([], dtype=np.int32)
+
+
+def _eval(x):
+ if hasattr(x, "__iter__"):
+ return [x.eval() for x in x]
+ return x.eval()
+
+
+def _constant(x):
+ if hasattr(x, "__iter__"):
+ return [tensor_util.constant_value(x) for x in x]
+ return tensor_util.constant_value(x)
+
+
+class DistributionShapeTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def _random_sample(self, sample_shape, dtype=tf.float64):
+ return self._rng.random_sample(sample_shape).astype(dtype.as_numpy_dtype())
+
+ def _assertNdArrayEqual(self, expected, actual):
+ """Helper which properly compares two np.ndarray-like objects.
+
+ This function checks for exact equality so is probably only suitable for
+ integers or powers of 2.
+
+ Args:
+ expected: np.ndarray. Ground-truth value.
+ actual: np.ndarray. Observed value.
+ """
+ expected = np.asarray(expected)
+ actual = np.asarray(actual)
+ self.assertEqual(
+ expected.shape, actual.shape,
+ "Shape mismatch: expected %s, got %s." % (expected.shape, actual.shape))
+ actual_item = actual.flat
+ for expected_item in expected.flat:
+ self.assertAllEqual(expected_item, next(actual_item))
+
+ def testDistributionShapeGetNdimsStatic(self):
with self.test_session():
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
x = 1
- self.assertEqual(shaper.get_sample_ndims(x), 0)
- self.assertEqual(shaper.batch_ndims, 0)
- self.assertEqual(shaper.event_ndims, 0)
-
- shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
- x = [[[0., 1, 2], [3, 4, 5]]]
- self.assertAllEqual(shaper.get_ndims(x), 3)
- self.assertEqual(shaper.get_sample_ndims(x), 1)
- self.assertEqual(shaper.batch_ndims, 1)
- self.assertEqual(shaper.event_ndims, 1)
-
- x += [[[6, 7, 8], [9, 10, 11]]]
- self.assertAllEqual(shaper.get_ndims(x), 3)
- self.assertEqual(shaper.get_sample_ndims(x), 1)
- self.assertEqual(shaper.batch_ndims, 1)
- self.assertEqual(shaper.event_ndims, 1)
+ self.assertEqual(0, shaper.get_sample_ndims(x).eval())
+ self.assertEqual(0, shaper.batch_ndims.eval())
+ self.assertEqual(0, shaper.event_ndims.eval())
+
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ x = self._random_sample((1, 2, 3))
+ self.assertAllEqual(3, shaper.get_ndims(x).eval())
+ self.assertEqual(1, shaper.get_sample_ndims(x).eval())
+ self.assertEqual(1, shaper.batch_ndims.eval())
+ self.assertEqual(1, shaper.event_ndims.eval())
+
+ x += self._random_sample((1, 2, 3))
+ self.assertAllEqual(3, shaper.get_ndims(x).eval())
+ self.assertEqual(1, shaper.get_sample_ndims(x).eval())
+ self.assertEqual(1, shaper.batch_ndims.eval())
+ self.assertEqual(1, shaper.event_ndims.eval())
# Test ndims functions work, even despite unfed Tensors.
y = tf.placeholder(tf.float32, shape=(1024, None, 1024))
- self.assertAllEqual(shaper.get_ndims(y), 3)
- self.assertEqual(shaper.get_sample_ndims(y), 1)
- self.assertEqual(shaper.batch_ndims, 1)
- self.assertEqual(shaper.event_ndims, 1)
+ self.assertEqual(3, shaper.get_ndims(y).eval())
+ self.assertEqual(1, shaper.get_sample_ndims(y).eval())
+ self.assertEqual(1, shaper.batch_ndims.eval())
+ self.assertEqual(1, shaper.event_ndims.eval())
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_ndims(y)
+ def testDistributionShapeGetNdimsDynamic(self):
+ with self.test_session() as sess:
+ batch_ndims = tf.placeholder(tf.int32)
+ event_ndims = tf.placeholder(tf.int32)
+ shaper = _DistributionShape(batch_ndims=batch_ndims,
+ event_ndims=event_ndims)
+ y = tf.placeholder(tf.float32)
+ y_value = np.ones((4, 2), dtype=y.dtype.as_numpy_dtype())
+ feed_dict = {y: y_value, batch_ndims: 1, event_ndims: 1}
+ self.assertEqual(2, sess.run(shaper.get_ndims(y),
+ feed_dict=feed_dict))
- def testShapeUtilGetDims(self):
+ def testDistributionShapeGetDimsStatic(self):
with self.test_session():
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_sample_dims(y)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_batch_dims(y)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_event_dims(y)
-
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
x = 1
- self.assertAllEqual(shaper.get_sample_dims(x), [])
- self.assertAllEqual(shaper.get_batch_dims(x), [])
- self.assertAllEqual(shaper.get_event_dims(x), [])
- self.assertAllEqual(shaper.get_dims(x, sample=False), [])
-
- shaper = _ShapeUtil(batch_ndims=1, event_ndims=2)
- x = [[[[0., 1], [2, 4]]]]
- self.assertAllEqual(shaper.get_sample_dims(x), [0])
- self.assertAllEqual(shaper.get_batch_dims(x), [1])
- self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
- self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
-
+ self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape),
+ _constant(shaper.get_dims(x)))
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=2)
+ x += self._random_sample((1, 1, 2, 2))
+ self._assertNdArrayEqual(
+ ([0], [1], [2, 3]),
+ _constant(shaper.get_dims(x)))
x += x
- self.assertAllEqual(shaper.get_sample_dims(x), [0])
- self.assertAllEqual(shaper.get_batch_dims(x), [1])
- self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
- self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
-
- # Test dims functions work, despite unfed Tensors.
- y = tf.placeholder(tf.float32, shape=(1024, None, 5, 5))
- self.assertAllEqual(shaper.get_sample_dims(y), [0])
- self.assertAllEqual(shaper.get_batch_dims(y), [1])
- self.assertAllEqual(shaper.get_event_dims(y), [2, 3])
-
- def testShapeUtilGetShape(self):
+ self._assertNdArrayEqual(
+ ([0], [1], [2, 3]),
+ _constant(shaper.get_dims(x)))
+
+ def testDistributionShapeGetDimsDynamic(self):
with self.test_session() as sess:
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_sample_shape(y)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_batch_shape(y)
- with self.assertRaises(ValueError):
- y = tf.placeholder(tf.float32)
- shaper.get_event_shape(y)
-
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
- x = 1
- self.assertAllEqual(shaper.get_sample_shape(x), [])
- self.assertAllEqual(shaper.get_batch_shape(x), [])
- self.assertAllEqual(shaper.get_event_shape(x), [])
- self.assertAllEqual(shaper.get_shape(x, batch=False), [])
-
- shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
- x = [[[0., 1, 2], [3, 4, 5]]]
- self.assertAllEqual(shaper.get_sample_shape(x), [1])
- self.assertAllEqual(shaper.get_batch_shape(x), [2])
- self.assertAllEqual(shaper.get_event_shape(x), [3])
- self.assertAllEqual(shaper.get_shape(x, batch=False), [1, 3])
-
- x += [[[6, 7, 8], [9, 10, 11]]]
- self.assertAllEqual(shaper.get_sample_shape(x), [2])
- self.assertAllEqual(shaper.get_batch_shape(x), [2])
- self.assertAllEqual(shaper.get_event_shape(x), [3])
- self.assertAllEqual(shaper.get_shape(x, batch=False), [2, 3])
-
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
- x = tf.ones((3, 2))
- self.assertAllEqual(shaper.get_shape(x, sample=False), (2,))
-
- def feed_eval(fun, build_shape=(None, None, 2), graph_shape=(3, 4, 2)):
- """Helper to use a deferred-shape tensor eval'ed at graph runtime."""
- y = tf.placeholder(tf.int32, shape=build_shape)
- y_value = np.ones(graph_shape, dtype=y.dtype.as_numpy_dtype())
- return sess.run(fun(y),
- feed_dict={y: y_value})
-
- shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
- self.assertAllEqual(feed_eval(shaper.get_sample_shape), [3])
- self.assertAllEqual(feed_eval(shaper.get_batch_shape), [4])
- self.assertAllEqual(feed_eval(shaper.get_event_shape), [2])
+ # Works for static {batch,event}_ndims despite unfed input.
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=2)
+ y = tf.placeholder(tf.float32, shape=(10, None, 5, 5))
+ self._assertNdArrayEqual([[0], [1], [2, 3]], _eval(shaper.get_dims(y)))
+
+ # Works for deferred {batch,event}_ndims.
+ batch_ndims = tf.placeholder(tf.int32)
+ event_ndims = tf.placeholder(tf.int32)
+ shaper = _DistributionShape(batch_ndims=batch_ndims,
+ event_ndims=event_ndims)
+ y = tf.placeholder(tf.float32)
+ y_value = self._random_sample((10, 3, 5, 5), dtype=y.dtype)
+ feed_dict = {y: y_value, batch_ndims: 1, event_ndims: 2}
+ self._assertNdArrayEqual(
+ ([0], [1], [2, 3]),
+ sess.run(shaper.get_dims(y), feed_dict=feed_dict))
+
+ def testDistributionShapeGetShapeStatic(self):
+ with self.test_session():
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
+ self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape),
+ _constant(shaper.get_shape(1.)))
+ self._assertNdArrayEqual(([1], _empty_shape, _empty_shape),
+ _constant(shaper.get_shape(np.ones(1))))
+ self._assertNdArrayEqual(([2, 2], _empty_shape, _empty_shape),
+ _constant(shaper.get_shape(np.ones((2, 2)))))
+ self._assertNdArrayEqual(([3, 2, 1], _empty_shape, _empty_shape),
+ _constant(shaper.get_shape(np.ones((3, 2, 1)))))
+
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=1)
+ with self.assertRaisesRegexp(ValueError, "expected .* <= ndims"):
+ shaper.get_shape(1.)
+ self._assertNdArrayEqual((_empty_shape, _empty_shape, [1]),
+ _constant(shaper.get_shape(np.ones(1))))
+ self._assertNdArrayEqual(([2], _empty_shape, [2]),
+ _constant(shaper.get_shape(np.ones((2, 2)))))
+ self._assertNdArrayEqual(([3, 2], _empty_shape, [1]),
+ _constant(shaper.get_shape(np.ones((3, 2, 1)))))
+
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=0)
+ with self.assertRaisesRegexp(ValueError, "expected .* <= ndims"):
+ shaper.get_shape(1.)
+ self._assertNdArrayEqual((_empty_shape, [1], _empty_shape),
+ _constant(shaper.get_shape(np.ones(1))))
+ self._assertNdArrayEqual(([2], [2], _empty_shape),
+ _constant(shaper.get_shape(np.ones((2, 2)))))
+ self._assertNdArrayEqual(([3, 2], [1], _empty_shape),
+ _constant(shaper.get_shape(np.ones((3, 2, 1)))))
+
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ with self.assertRaisesRegexp(ValueError, "expected .* <= ndims"):
+ shaper.get_shape(1.)
+ with self.assertRaisesRegexp(ValueError, "expected .* <= ndims"):
+ shaper.get_shape(np.ones(1))
+ self._assertNdArrayEqual((_empty_shape, [2], [2]),
+ _constant(shaper.get_shape(np.ones((2, 2)))))
+ self._assertNdArrayEqual(([3], [2], [1]),
+ _constant(shaper.get_shape(np.ones((3, 2, 1)))))
+
+ def testDistributionShapeGetShapeDynamic(self):
+ with self.test_session() as sess:
+ # Works for static ndims despite unknown static shape.
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ y = tf.placeholder(tf.int32, shape=(None, None, 2))
+ y_value = np.ones((3, 4, 2), dtype=y.dtype.as_numpy_dtype())
+ self._assertNdArrayEqual(
+ ([3], [4], [2]),
+ sess.run(shaper.get_shape(y), feed_dict={y: y_value}))
+
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=1)
+ y = tf.placeholder(tf.int32, shape=(None, None))
+ y_value = np.ones((3, 2), dtype=y.dtype.as_numpy_dtype())
+ self._assertNdArrayEqual(
+ ([3], _empty_shape, [2]),
+ sess.run(shaper.get_shape(y), feed_dict={y: y_value}))
+
+ # Works for deferred {batch,event}_ndims.
+ batch_ndims = tf.placeholder(tf.int32)
+ event_ndims = tf.placeholder(tf.int32)
+ shaper = _DistributionShape(batch_ndims=batch_ndims,
+ event_ndims=event_ndims)
+ y = tf.placeholder(tf.float32)
+ y_value = self._random_sample((3, 4, 2), dtype=y.dtype)
+ feed_dict = {y: y_value, batch_ndims: 1, event_ndims: 1}
+ self._assertNdArrayEqual(
+ ([3], [4], [2]),
+ sess.run(shaper.get_shape(y), feed_dict=feed_dict))
+
+ y_value = self._random_sample((3, 2), dtype=y.dtype)
+ feed_dict = {y: y_value, batch_ndims: 0, event_ndims: 1}
+ self._assertNdArrayEqual(
+ ([3], _empty_shape, [2]),
+ sess.run(shaper.get_shape(y), feed_dict=feed_dict))
+
+ def testDistributionShapeMakeBatchReadyStatic(self):
+ with self.test_session() as sess:
+ x = self._random_sample((1, 2, 3))
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(np.transpose(x, axes=(1, 2, 0)), y.eval())
+ self.assertAllEqual((1,), sample_shape.eval())
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x, should_be_x_value.eval())
+
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ x = tf.placeholder(tf.float32)
+ x_value = self._random_sample((3, 4, 2), dtype=x.dtype)
+ feed_dict = {x: x_value}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (3,),
+ sess.run(sample_shape, feed_dict=feed_dict))
+ self.assertAllClose(
+ np.transpose(np.reshape(x_value, (-1, 4, 2)), (1, 2, 0)),
+ sess.run(y, feed_dict=feed_dict),
+ rtol=1e-3)
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
+
+ shaper = _DistributionShape(batch_ndims=0, event_ndims=0)
+ x = tf.placeholder(tf.float32)
+ x_value = np.ones((3,), dtype=x.dtype.as_numpy_dtype())
+ feed_dict = {x: x_value}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (3,),
+ sess.run(sample_shape, feed_dict=feed_dict))
+ # The following check shows we don't need to manually set_shape in the
+ # ShapeUtil.
+ self.assertAllEqual((1, 1, None),
+ y.get_shape().ndims and y.get_shape().as_list())
+ self.assertAllEqual(
+ np.ones((1, 1, 3), dtype=x.dtype.as_numpy_dtype()),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
+
+ def testDistributionShapeMakeBatchReadyDynamic(self):
+ with self.test_session() as sess:
+ shaper = _DistributionShape(batch_ndims=1, event_ndims=1)
+ x = tf.placeholder(tf.float32, shape=(1, 2, 3))
+ x_value = self._random_sample(x.get_shape().as_list(), dtype=x.dtype)
+ y, sample_shape = sess.run(
+ shaper.make_batch_of_event_sample_matrices(x),
+ feed_dict={x: x_value})
+ self.assertAllEqual(np.transpose(x_value, (1, 2, 0)), y)
+ self.assertAllEqual((1,), sample_shape)
+
+ feed_dict = {x: x_value}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (1,),
+ sess.run(sample_shape, feed_dict=feed_dict))
+ self.assertAllEqual(
+ np.transpose(x_value, (1, 2, 0)),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
+
+ batch_ndims = tf.placeholder(tf.int32)
+ event_ndims = tf.placeholder(tf.int32)
+ shaper = _DistributionShape(batch_ndims=batch_ndims,
+ event_ndims=event_ndims)
+
+ # batch_ndims = 1, event_ndims = 1.
+ x = tf.placeholder(tf.float32)
+ x_value = np.ones((3, 4, 2), dtype=x.dtype.as_numpy_dtype())
+ feed_dict = {x: x_value, batch_ndims: 1, event_ndims: 1}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (3,),
+ sess.run(sample_shape, feed_dict=feed_dict))
+ self.assertAllEqual(
+ np.ones((4, 2, 3), dtype=x.dtype.as_numpy_dtype()),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
+
+ # batch_ndims = 0, event_ndims = 0.
+ x_value = np.ones((3,), dtype=x.dtype.as_numpy_dtype())
+ feed_dict = {x: x_value, batch_ndims: 0, event_ndims: 0}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (3,),
+ sess.run(sample_shape, feed_dict=feed_dict))
+ self.assertAllEqual(
+ np.ones((1, 1, 3), dtype=x.dtype.as_numpy_dtype()),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
+
+ # batch_ndims = 0, event_ndims = 1.
+ x_value = np.ones((1, 2,), dtype=x.dtype.as_numpy_dtype())
+ feed_dict = {x: x_value, batch_ndims: 0, event_ndims: 1}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
+ self.assertAllEqual(
+ (1,),
+ sess.run(sample_shape, feed_dict=feed_dict))
self.assertAllEqual(
- feed_eval(lambda y: shaper.get_shape(y, batch=False)),
- [3, 2])
+ np.ones((1, 2, 1), dtype=x.dtype.as_numpy_dtype()),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
- shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
+ # batch_ndims = 1, event_ndims = 0.
+ x_value = np.ones((1, 2), dtype=x.dtype.as_numpy_dtype())
+ feed_dict = {x: x_value, batch_ndims: 1, event_ndims: 0}
+ y, sample_shape = shaper.make_batch_of_event_sample_matrices(x)
self.assertAllEqual(
- feed_eval(lambda y: shaper.get_shape(y, batch=False),
- (None, None),
- (3, 2)),
- [3, 2])
+ (1,),
+ sess.run(sample_shape, feed_dict=feed_dict))
self.assertAllEqual(
- feed_eval(lambda y: shaper.get_shape(y, sample=False),
- (None, None),
- (3, 2)),
- [2])
+ np.ones((2, 1, 1), dtype=x.dtype.as_numpy_dtype()),
+ sess.run(y, feed_dict=feed_dict))
+ should_be_x_value = shaper.undo_make_batch_of_event_sample_matrices(
+ y, sample_shape)
+ self.assertAllEqual(x_value, sess.run(should_be_x_value,
+ feed_dict=feed_dict))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index 41e0e6f73c..c74fdc79d4 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -17,35 +17,73 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
+
+from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
class _Bijector(object):
- """An interface for transforming random variable(s).
+ """An interface for transforming a `Distribution` `Tensor`.
+
+ Recall that a `Distribution` `Tensor` has dimensions which have `sample`,
+ `batch`, and `event` semantics. (See `DistributionShape` for more details.)
- A bijector is characterized by three operations:
+ A `Bijector` implements a bijective, differentiable function by transforming
+ an input `Tensor`. The output `Tensor` shape is constrained by the input
+ `sample`, `batch`, and `event` shape. A `Bijector` is characterized by three
+ operations:
- 1) Forward Evaluation
- Useful for turning one random outcome into another random outcome from a
- different distribution.
+ (1) Forward Evaluation
+ Useful for turning one random outcome into another random outcome from a
+ different distribution.
- 2) Inverse Evaluation
- Useful for "reversing" a transformation to compute one probability in terms
- of another.
+ (2) Inverse Evaluation
+ Useful for "reversing" a transformation to compute one probability in
+ terms of another.
- 3) (log o det o Jacobian o inverse)(x)
- "The log of the determinant of the matrix of all first-order partial
- derivatives of the inverse function."
- Useful for inverting a transformation to compute one probability in terms
- of another. Geometrically, the det(Jacobian) is the volume of the
- transformation and is used to scale the probability.
+ (3) (log o det o Jacobian o inverse)(x)
+ "The log of the determinant of the matrix of all first-order partial
+ derivatives of the inverse function."
+ Useful for inverting a transformation to compute one probability in terms
+ of another. Geometrically, the det(Jacobian) is the volume of the
+ transformation and is used to scale the probability.
By convention, transformations of random variables are named in terms of the
forward transformation. The forward transformation creates samples, the
inverse is useful for computing probabilities.
+ Example Use:
+ Basic properties:
+
+ ```python
+ x = ... # A tensor.
+ # Evaluate forward transformation.
+ fwd_x = my_bijector.forward(x)
+ x == my_bijector.inverse(fwd_x)
+ x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)).
+ ```
+
+ Computing a log-likelihood:
+
+ ```python
+ def transformed_log_pdf(bijector, log_pdf, x):
+ return (bijector.inverse_log_det_jacobian(x) +
+ log_pdf(bijector.inverse(x)))
+ ```
+
+ Transforming a random outcome:
+
+ ```python
+ def transformed_sample(bijector, x):
+ return bijector.forward(x)
+ ```
+
Example transformations:
"Exponential"
@@ -82,60 +120,83 @@ class _Bijector(object):
MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
```
- Example use:
- Basic properties:
-
- ```python
- x = ... # A tensor.
- # Evaluate forward transformation.
- fwd_x = my_bijector.forward(x)
- x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)).
- x == my_bijector.inverse(fwd_x)
- ```
-
- Computing a log-likelihood:
-
- ```python
- def transformed_log_pdf(bijector, log_pdf, x):
- return (bijector.inverse_log_det_jacobian(x) +
- log_pdf(bijector.inverse(x)))
- ```
-
- Transforming a random outcome:
-
- ```python
- def transformed_sample(bijector, x):
- return bijector.forward(x)
- ```
-
+ Example of why a `Bijector` needs to understand sample, batch, event
+ partitioning:
+ Consider the `Exp` `Bijector` applied to a `Tensor` which has sample, batch,
+ and event (S, B, E) shape semantics. Suppose
+ the `Tensor`'s partitioned-shape is `(S=[4], B=[2], E=[3, 3])`.
+
+ For `Exp`, the shape of the `Tensor` returned by `forward` and `inverse` is
+ unchanged, i.e., `[4, 2, 3, 3]`. However the shape returned by
+ `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction
+ over the event dimensions.
+
+ Subclass Requirements:
+ Subclasses are expected to implement `_forward` and one or both of:
+ - `_inverse`, `_inverse_log_det_jacobian`,
+ - `_inverse_and_inverse_log_det_jacobian`.
+
+ If computation can be shared among `_inverse` and
+ `_inverse_log_det_jacobian` it is preferable to implement
+ `_inverse_and_inverse_log_det_jacobian`. This usually reduces
+ graph-construction overhead because a `Distribution`'s implementation of
+ `log_prob` will need to evaluate both the inverse Jacobian as well as the
+ inverse function.
+
+ If an additional use case needs just `inverse` or just
+ `inverse_log_det_jacobian` then he or she may also wish to implement these
+ functions to avoid computing the `inverse_log_det_jacobian` or the
+ `inverse`, respectively.
"""
- # TODO(b/30476956): Try to remove constructor dependence on shape util.
- def __init__(self, shaper=None, name=None):
+ # TODO(b/30476956): Try to remove constructor dependence on ndims.
+ def __init__(self,
+ batch_ndims=None,
+ event_ndims=None,
+ parameters=None,
+ is_constant_jacobian=False,
+ validate_args=True,
+ dtype=None,
+ name=None):
"""Constructs Bijector.
- A bijector transforms random variables into new random variables. Managing
- shape is typically an important piece of this so a Bijector is usually
- composed of ShapeUtil. The ShapeUtil object handles input shape checks as
- well as reshaping/transposing for easier linear algebra operations.
+ A `Bijector` transforms random variables into new random variables.
+
+ Examples:
- Example:
```python
# Create the Y = g(X) = X transform which operates on 4-Tensors of vectors.
- identity = Identity(ShapeUtil(batch_ndims=4, event_ndims=1))
+ identity = Identity(batch_ndims=4, event_ndims=1)
# Create the Y = g(X) = exp(X) transform which operates on matrices.
- exp = Exp(ShapeUtil(batch_ndims=0, event_ndims=2))
+ exp = Exp(batch_ndims=0, event_ndims=2)
```
- See Bijector subclass doc for more details and examples.
+ See `Bijector` subclass docstring for more details and specific examples.
Args:
- shaper: object used for managing and manipulating shape, typically an
- instance of ShapeUtil.
+ batch_ndims: number of dimensions associated with batch coordinates.
+ event_ndims: number of dimensions associated with event coordinates.
+ parameters: Dictionary of parameters used by this `Bijector`
+ is_constant_jacobian: `Boolean` indicating that the Jacobian is not a
+ function of the input.
+ validate_args: `Boolean`. If true, Tensor arguments are
+ checked for correctness. (Non-tensor arguments are always checked.)
+ dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
+ enforced.
name: The name to give Ops created by the initializer.
"""
- self._shaper = shaper
+ if batch_ndims is None or event_ndims is None:
+ self._shaper = None # Apparently subclass will create.
+ else:
+ self._shaper = _DistributionShape(
+ batch_ndims=batch_ndims,
+ event_ndims=event_ndims,
+ validate_args=validate_args)
+ self._parameters = parameters or {}
+ self._is_constant_jacobian = is_constant_jacobian
+ self._validate_args = validate_args
+ self._dtype = dtype
self._name = name or type(self).__name__
@property
@@ -144,12 +205,38 @@ class _Bijector(object):
return self._shaper
@property
+ def parameters(self):
+ """Returns this `Bijector`'s parameters as a name/value dictionary."""
+ return self._parameters
+
+ @property
+ def is_constant_jacobian(self):
+ """Returns true iff the Jacobian is not a function of x.
+
+ Note: Jacobian is either constant for both forward and inverse or neither.
+
+ Returns:
+ `Boolean`.
+ """
+ return self._is_constant_jacobian
+
+ @property
+ def validate_args(self):
+ """Returns True if Tensor arguments will be validated."""
+ return self._validate_args
+
+ @property
+ def dtype(self):
+ """dtype of `Tensor`s transformable by this distribution."""
+ return self._dtype
+
+ @property
def name(self):
- """Returns the string name of this bijector."""
+ """Returns the string name of this `Bijector`."""
return self._name
- def forward(self, x, name='forward'):
- """Returns the forward bijector evaluation, i.e., X = g(Y).
+ def forward(self, x, name="forward"):
+ """Returns the forward `Bijector` evaluation, i.e., X = g(Y).
Args:
x: `Tensor`. The input to the "forward" evaluation.
@@ -157,14 +244,19 @@ class _Bijector(object):
Returns:
`Tensor`.
+
+ Raises:
+ TypeError: if `self.dtype` is specified and `x.dtype` is not
+ `self.dtype`.
+ AttributeError: if `_forward` is not implemented.
"""
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x)
- return self._forward(x)
+ with self._name_scope(name, [x]):
+ x = ops.convert_to_tensor(x, name="x")
+ self._maybe_assert_dtype(x)
+ return self._forward(x)
- def inverse(self, x, name='inverse'):
- """Returns the inverse bijector evaluation, i.e., X = g^{-1}(Y).
+ def inverse(self, x, name="inverse"):
+ """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
Args:
x: `Tensor`. The input to the "inverse" evaluation.
@@ -172,37 +264,56 @@ class _Bijector(object):
Returns:
`Tensor`.
+
+ Raises:
+ TypeError: if `self.dtype` is specified and `x.dtype` is not
+ `self.dtype`.
+ AttributeError: if neither `_inverse` nor
+ `_inverse_and_inverse_log_det_jacobian` are implemented.
"""
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x)
- try:
- return self._inverse(x)
- except NotImplementedError:
- return self._inverse_and_inverse_log_det_jacobian(x)[0]
-
- def inverse_log_det_jacobian(self, x, name='inverse_log_det_jacobian'):
+ with self._name_scope(name, [x]):
+ x = ops.convert_to_tensor(x, name="x")
+ self._maybe_assert_dtype(x)
+ try:
+ return self._inverse(x)
+ except AttributeError:
+ # Since _inverse was not implemented, try to see if it's implemented
+ # by the _inverse_and_inverse_log_det_jacobian member.
+ return self._inverse_and_inverse_log_det_jacobian(x)[0]
+
+ def inverse_log_det_jacobian(self, x, name="inverse_log_det_jacobian"):
"""Returns the (log o det o Jacobian o inverse)(x).
Mathematically, returns: log(det(dY/dX g^{-1}))(Y).
+ Note that forward_log_det_jacobian is the negative of this function. (See
+ is_constant_jacobian for related proof.)
+
Args:
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
name: The name to give this op.
Returns:
`Tensor`.
+
+ Raises:
+ TypeError: if `self.dtype` is specified and `x.dtype` is not
+ `self.dtype`.
+ AttributeError: if neither `_inverse_log_det_jacobian` nor
+ `_inverse_and_inverse_log_det_jacobian` are implemented.
"""
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x)
- try:
- return self._inverse_log_det_jacobian(x)
- except NotImplementedError:
- return self._inverse_and_inverse_log_det_jacobian(x)[1]
+ with self._name_scope(name, [x]):
+ x = ops.convert_to_tensor(x, name="x")
+ self._maybe_assert_dtype(x)
+ try:
+ return self._inverse_log_det_jacobian(x)
+ except AttributeError:
+ # Since _inverse_log_det_jacobian was not implemented, try to see if
+ # it's implemented by the _inverse_and_inverse_log_det_jacobian member.
+ return self._inverse_and_inverse_log_det_jacobian(x)[1]
def inverse_and_inverse_log_det_jacobian(
- self, x, name='inverse_and_inverse_log_det_jacobian'):
+ self, x, name="inverse_and_inverse_log_det_jacobian"):
"""Returns both the inverse evaluation and inverse_log_det_jacobian.
Enables possibly more efficient calculation when both inverse and
@@ -216,79 +327,48 @@ class _Bijector(object):
Returns:
`Tensor`.
- """
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x)
- try:
- return self._inverse_and_inverse_log_det_jacobian(x)
- except NotImplementedError:
- return self._inverse(x), self._inverse_log_det_jacobian(x)
-
- # Subclass interface.
- def _forward(self, x):
- """Subclass implementation of forward().
-
- Args:
- x: `Tensor`. The input to the "forward" evaluation.
-
- Raises:
- `NotImplementedError`: if subclass implementation not provided
-
- Returns:
- `Tensor`.
- """
- raise NotImplementedError('_forward not implemented')
-
- def _inverse(self, x):
- """Subclass implementation of inverse().
-
- Args:
- x: `Tensor`. The input to the "inverse" evaluation.
Raises:
- `NotImplementedError`: if subclass implementation not provided
-
- Returns:
- `Tensor`.
+ TypeError: if `self.dtype` is specified and `x.dtype` is not
+ `self.dtype`.
+ AttributeError: if neither `_inverse_and_inverse_log_det_jacobian` nor
+ {`_inverse`, `_inverse_log_det_jacobian`} are implemented.
"""
- raise NotImplementedError('_inverse not implemented')
-
- def _inverse_log_det_jacobian(self, x):
- """Subclass implementation of inverse_log_det_jacobian().
-
- Args:
- x: `Tensor`. The input to the "inverse" Jacobian evaluation.
-
- Raises:
- `NotImplementedError`: if subclass implementation not provided
-
- Returns:
- `Tensor`.
- """
- raise NotImplementedError('_inverse_log_det_jacobian not implemented')
-
- def _inverse_and_inverse_log_det_jacobian(self, x):
- """Subclass implementation of inverse_and_inverse_log_det_jacobian().
-
- Args:
- x: `Tensor`. The input to the "inverse" evaluation.
+ with self._name_scope(name, [x]):
+ x = ops.convert_to_tensor(x, name="x")
+ self._maybe_assert_dtype(x)
+ try:
+ return self._inverse_and_inverse_log_det_jacobian(x)
+ except AttributeError:
+ # Since _inverse_and_inverse_log_det_jacobian was not implemented, try
+ # to see if we can separately use _inverse and
+ # _inverse_log_det_jacobian members.
+ return self._inverse(x), self._inverse_log_det_jacobian(x)
+
+ @contextlib.contextmanager
+ def _name_scope(self, name=None, values=None):
+ """Helper function to standardize op scope."""
+ with ops.name_scope(self.name):
+ with ops.name_scope(name, values=(
+ (values or []) + list(self.parameters.values()))) as scope:
+ yield scope
- Returns:
- List of two `Tensor` items, inverse and inverse_log_det_jacobian.
- """
- raise NotImplementedError(
- '_inverse_and_inverse_log_det_jacobian not implemented')
+ def _maybe_assert_dtype(self, x):
+ """Helper to check dtype when self.dtype is known."""
+ if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
+ raise TypeError("Input had dtype %s but expected %s." %
+ (self.dtype, x.dtype))
class _Identity(_Bijector):
"""Bijector which computes Y = g(X) = X.
Example Use:
+
```python
- # Create the Y=g(X)=X transform which works only on Tensors with 1 batch
- # ndims and 1 event ndim (i.e., vector of vectors).
- identity = Identity(ShapeUtil(batch_ndims=1, event_ndims=1))
+ # Create the Y=g(X)=X transform which is intended for Tensors with 1 batch
+ # ndim and 1 event ndim (i.e., vector of vectors).
+ identity = Identity(batch_ndims=1, event_ndims=1)
x = [[1., 2],
[3, 4]]
x == identity.forward(x) == identity.inverse(x)
@@ -296,9 +376,14 @@ class _Identity(_Bijector):
"""
- # TODO(b/30476956): Try to remove constructor dependence on shape util.
- def __init__(self, shaper=None, name='Identity'):
- super(_Identity, self).__init__(shaper, name)
+ def __init__(self, validate_args=True, name="Identity"):
+ super(_Identity, self).__init__(
+ batch_ndims=0,
+ event_ndims=0,
+ is_constant_jacobian=True,
+ validate_args=validate_args,
+ name=name)
+ self._is_constant_jacobian = True
def _forward(self, x):
return x
@@ -307,19 +392,18 @@ class _Identity(_Bijector):
return x
def _inverse_log_det_jacobian(self, x):
- result_shape = self.shaper.get_shape(
- x, sample=True, batch=True, event=False)
- return array_ops.zeros(result_shape, dtype=x.dtype)
+ return constant_op.constant(0., dtype=x.dtype)
class _Exp(_Bijector):
"""Bijector which computes Y = g(X) = exp(X).
Example Use:
+
```python
# Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
- # batch ndims and 2 event ndim (i.e., vector of matrices).
- exp = Exp(ShapeUtil(batch_ndims=1, event_ndims=2))
+ # batch ndim and 2 event ndims (i.e., vector of matrices).
+ exp = Exp(batch_ndims=1, event_ndims=2)
x = [[[1., 2],
[3, 4]],
[[5, 6],
@@ -328,11 +412,20 @@ class _Exp(_Bijector):
log(x) == exp.inverse(x)
```
+ Note: the exp(.) is applied element-wise but the Jacobian is a reduction
+ over the event space.
"""
- # TODO(b/30476956): Try to remove constructor dependence on shape util.
- def __init__(self, shaper=None, name='Exp'):
- super(_Exp, self).__init__(shaper, name)
+ # TODO(b/30476956): Try to remove constructor dependence on ndims.
+ def __init__(self,
+ event_ndims=0,
+ validate_args=True,
+ name="Exp"):
+ super(_Exp, self).__init__(
+ batch_ndims=0,
+ event_ndims=event_ndims,
+ validate_args=validate_args,
+ name=name)
def _forward(self, x):
return math_ops.exp(x)
@@ -341,10 +434,158 @@ class _Exp(_Bijector):
return math_ops.log(x)
def _inverse_log_det_jacobian(self, x):
- d = self.shaper.get_event_dims(x)
- return -math_ops.reduce_sum(math_ops.log(x), d)
+ if self.shaper is None:
+ raise ValueError("Jacobian cannot be computed with unknown event_ndims")
+ _, _, event_dims = self.shaper.get_dims(x)
+ return -math_ops.reduce_sum(math_ops.log(x), reduction_indices=event_dims)
def _inverse_and_inverse_log_det_jacobian(self, x):
+ if self.shaper is None:
+ raise ValueError("Jacobian cannot be computed with unknown event_ndims")
y = math_ops.log(x)
- d = self.shaper.get_event_dims(x)
- return y, -math_ops.reduce_sum(y, d)
+ _, _, event_dims = self.shaper.get_dims(x)
+ return y, -math_ops.reduce_sum(y, reduction_indices=event_dims)
+
+
+class _ShiftAndScale(_Bijector):
+ """Bijector which computes Y = g(X; loc, scale) = scale * X + loc.
+
+ Example Use:
+
+ ```python
+ # No batch, scalar.
+ mu = 0 # shape=[]
+ sigma = 1 # shape=[]
+ b = ShiftAndScale(loc=mu, scale=sigma)
+ # b.shaper.batch_ndims == 0
+ # b.shaper.event_ndims == 0
+
+ # One batch, scalar.
+ mu = ... # shape=[b], b>0
+ sigma = ... # shape=[b], b>0
+ b = ShiftAndScale(loc=mu, scale=sigma)
+ # b.shaper.batch_ndims == 1
+ # b.shaper.event_ndims == 0
+
+ # No batch, multivariate.
+ mu = ... # shape=[d], d>0
+ sigma = ... # shape=[d, d], d>0
+ b = ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ # b.shaper.batch_ndims == 0
+ # b.shaper.event_ndims == 1
+
+ # (B1*B2*...*Bb)-batch, multivariate.
+ mu = ... # shape=[B1,...,Bb, d], b>0, d>0
+ sigma = ... # shape=[B1,...,Bb, d, d], b>0, d>0
+ b = ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ # b.shaper.batch_ndims == b
+ # b.shaper.event_ndims == 1
+
+ # Mu is broadcast:
+ mu = 1
+ sigma = [I, I] # I is a 3x3 identity matrix.
+ b = ShiftAndScale(loc=mu, scale=sigma, event_ndims=1)
+ x = numpy.ones(S + sigma.shape)
+ b.forward(x) # == x + 1
+ ```
+
+ """
+
+ def __init__(self,
+ loc,
+ scale,
+ event_ndims=0,
+ validate_args=True,
+ name="ShiftAndScale"):
+ self._parameters = {}
+ self._name = name
+ with self._name_scope("init", values=[loc, scale, event_ndims]):
+ self._loc = ops.convert_to_tensor(loc, name="loc")
+ self._scale = ops.convert_to_tensor(scale, name="scale")
+ event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ if self.loc.dtype.base_dtype != self.scale.dtype.base_dtype:
+ raise TypeError("%s.dtype=%s does not match %s.dtype=%s" %
+ (self.loc.name, self.loc.dtype, self.scale.name,
+ self.scale.dtype))
+ if event_ndims.dtype.base_dtype != dtypes.int32.base_dtype:
+ raise TypeError("%s.dtype=%s does not match %s" %
+ (event_ndims.name, event_ndims.dtype, dtypes.int32))
+ self._scale, batch_ndims = self._process_scale(self.scale, event_ndims)
+ super(_ShiftAndScale, self).__init__(
+ batch_ndims=batch_ndims,
+ event_ndims=event_ndims,
+ parameters={"loc": self.loc, "scale": self.scale},
+ is_constant_jacobian=True,
+ validate_args=validate_args,
+ name=name)
+
+ def _process_scale(self, scale, event_ndims):
+ """Helper to __init__ which gets scale in batch-ready form.
+
+ This function expands dimensions of `scale` according to the following
+ table:
+ event_ndims
+ scale.ndims 0 1
+ 0 [1]+S+[1,1] "silent error"
+ 1 [ ]+S+[1,1] "silent error"
+ 2 [ ]+S+[1,1] [1]+S+[ ]
+ 3 [ ]+S+[1,1] [ ]+S+[ ]
+ ... (same) (same)
+
+ The idea is that we want to convert `scale` into something which can always
+ work for, say, the left-hand argument of `batch_matmul`.
+
+ Args:
+ scale: `Tensor`.
+ event_ndims: `Tensor` (0D, `int32`).
+
+ Returns:
+ scale: `Tensor` with dims expanded according to [above] table.
+ batch_ndims: `Tensor` (0D, `int32`). The ndims of the `batch` portion.
+ """
+ ndims = array_ops.rank(scale)
+ left = math_ops.select(
+ math_ops.reduce_any([
+ math_ops.reduce_all([
+ math_ops.equal(ndims, 0),
+ math_ops.equal(event_ndims, 0)
+ ]),
+ math_ops.reduce_all([
+ math_ops.equal(ndims, 2),
+ math_ops.equal(event_ndims, 1)
+ ])]), 1, 0)
+ right = math_ops.select(math_ops.equal(event_ndims, 0), 2, 0)
+ pad = array_ops.concat(0, (
+ array_ops.ones([left], dtype=dtypes.int32),
+ array_ops.shape(scale),
+ array_ops.ones([right], dtype=dtypes.int32)))
+ scale = array_ops.reshape(scale, pad)
+ batch_ndims = ndims - 2 + right
+ return scale, batch_ndims
+
+ @property
+ def loc(self):
+ return self._loc
+
+ @property
+ def scale(self):
+ return self._scale
+
+ def _forward(self, x):
+ x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
+ x = math_ops.batch_matmul(self.scale, x)
+ x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
+ x += self.loc
+ return x
+
+ def _inverse(self, x):
+ x -= self.loc
+ x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
+ x = linalg_ops.batch_matrix_triangular_solve(self.scale, x)
+ x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
+ return x
+
+ def _inverse_log_det_jacobian(self, x): # pylint: disable=unused-argument
+ return -math_ops.reduce_sum(
+ math_ops.log(array_ops.batch_matrix_diag_part(self.scale)),
+ reduction_indices=[-1])
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 0cc12deded..c6386b905b 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -242,3 +244,133 @@ def batch_matrix_diag_transform(matrix, transform=None, name=None):
transformed_mat = array_ops.batch_matrix_set_diag(matrix, transformed_diag)
return transformed_mat
+
+
+def rotate_transpose(x, shift, name="rotate_transpose"):
+ """Circularly moves dims left or right.
+
+ Effectively identical to:
+
+ ```python
+ numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
+ ```
+
+ When `validate_args=True` additional graph-runtime checks are
+ performed. These checks entail moving data from to GPU to CPU.
+
+ Example:
+
+ ```python
+ x = ... # Tensor of shape [1, 2, 3, 4].
+ rotate_transpose(x, -1) # result shape: [2, 3, 4, 1]
+ rotate_transpose(x, -2) # result shape: [3, 4, 1, 2]
+ rotate_transpose(x, 1) # result shape: [4, 1, 2, 3]
+ rotate_transpose(x, 2) # result shape: [3, 4, 1, 2]
+ rotate_transpose(x, 7) == rotate_transpose(x, 3)
+ rotate_transpose(x, -7) == rotate_transpose(x, -3)
+ ```
+
+ Args:
+ x: `Tensor`.
+ shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
+ transpose right (shift>0).
+ name: `String`. The name to give this op.
+
+ Returns:
+ rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
+
+ Raises:
+ TypeError: if shift is not integer type.
+ """
+ with ops.name_scope(name, values=[x, shift]):
+ x = ops.convert_to_tensor(x, name="x")
+ shift = ops.convert_to_tensor(shift, name="shift")
+ # We do not assign back to preserve constant-ness.
+ check_ops.assert_integer(shift)
+ shift_value_static = tensor_util.constant_value(shift)
+ ndims = x.get_shape().ndims
+ if ndims is not None and shift_value_static is not None:
+ if ndims < 2: return x
+ shift_value_static = np.sign(shift_value_static) * (
+ abs(shift_value_static) % ndims)
+ if shift_value_static == 0: return x
+ perm = np.roll(np.arange(ndims), shift_value_static)
+ return array_ops.transpose(x, perm=perm)
+ else:
+ # Consider if we always had a positive shift, and some specified
+ # direction.
+ # When shifting left we want the new array:
+ # last(x, n-shift) + first(x, shift)
+ # and if shifting right then we want:
+ # last(x, shift) + first(x, n-shift)
+ # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
+ # Also, we can encode direction and shift as one: direction * shift.
+ # Combining these facts, we have:
+ # a = cond(shift<0, -shift, n-shift)
+ # last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
+ # Finally, we transform shift by modulo length so it can be specified
+ # independently from the array upon which it operates (like python).
+ ndims = array_ops.rank(x)
+ shift = math_ops.select(math_ops.less(shift, 0),
+ math_ops.mod(-shift, ndims),
+ ndims - math_ops.mod(shift, ndims))
+ first = math_ops.range(0, shift)
+ last = math_ops.range(shift, ndims)
+ perm = array_ops.concat(0, (last, first))
+ return array_ops.transpose(x, perm=perm)
+
+
+def pick_vector(cond,
+ true_vector,
+ false_vector,
+ name="pick_vector"):
+ """Picks possibly different length row `Tensor`s based on condition.
+
+ Value `Tensor`s should have exactly one dimension.
+
+ If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
+ `false_vector` is immediately returned. I.e., no graph nodes are created and
+ no validation happens.
+
+ Args:
+ cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
+ true_vector: `Tensor` of one dimension. Returned when cond is `True`.
+ false_vector: `Tensor` of one dimension. Returned when cond is `False`.
+ name: `String`. The name to give this op.
+
+ Example:
+
+ ```python
+ pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18))
+ # result is tensor: [10, 11].
+ pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18))
+ # result is tensor: [15, 16, 17].
+ ```
+
+ Returns:
+ true_or_false_vector: `Tensor`.
+
+ Raises:
+ TypeError: if `cond.dtype != tf.bool`
+ TypeError: if `cond` is not a constant and
+ `true_vector.dtype != false_vector.dtype`
+ """
+ with ops.op_scope((cond, true_vector, false_vector), name):
+ cond = ops.convert_to_tensor(cond, name="cond")
+ if cond.dtype != dtypes.bool:
+ raise TypeError("%s.dtype=%s which is not %s" %
+ (cond.name, cond.dtype, dtypes.bool))
+ cond_value_static = tensor_util.constant_value(cond)
+ if cond_value_static is not None:
+ return true_vector if cond_value_static else false_vector
+ true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
+ false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
+ if true_vector.dtype != false_vector.dtype:
+ raise TypeError(
+ "%s.dtype=%s does not match %s.dtype=%s"
+ % (true_vector.name, true_vector.dtype,
+ false_vector.name, false_vector.dtype))
+ n = array_ops.shape(true_vector)[0]
+ return array_ops.slice(array_ops.concat(0, (true_vector, false_vector)),
+ [math_ops.select(cond, 0, n)],
+ [math_ops.select(cond, n, -1)])
diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py
index ac394bc1d5..856ed8144f 100644
--- a/tensorflow/contrib/distributions/python/ops/shape.py
+++ b/tensorflow/contrib/distributions/python/ops/shape.py
@@ -17,136 +17,194 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
+
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
-class _ShapeUtil(object):
- """Class which helps infer/identify subsets of tensor dimensions.
+class _DistributionShape(object):
+ """Manage and manipulate `Distribution` shape.
Terminology:
Recall that a `Tensor` has:
- shape: sizes of tensor dimensions,
- ndims: size of shape; number of tensor dimensions,
- dims: indexes into shape; useful for transpose, reduce.
-
- Tensors sampled from a `Distribution` can be partitioned by:
- sample dims: indexes independent, identically distributed (iid) draws,
- batch dims: indexes non-identical draws,
- event dims: indexes coordinates of a single draw.
+ - `shape`: size of `Tensor` dimensions,
+ - `ndims`: size of `shape`; number of `Tensor` dimensions,
+ - `dims`: indexes into `shape`; useful for transpose, reduce.
+
+ `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`,
+ `batch_dims`, and `event_dims`. To understand the semantics of these
+ dimensions, consider when two of the three are fixed and the remaining
+ is varied:
+ - `sample_dims`: indexes independent draws from identical
+ parameterizations of the `Distribution`.
+ - `batch_dims`: indexes independent draws from non-identical
+ parameterizations of the `Distribution`.
+ - `event_dims`: indexes event coordinates from one sample.
+
+ The `sample`, `batch`, and `event` dimensions constitute the entirety of a
+ `Distribution` `Tensor`'s shape.
+
+ The dimensions are always in `sample`, `batch`, `event` order.
+
+ Purpose:
+ This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into
+ `Distribution` notions of `sample,` `batch,` and `event` dimensions. That
+ is, it computes any of:
- The sample, batch, and event dimensions constitute the entirety of a
- `Tensor` shape. The dimensions are always in sample, batch, event order.
+ ```
+ sample_shape batch_shape event_shape
+ sample_dims batch_dims event_dims
+ sample_ndims batch_ndims event_ndims
+ ```
- Assumptions:
- We assume that batch_ndims and event_ndims are statically known for both
- creating this object and for inputs to its functions.
- TODO(jvdillon): Relax this assumption and support fully unknown shape.
+ for a given `Tensor`, e.g., the result of
+ `Distribution.sample(sample_shape=...)`.
- We also assume that the `Tensor` rank is static, i.e., `x.get_shape().ndims
- is not None`.
+ For a given `Tensor`, this class computes the above table using minimal
+ information: `batch_ndims` and `event_ndims`.
- Possible use-cases:
- ~ Sample dimensions:
+ Examples of `Distribution` `shape` semantics:
+ - Sample dimensions:
Computing summary statistics, i.e., the average is a reduction over sample
dimensions.
- ~ Batch dimensions:
- Log-likelihood under model predicted location:
```python
- mu = ... # vector of predictions, one for each covariate.
- neg_log_likelihood = -tf.reduce_mean(
- Normal(loc=mu, scale=1).log_pdf(x),
- reduce_dims=[0])
+ sample_dims = [0]
+ tf.reduce_mean(Normal(mu=1.3, sigma=1.).sample_n(1000),
+ reduction_indices=sample_dims) # ~= 1.3
```
+ - Batch dimensions:
Monte Carlo estimation of a marginal probability:
Average over batch dimensions where batch dimensions are associated with
- random draws of a prior.
+ random draws from a prior.
E.g., suppose we want to find the Monte Carlo estimate of the marginal
- distribution of a Normal with a random Laplace location:
+ distribution of a `Normal` with a random `Laplace` location:
+
```
- P(X=x) = integral P(X=x|y) P(Y=y) dy
- ~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1)
- = tf.reduce_mean(Normal(loc=Laplace(0, 1).sample_n(n=1000),
- scale=tf.ones([1000, 1])).pdf(x),
- reduce_dims=[0])
+ P(X=x) = integral P(X=x|y) P(Y=y) dy
+ ~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1)
+ = tf.reduce_mean(Normal(mu=Laplace(0., 1.).sample_n(n=1000),
+ sigma=tf.ones(1000)).pdf(x),
+ reduction_indices=batch_dims)
```
- The `Laplace` distribution generates a tensor of shape [1000, 1]. When fed
- to a `Normal`, this is interpreted as 1000 different locations, i.e.,
- 1000 non-identical Normals. Therefore a single call to pdf(x) yields 1000
- probabilities, one for every location. The average over this batch yields
- the marginal.
+ The `Laplace` distribution generates a `Tensor` of shape `[1000]`. When
+ fed to a `Normal`, this is interpreted as 1000 different locations, i.e.,
+ 1000 non-identical Normals. Therefore a single call to `pdf(x)` yields
+ 1000 probabilities, one for every location. The average over this batch
+ yields the marginal.
- ~ Event dimensions:
+ - Event dimensions:
Computing the determinant of the Jacobian of a function of a random
variable involves a reduction over event dimensions.
+ E.g., Jacobian of the transform `Y = g(X) = exp(X)`:
+
+ ```python
+ tf.div(1., tf.reduce_prod(x, event_dims))
+ ```
- Examples:
- Write S, B, E for sample shape, batch shape, and event shape (resp.).
+ Examples using this class:
+ Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`.
```python
- x.get_shape() == S + B + E # For statically known x shape.
-
- # 100 iid samples from one multivariate Normal with two
- # degrees of freedom (DF).
+ # 150 iid samples from one multivariate Normal with two degrees of freedom.
mu = [0., 0]
sigma = [[1., 0],
[0, 1]]
- X = MultivariateNormal(loc=mu, scale=sigma).sample_n(n=100)
- # S = [100]
+ mvn = MultivariateNormal(mu, sigma)
+ rand_mvn = mvn.sample(sample_shape=[3, 50])
+ shaper = DistributionShape(batch_ndims=0, event_ndims=1)
+ S, B, E = shaper.get_shape(rand_mvn)
+ # S = [3, 50]
# B = []
# E = [2]
- # 100 iid samples from one Wishart with 2x2 DF.
+ # 12 iid samples from one Wishart with 2x2 events.
sigma = [[1., 0],
- [0, 1]]
- X = Wishart(scale=sigma).sample_n(n=100)
- # S = [100]
+ [2, 1]]
+ wishart = Wishart(df=5, scale=sigma)
+ rand_wishart = wishart.sample(sample_shape=[3, 4])
+ shaper = DistributionShape(batch_ndims=0, event_ndims=2)
+ S, B, E = shaper.get_shape(rand_wishart)
+ # S = [3, 4]
# B = []
# E = [2, 2]
- # 100 iid samples (with shape [2, 50]) from two, non-identical bivariate
- # Normal distributions.
- mu = ... # shape(2, 2)
- sigma = ... # shape(2, 2, 2)
- X = MultivariateNormal(loc=mu, scale=sigma).sample(shape=[2, 50])
- # S = [2, 50]
+ # 100 iid samples from two, non-identical trivariate Normal distributions.
+ mu = ... # shape(2, 3)
+ sigma = ... # shape(2, 3, 3)
+ X = MultivariateNormal(mu, sigma).sample(shape=[4, 25])
+ # S = [4, 25]
# B = [2]
- # E = [2]
+ # E = [3]
```
+ Argument Validation:
+ When `validate_args=True`, checks that cannot be done during
+ graph construction are performed at graph execution. This may result in a
+ performance degradation because data must be switched from GPU to CPU.
+
+ For example, when `validate_args=True` and `event_ndims` is a
+ non-constant `Tensor`, it is checked to be a non-negative integer at graph
+ execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor`
+ arguments are always checked for correctness since this can be done for
+ "free," i.e., during graph construction.
"""
- def __init__(self, batch_ndims=None, event_ndims=None, name='ShapeUtil'):
- """Construct ShapeUtil with known sample, batch, and/or event ndims.
+ def __init__(self,
+ batch_ndims=None,
+ event_ndims=None,
+ validate_args=True,
+ name="DistributionShape"):
+ """Construct `DistributionShape` with fixed `batch_ndims`, `event_ndims`.
- Typically, batch_ndims and event_ndims are fixed throughout the lifetime of
- a Distribution.
+ `batch_ndims` and `event_ndims` are fixed throughout the lifetime of a
+ `Distribution`. They may only be known at graph execution.
+
+ If both `batch_ndims` and `event_ndims` are python scalars (rather than
+ either being a `Tensor`), functions in this class automatically perform
+ sanity checks during graph construction.
Args:
- batch_ndims: number of dims (rank) of the batch portion of indexes of a
- `Tensor`. A "batch" is a non-identical distribution, i.e, Normal with
- different parameters.
- event_ndims: number of dims (rank) of the event portion of indexes of a
- `Tensor`. An "event" is what is sampled from a distribution, i.e., a
- trivariate Normal has an event shape of [3] and a 4 dimensional Wishart
- has an event shape of [4, 4].
- name: `String`. The name to give Ops created by this class.
+ batch_ndims: `Tensor`. Number of `dims` (`rank`) of the batch portion of
+ indexes of a `Tensor`. A "batch" is a non-identical distribution, i.e,
+ Normal with different parameters.
+ event_ndims: `Tensor`. Number of `dims` (`rank`) of the event portion of
+ indexes of a `Tensor`. An "event" is what is sampled from a
+ distribution, i.e., a trivariate Normal has an event shape of [3] and a
+ 4 dimensional Wishart has an event shape of [4, 4].
+ validate_args: `Boolean`. When `True`, non-`tf.constant` `Tensor`
+ arguments are checked for correctness. (`tf.constant` arguments are
+ always checked.)
+ name: `String`. The name prepended to Ops created by this class.
Raises:
- ValueError: if batch_ndims or event_ndims are invalid.
+ ValueError: if either `batch_ndims` or `event_ndims` are: `None`,
+ negative, not `int32`.
"""
- if batch_ndims < 0:
- raise ValueError('must specify non-negative batch_ndims(%d)', batch_ndims)
- if batch_ndims > 0 and event_ndims < 1:
- raise ValueError('must specify positive event_ndims(%d) when '
- 'batch_ndims(%d) is positive', event_ndims, batch_ndims)
- # TODO(jvdillon): Support batches of scalars.
- self._name = name
+ if batch_ndims is None: raise ValueError("batch_ndims cannot be None")
+ if event_ndims is None: raise ValueError("event_ndims cannot be None")
self._batch_ndims = batch_ndims
self._event_ndims = event_ndims
+ self._validate_args = validate_args
+ self._name = name
+ with self._name_scope("init"):
+ self._batch_ndims = self._assert_non_negative_int32_scalar(
+ ops.convert_to_tensor(batch_ndims, name="batch_ndims"))
+ self._batch_ndims_static, self._batch_ndims_is_0 = self._introspect_ndims(
+ self._batch_ndims)
+ self._event_ndims = self._assert_non_negative_int32_scalar(
+ ops.convert_to_tensor(event_ndims, name="event_ndims"))
+ self._event_ndims_static, self._event_ndims_is_0 = self._introspect_ndims(
+ self._event_ndims)
@property
def name(self):
@@ -163,234 +221,246 @@ class _ShapeUtil(object):
"""Returns number of dimensions needed to index a sample's coordinates."""
return self._event_ndims
- def get_ndims(self, x, name='get_ndims'):
- """Get tensor ndims (rank).
-
- Args:
- x: `Tensor`.
- name: `String`. The name to give this op.
-
- Raises:
- ValueError: if ndims is not statically known.
-
- Returns:
- `Scalar` number of dimensions associated with a `Tensor`.
- """
- if x is None:
- raise ValueError('Input was None which does not have known ndims.')
- with ops.name_scope(self.name):
- with ops.name_scope(name, values=[x]):
- ndims = ops.convert_to_tensor(x).get_shape().ndims
- if ndims is None:
- raise ValueError('ShapeUtil assumes static number of '
- 'dimensions(%d)', ndims)
- return ndims
-
- def get_sample_ndims(self, x):
- """Returns number of dimensions corresponding to iid draws.
-
- Args:
- x: `Tensor`.
-
- Raises:
- ValueError: if batch_ndims or event_ndims are not statically known.
- ValueError: if static sample_ndims does not match inferred
-
- Returns:
- Scalar number of dimensions associated with a sample.
- """
- ndims = self.get_ndims(x)
- sample_ndims = ndims - self.batch_ndims - self.event_ndims
- if sample_ndims < 0:
- raise ValueError('expected batch_ndims(%d) + event_ndims(%d) < ndims(%d)',
- self.batch_ndims, self.event_ndims, ndims)
- return sample_ndims
-
- def get_dims(self, x, sample=True, batch=True, event=True):
- """Returns subset of tensor's dimension indexes (indexes into shape).
-
- Args:
- x: `Tensor`.
- sample: `Boolean`. Include sample dimensions or not.
- batch: `Boolean`. Include batch dimensions or not.
- event: `Boolean`. Include event dimensions or not.
-
- Raises:
- ValueError: if `x.get_shape().ndims` is `None`
-
- Returns:
- List enumerating requested dimensions.
- """
- ndims = self.get_ndims(x)
-
- if sample and batch and event:
- return list(range(ndims))
-
- sample_start = 0
- batch_start = self.get_sample_ndims(x)
- event_start = batch_start + self.batch_ndims
-
- sample_shape = list(range(sample_start, batch_start)) if sample else []
- batch_shape = list(range(batch_start, event_start)) if batch else []
- event_shape = list(range(event_start, ndims)) if event else []
-
- return sample_shape + batch_shape + event_shape
+ @property
+ def validate_args(self):
+ """Returns True if graph-runtime `Tensor` checks are enabled."""
+ return self._validate_args
- def get_shape(self, x, sample=True, batch=True, event=True, name='get_shape'):
- """Returns subset of tensor's shape (size of dimensions).
+ def get_ndims(self, x, name="get_ndims"):
+ """Get `Tensor` number of dimensions (rank).
Args:
x: `Tensor`.
- sample: `Boolean`. Include sample shape or not.
- batch: `Boolean`. Include batch shape or not.
- event: `Boolean`. Include event shape or not.
name: `String`. The name to give this op.
- Raises:
- ValueError: if `x.get_shape().ndims` is `None`
-
Returns:
- List describing event shape if known statically, `Tensor` otherwise.
+ ndims: Scalar number of dimensions associated with a `Tensor`.
"""
- if not sample and not batch and not event:
- return []
- with ops.name_scope(self._name):
- with ops.name_scope(name, values=[x]):
- x = ops.convert_to_tensor(x)
- shape = (x.get_shape().as_list()
- if x.get_shape().is_fully_defined()
- else array_ops.shape(x))
-
- if sample and batch and event:
- return shape
-
- sample_start = 0
- batch_start = self.get_sample_ndims(x)
- event_start = batch_start + self.batch_ndims
-
- sample_shape = shape[sample_start:batch_start] if sample else []
- batch_shape = shape[batch_start:event_start] if batch else []
- event_shape = shape[event_start:] if event else []
-
- if not batch and not event:
- return sample_shape
- if not sample and not event:
- return batch_shape
- if not sample and not batch:
- return event_shape
-
- if x.get_shape().is_fully_defined():
- return sample_shape + batch_shape + event_shape
- else:
- return array_ops.concat(0, [sample_shape, batch_shape, event_shape])
-
- def get_sample_dims(self, x):
- """Returns dimension indexes corresponding to sample.
-
- Convenience function; identical to:
+ with self._name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ ndims = x.get_shape().ndims
+ if ndims is None:
+ return array_ops.rank(x, name="ndims")
+ return ops.convert_to_tensor(ndims, dtype=dtypes.int32, name="ndims")
- ```python
- get_dims(x, sample=True, batch=False, event=False)
- ```
+ def get_sample_ndims(self, x, name="get_sample_ndims"):
+ """Returns number of dimensions corresponding to iid draws ("sample").
Args:
x: `Tensor`.
-
- Raises:
- ValueError: if `x.get_shape().ndims` is `None`
+ name: `String`. The name to give this op.
Returns:
- List enumerating sample dimensions.
- """
- return self.get_dims(x, sample=True, batch=False, event=False)
-
- def get_batch_dims(self, x):
- """Returns dimension indexes corresponding to batch.
-
- Convenience function; identical to:
-
- ```python
- get_dims(x, sample=False, batch=True, event=False)
- ```
-
- Args:
- x: `Tensor`.
+ sample_ndims: `Tensor` (0D, `int32`).
Raises:
- ValueError: if `x.get_shape().ndims` is `None`
-
- Returns:
- List enumerating batch dimensions.
+ ValueError: if `sample_ndims` is calculated to be negative.
"""
- return self.get_dims(x, sample=False, batch=True, event=False)
-
- def get_event_dims(self, x):
- """Returns dimension indexes corresponding to event.
-
- Convenience function; identical to:
+ with self._name_scope(name, values=[x]):
+ ndims = self.get_ndims(x, name=name)
+ if self._is_all_constant_helper(ndims, self.batch_ndims,
+ self.event_ndims):
+ ndims = tensor_util.constant_value(ndims)
+ sample_ndims = (ndims - self._batch_ndims_static -
+ self._event_ndims_static)
+ if sample_ndims < 0:
+ raise ValueError(
+ "expected batch_ndims(%d) + event_ndims(%d) <= ndims(%d)" %
+ (self._batch_ndims_static, self._event_ndims_static, ndims))
+ return ops.convert_to_tensor(sample_ndims, name="sample_ndims")
+ else:
+ with ops.name_scope(name="sample_ndims"):
+ sample_ndims = ndims - self.batch_ndims - self.event_ndims
+ if self.validate_args:
+ sample_ndims = control_flow_ops.with_dependencies(
+ [check_ops.assert_non_negative(sample_ndims)], sample_ndims)
+ return sample_ndims
+
+ def get_dims(self, x, name="get_dims"):
+ """Returns dimensions indexing `sample_shape`, `batch_shape`, `event_shape`.
+
+ Example:
```python
- get_dims(x, sample=False, batch=False, event=True)
+ x = ... # Tensor with shape [4, 3, 2, 1]
+ sample_dims, batch_dims, event_dims = _DistributionShape(
+ batch_ndims=2, event_ndims=1).get_dims(x)
+ # sample_dims == [0]
+ # batch_dims == [1, 2]
+ # event_dims == [3]
+ # Note that these are not the shape parts, but rather indexes into shape.
```
Args:
x: `Tensor`.
-
- Raises:
- ValueError: if `x.get_shape().ndims` is `None`
+ name: `String`. The name to give this op.
Returns:
- List enumerating event dimensions.
+ sample_dims: `Tensor` (1D, `int32`).
+ batch_dims: `Tensor` (1D, `int32`).
+ event_dims: `Tensor` (1D, `int32`).
"""
- return self.get_dims(x, sample=False, batch=False, event=True)
-
- def get_sample_shape(self, x):
- """Returns shape corresponding to sample.
-
- Convenience function; identical to:
+ with self._name_scope(name, values=[x]):
+ def make_dims(start_sum, size, name):
+ """Closure to make dims range."""
+ start_sum = start_sum if start_sum else (
+ array_ops.zeros((), dtype=dtypes.int32, name="zero"),)
+ if self._is_all_constant_helper(size, *start_sum):
+ start = sum([tensor_util.constant_value(s) for s in start_sum])
+ stop = start + tensor_util.constant_value(size)
+ return ops.convert_to_tensor(
+ list(range(start, stop)), dtype=dtypes.int32, name=name)
+ else:
+ start = sum(start_sum)
+ return math_ops.range(start, start + size)
+ sample_ndims = self.get_sample_ndims(x, name=name)
+ return (make_dims((), sample_ndims, name="sample_dims"),
+ make_dims((sample_ndims,), self.batch_ndims, name="batch_dims"),
+ make_dims((sample_ndims, self.batch_ndims),
+ self.event_ndims, name="event_dims"))
- ```python
- get_shape(x, sample=True, batch=False, event=False)
- ```
+ def get_shape(self, x, name="get_shape"):
+ """Returns `Tensor`'s shape partitioned into `sample`, `batch`, `event`.
Args:
x: `Tensor`.
+ name: `String`. The name to give this op.
Returns:
- List describing sample shape if known statically, `Tensor` otherwise.
+ sample_shape: `Tensor` (1D, `int32`).
+ batch_shape: `Tensor` (1D, `int32`).
+ event_shape: `Tensor` (1D, `int32`).
"""
- return self.get_shape(x, sample=True, batch=False, event=False)
-
- def get_batch_shape(self, x):
- """Returns shape corresponding to batch.
-
- Convenience function; identical to:
-
- ```python
- get_shape(x, sample=False, batch=True, event=False)
- ```
+ with self._name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ def slice_shape(start_sum, size, name):
+ """Closure to slice out shape."""
+ start_sum = start_sum if start_sum else (
+ array_ops.zeros((), dtype=dtypes.int32, name="zero"),)
+ if (x.get_shape().ndims is not None and
+ self._is_all_constant_helper(size, *start_sum)):
+ start = sum([tensor_util.constant_value(s) for s in start_sum])
+ stop = start + tensor_util.constant_value(size)
+ slice_ = x.get_shape()[start:stop].as_list()
+ if all(s is not None for s in slice_):
+ return ops.convert_to_tensor(slice_, dtype=dtypes.int32, name=name)
+ # Fall-through intended.
+ return array_ops.slice(array_ops.shape(x), (sum(start_sum),), (size,))
+ sample_ndims = self.get_sample_ndims(x, name=name)
+ return (slice_shape((), sample_ndims,
+ name="sample_shape"),
+ slice_shape((sample_ndims,), self.batch_ndims,
+ name="batch_shape"),
+ slice_shape((sample_ndims, self.batch_ndims), self.event_ndims,
+ name="event_shape"))
+
+ def make_batch_of_event_sample_matrices(
+ self, x, name="make_batch_of_event_sample_matrices"):
+ """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.
+
+ Where:
+ - `B_ = B if B else [1]`,
+ - `E_ = E if E else [1]`,
+ - `S_ = [tf.reduce_prod(S)]`.
Args:
x: `Tensor`.
+ name: `String`. The name to give this op.
Returns:
- List describing batch shape if known statically, `Tensor` otherwise.
+ x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
+ sample_shape: `Tensor` (1D, `int32`).
"""
- return self.get_shape(x, sample=False, batch=True, event=False)
-
- def get_event_shape(self, x):
- """Returns shape corresponding to event.
-
- Convenience function; identical to:
-
- ```python
- get_shape(x, sample=False, batch=False, event=True)
- ```
+ with self._name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ sample_shape, batch_shape, event_shape = self.get_shape(x)
+ event_shape = distribution_util.pick_vector(
+ self._event_ndims_is_0, (1,), event_shape)
+ batch_shape = distribution_util.pick_vector(
+ self._batch_ndims_is_0, (1,), batch_shape)
+ new_shape = array_ops.concat(0, ((-1,), batch_shape, event_shape))
+ x = array_ops.reshape(x, shape=new_shape)
+ x = distribution_util.rotate_transpose(x, shift=-1)
+ return x, sample_shape
+
+ def undo_make_batch_of_event_sample_matrices(
+ self, x, sample_shape, name="undo_make_batch_of_event_sample_matrices"):
+ """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.
+
+ Where:
+ - `B_ = B if B else [1]`,
+ - `E_ = E if E else [1]`,
+ - `S_ = [tf.reduce_prod(S)]`.
+
+ This function "reverses" `make_batch_of_event_sample_matrices`.
Args:
- x: `Tensor`.
+ x: `Tensor` of shape `B_+E_+S_`.
+ sample_shape: `Tensor` (1D, `int32`).
+ name: `String`. The name to give this op.
Returns:
- List describing event shape if known statically, `Tensor` otherwise.
+ x: `Tensor`. Input transposed/reshaped to `S+B+E`.
"""
- return self.get_shape(x, sample=False, batch=False, event=True)
+ with self._name_scope(name, values=[x, sample_shape]):
+ x = ops.convert_to_tensor(x, name="x")
+ sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape")
+ x = distribution_util.rotate_transpose(x, shift=1)
+ if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
+ if self._batch_ndims_is_0 or self._event_ndims_is_0:
+ b = ((min(-2, -1 - self._event_ndims_static),)
+ if self._batch_ndims_is_0 else ())
+ e = (-1,) if self._event_ndims_is_0 else ()
+ x = array_ops.squeeze(x, squeeze_dims=b + e)
+ _, batch_shape, event_shape = self.get_shape(x)
+ else:
+ s = (x.get_shape().as_list() if x.get_shape().is_fully_defined()
+ else array_ops.shape(x))
+ batch_shape = array_ops.slice(s, (1,), (self.batch_ndims,))
+ # Since sample_dims=1 and is left-most, we add 1 to the number of
+ # batch_ndims to get the event start dim.
+ event_start = math_ops.select(
+ self._batch_ndims_is_0, 2, 1 + self.batch_ndims)
+ event_shape = array_ops.slice(s, (event_start,), (self.event_ndims,))
+ new_shape = array_ops.concat(0, (sample_shape, batch_shape, event_shape))
+ x = array_ops.reshape(x, shape=new_shape)
+ return x
+
+ @contextlib.contextmanager
+ def _name_scope(self, name=None, values=None):
+ """Helper function to standardize op scope."""
+ with ops.name_scope(self.name):
+ with ops.name_scope(name, values=(
+ (values or []) + [self.batch_ndims, self.event_ndims])) as scope:
+ yield scope
+
+ def _is_all_constant_helper(self, *args):
+ """Helper which returns True if all inputs are constant_value."""
+ return all(tensor_util.constant_value(x) is not None for x in args)
+
+ def _assert_non_negative_int32_scalar(self, x):
+ """Helper which ensures that input is a non-negative, int32, scalar."""
+ x = ops.convert_to_tensor(x, name="x")
+ if x.dtype.base_dtype != dtypes.int32.base_dtype:
+ raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, dtypes.int32))
+ x_value_static = tensor_util.constant_value(x)
+ if x.get_shape().ndims is not None and x_value_static is not None:
+ if x.get_shape().ndims != 0:
+ raise ValueError("%s.ndims=%d is not 0 (scalar)" %
+ (x.name, x.get_shape().ndims))
+ if x_value_static < 0:
+ raise ValueError("%s.value=%d cannot be negative" %
+ (x.name, x_value_static))
+ return x
+ if self.validate_args:
+ x = control_flow_ops.with_dependencies([
+ check_ops.assert_rank(x, 0),
+ check_ops.assert_non_negative(x)], x)
+ return x
+
+ def _introspect_ndims(self, ndims):
+ """Helper to establish some properties of input ndims args."""
+ if self._is_all_constant_helper(ndims):
+ return (tensor_util.constant_value(ndims),
+ tensor_util.constant_value(ndims) == 0)
+ return None, math_ops.equal(ndims, 0)