aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 17:24:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 17:27:27 -0700
commit8c5d37c3b96cdbcb8a3b657144d4fb63fb3dc100 (patch)
tree5504a8840440f9cdd63fee2f72d9a034b7532fd2 /tensorflow/contrib/distributions
parent5ebfc750447fd100e1b1c3bd747b87f460b50a81 (diff)
Add `move_dimension` utility to move a single dimension within a Tensor.
PiperOrigin-RevId: 200141207
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py48
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py79
2 files changed, 127 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index 31d24aa9ea..bbbec2103a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.linalg import linear_operator_diag
@@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase):
return False
+class TestMoveDimension(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_move_dimension_static_shape(self):
+
+ x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 1, 1)
+ self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, 3)
+ self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, -2)
+ self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 4, 2)
+ self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_move_dimension_dynamic_shape(self):
+
+ x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
+ x = array_ops.placeholder_with_default(input=x_, shape=None)
+
+ x_perm = distribution_util.move_dimension(x, 1, 1)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, 3)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, -2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 4, 2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 6, 4, 1])
+
+ x_perm = distribution_util.move_dimension(x, -1, 2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 6, 4, 1])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 289e1d50e1..6959b3e877 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -21,12 +21,19 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+
+# The following two lines are redundant, in a sense. The first enables
+# good coding practice *within* this file (`util.prefer_static_value`
+# rather than `prefer_static_value`). The second ensures that users
+# also get the core utils when they import this file.
+from tensorflow.python.ops.distributions import util
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
@@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
def static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
+
+
+def move_dimension(x, source_idx, dest_idx):
+ """Move a single tensor dimension within its shape.
+
+ This is a special case of `tf.transpose()`, which applies
+ arbitrary permutations to tensor dimensions.
+
+ Args:
+ x: Tensor of rank `ndims`.
+ source_idx: Integer index into `x.shape` (negative indexing is
+ supported).
+ dest_idx: Integer index into `x.shape` (negative indexing is
+ supported).
+
+ Returns:
+ x_perm: Tensor of rank `ndims`, in which the dimension at original
+ index `source_idx` has been moved to new index `dest_idx`, with
+ all other dimensions retained in their original order.
+
+ Example:
+
+ ```python
+ x = tf.placeholder(shape=[200, 30, 4, 1, 6])
+ x_perm = _move_dimension(x, 1, 1) # no-op
+ x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
+ x_perm = _move_dimension(x, 0, -2) # equivalent to previous
+ x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
+ ```
+ """
+ ndims = util.prefer_static_rank(x)
+ if isinstance(source_idx, int):
+ dtype = dtypes.int32
+ else:
+ dtype = dtypes.as_dtype(source_idx.dtype)
+
+ # Handle negative indexing. Since ndims might be dynamic, this makes
+ # source_idx and dest_idx also possibly dynamic.
+ if source_idx < 0:
+ source_idx = ndims + source_idx
+ if dest_idx < 0:
+ dest_idx = ndims + dest_idx
+
+ # Construct the appropriate permutation of dimensions, depending
+ # whether the source is before or after the destination.
+ def move_left_permutation():
+ return util.prefer_static_value(
+ array_ops.concat([
+ math_ops.range(0, dest_idx, dtype=dtype),
+ [source_idx],
+ math_ops.range(dest_idx, source_idx, dtype=dtype),
+ math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0))
+
+ def move_right_permutation():
+ return util.prefer_static_value(
+ array_ops.concat([
+ math_ops.range(0, source_idx, dtype=dtype),
+ math_ops.range(source_idx+1, dest_idx+1, dtype=dtype),
+ [source_idx],
+ math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0))
+
+ def x_permuted():
+ return array_ops.transpose(
+ x, perm=smart_cond.smart_cond(source_idx < dest_idx,
+ move_right_permutation,
+ move_left_permutation))
+
+ # One final conditional to handle the special case where source
+ # and destination indices are equal.
+ return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
+ lambda: x,
+ x_permuted)