aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-16 08:45:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 08:47:59 -0700
commit61108cc05e3eb6463fbef5eba9d1ff7b1130263d (patch)
tree5feaebe3e35180b0754996574dd306b3a5d8f4cc /tensorflow/contrib/distributions
parenteaa78c17269b97991355974d7a26d650de76bbcd (diff)
Modify tf.contrib.distributions.BatchReshape to behave a bit more like
tf.reshape: accept a single unknown dimension and infer partial shape information statically. PiperOrigin-RevId: 196833267
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py16
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py189
2 files changed, 100 insertions, 105 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
index 59d549b7b8..f2bb2d3325 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -448,8 +448,7 @@ class _BatchReshapeTest(object):
else:
with self.test_session():
- with self.assertRaisesOpError(r"`batch_shape` size must match "
- r"`distributions.batch_shape` size"):
+ with self.assertRaisesOpError(r"Shape sizes do not match."):
batch_reshape_lib.BatchReshape(
distribution=mvn,
batch_shape=new_batch_shape_ph,
@@ -457,8 +456,13 @@ class _BatchReshapeTest(object):
def test_non_positive_shape(self):
dims = 2
- new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check.
- old_batch_shape = [2]
+ old_batch_shape = [4]
+ if self.is_static_shape:
+ # Unknown first dimension does not trigger size check. Note that
+ # any dimension < 0 is treated statically as unknown.
+ new_batch_shape = [-1, 0]
+ else:
+ new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape.
new_batch_shape_ph = (
constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
@@ -471,7 +475,7 @@ class _BatchReshapeTest(object):
mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
if self.is_static_shape:
- with self.assertRaisesRegexp(ValueError, r".*must be positive.*"):
+ with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"):
batch_reshape_lib.BatchReshape(
distribution=mvn,
batch_shape=new_batch_shape_ph,
@@ -479,7 +483,7 @@ class _BatchReshapeTest(object):
else:
with self.test_session():
- with self.assertRaisesOpError(r".*must be positive.*"):
+ with self.assertRaisesOpError(r".*must be >=-1.*"):
batch_reshape_lib.BatchReshape(
distribution=mvn,
batch_shape=new_batch_shape_ph,
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 8a4041cf43..c709318f76 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -42,9 +42,6 @@ class BatchReshape(distribution_lib.Distribution):
This "meta-distribution" reshapes the batch dimensions of another
distribution.
- Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support
- `-1` for flattening.
-
#### Examples
```python
@@ -52,7 +49,7 @@ class BatchReshape(distribution_lib.Distribution):
dtype = np.float32
dims = 2
- new_batch_shape = [1, 2, 3]
+ new_batch_shape = [1, 2, -1]
old_batch_shape = [6]
scale = np.ones(old_batch_shape + [dims], dtype)
@@ -86,8 +83,9 @@ class BatchReshape(distribution_lib.Distribution):
Args:
distribution: The base distribution instance to reshape. Typically an
instance of `Distribution`.
- batch_shape: Positive `int`-like vector-shaped `Tensor` representing the
- new shape of the batch dimensions.
+ batch_shape: Positive `int`-like vector-shaped `Tensor` representing
+ the new shape of the batch dimensions. Up to one dimension may contain
+ `-1`, meaning the remainder of the batch size.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -107,29 +105,26 @@ class BatchReshape(distribution_lib.Distribution):
"""
parameters = distribution_util.parent_frame_arguments()
name = name or "BatchReshape" + distribution.name
- self._distribution = distribution
with ops.name_scope(name, values=[batch_shape]) as name:
- self._batch_shape_ = ops.convert_to_tensor(
- batch_shape,
- dtype=dtypes.int32,
- name="batch_shape")
- self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
- if self._batch_shape_static is not None:
- self._batch_shape_static = np.int32(self._batch_shape_static)
- self._runtime_assertions = validate_init_args(
- self._distribution,
- self._batch_shape_,
- validate_args,
- self._batch_shape_static)
+ # The unexpanded batch shape may contain up to one dimension of -1.
+ self._batch_shape_unexpanded = ops.convert_to_tensor(
+ batch_shape, dtype=dtypes.int32, name="batch_shape")
+ validate_init_args_statically(distribution, self._batch_shape_unexpanded)
+ batch_shape, batch_shape_static, runtime_assertions = calculate_reshape(
+ distribution.batch_shape_tensor(), self._batch_shape_unexpanded,
+ validate_args)
+ self._distribution = distribution
+ self._batch_shape_ = batch_shape
+ self._batch_shape_static = batch_shape_static
+ self._runtime_assertions = runtime_assertions
super(BatchReshape, self).__init__(
- dtype=self._distribution.dtype,
- reparameterization_type=self._distribution.reparameterization_type,
+ dtype=distribution.dtype,
+ reparameterization_type=distribution.reparameterization_type,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=(
- [self._batch_shape_] +
- self._distribution._graph_parents), # pylint: disable=protected-access
+ [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access
name=name)
@property
@@ -141,7 +136,7 @@ class BatchReshape(distribution_lib.Distribution):
return array_ops.identity(self._batch_shape_)
def _batch_shape(self):
- return tensor_shape.TensorShape(self._batch_shape_static)
+ return self._batch_shape_static
def _event_shape_tensor(self):
with ops.control_dependencies(self._runtime_assertions):
@@ -153,11 +148,13 @@ class BatchReshape(distribution_lib.Distribution):
def _sample_n(self, n, seed=None):
with ops.control_dependencies(self._runtime_assertions):
x = self.distribution.sample(sample_shape=n, seed=seed)
- new_shape = array_ops.concat([
- [n],
- self.batch_shape_tensor(),
- self.event_shape_tensor(),
- ], axis=0)
+ new_shape = array_ops.concat(
+ [
+ [n],
+ self._batch_shape_unexpanded,
+ self.event_shape_tensor(),
+ ],
+ axis=0)
return array_ops.reshape(x, new_shape)
def _log_prob(self, x):
@@ -214,9 +211,9 @@ class BatchReshape(distribution_lib.Distribution):
event_ndims = (array_ops.size(self.event_shape_tensor())
if self.event_shape.ndims is None
else self.event_shape.ndims)
- batch_ndims = (array_ops.size(self.batch_shape_tensor())
- if self.batch_shape.ndims is None
- else self.batch_shape.ndims)
+ batch_ndims = (
+ array_ops.size(self._batch_shape_unexpanded)
+ if self.batch_shape.ndims is None else self.batch_shape.ndims)
sample_ndims = x_ndims - batch_ndims - event_ndims
if isinstance(sample_ndims, int):
static_sample_shape = x.shape[:sample_ndims]
@@ -239,10 +236,11 @@ class BatchReshape(distribution_lib.Distribution):
self.event_shape_tensor(),
], axis=0)
result = fn(array_ops.reshape(x, old_shape))
- new_shape = array_ops.concat([
- sample_shape,
- self.batch_shape_tensor(),
- ], axis=0)
+ new_shape = array_ops.concat(
+ [
+ sample_shape,
+ self._batch_shape_unexpanded,
+ ], axis=0)
result = array_ops.reshape(result, new_shape)
if (static_sample_shape.ndims is not None and
self.batch_shape.ndims is not None):
@@ -262,8 +260,7 @@ class BatchReshape(distribution_lib.Distribution):
if static_event_shape_list is None:
static_event_shape_list = [self.event_shape]
new_shape = array_ops.concat(
- [self.batch_shape_tensor()] + event_shape_list,
- axis=0)
+ [self._batch_shape_unexpanded] + event_shape_list, axis=0)
result = array_ops.reshape(fn(), new_shape)
if (self.batch_shape.ndims is not None and
self.event_shape.ndims is not None):
@@ -282,9 +279,9 @@ class BatchReshape(distribution_lib.Distribution):
event_ndims = (array_ops.size(self.event_shape_tensor())
if self.event_shape.ndims is None
else self.event_shape.ndims)
- batch_ndims = (array_ops.size(self.batch_shape_tensor())
- if self.batch_shape.ndims is None
- else self.batch_shape.ndims)
+ batch_ndims = (
+ array_ops.size(self._batch_shape_unexpanded)
+ if self.batch_shape.ndims is None else self.batch_shape.ndims)
expected_batch_event_ndims = batch_ndims + event_ndims
if (isinstance(x_ndims, int) and
@@ -356,62 +353,56 @@ class BatchReshape(distribution_lib.Distribution):
return runtime_assertions
-def validate_init_args(
- distribution,
- batch_shape,
- validate_args,
- batch_shape_static):
+def calculate_reshape(original_shape, new_shape, validate=False, name=None):
+ """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
+ batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
+ if batch_shape_static.is_fully_defined():
+ return np.int32(batch_shape_static.as_list()), batch_shape_static, []
+ with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
+ original_size = math_ops.reduce_prod(original_shape)
+ implicit_dim = math_ops.equal(new_shape, -1)
+ size_implicit_dim = (
+ original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
+ new_ndims = array_ops.shape(new_shape)
+ expanded_new_shape = array_ops.where( # Assumes exactly one `-1`.
+ implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
+ validations = [] if not validate else [
+ check_ops.assert_rank(
+ original_shape, 1, message="Original shape must be a vector."),
+ check_ops.assert_rank(
+ new_shape, 1, message="New shape must be a vector."),
+ check_ops.assert_less_equal(
+ math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
+ 1,
+ message="At most one dimension can be unknown."),
+ check_ops.assert_positive(
+ expanded_new_shape, message="Shape elements must be >=-1."),
+ check_ops.assert_equal(
+ math_ops.reduce_prod(expanded_new_shape),
+ original_size,
+ message="Shape sizes do not match."),
+ ]
+ return expanded_new_shape, batch_shape_static, validations
+
+
+def validate_init_args_statically(distribution, batch_shape):
"""Helper to __init__ which makes or raises assertions."""
- with ops.name_scope(name="validate_init_args",
- values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access
- runtime_assertions = []
-
- if batch_shape.shape.ndims is not None:
- if batch_shape.shape.ndims != 1:
- raise ValueError("`batch_shape` must be a vector "
- "(saw rank: {}).".format(
- batch_shape.shape.ndims))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_rank(
- batch_shape,
- 1,
- message="`batch_shape` must be a vector.",
- name="assert_batch_shape_is_vector"),
- ]
-
- batch_size_static = np.prod(batch_shape_static)
- dist_batch_size_static = (
- None if not distribution.batch_shape.is_fully_defined()
- else np.prod(distribution.batch_shape).value)
-
- if batch_size_static is not None and dist_batch_size_static is not None:
- if batch_size_static != dist_batch_size_static:
- raise ValueError("`batch_shape` size ({}) must match "
- "`distribution.batch_shape` size ({}).".format(
- batch_size_static,
- dist_batch_size_static))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_equal(
- math_ops.reduce_prod(batch_shape),
- math_ops.reduce_prod(distribution.batch_shape_tensor()),
- message=("`batch_shape` size must match "
- "`distributions.batch_shape` size."),
- name="assert_batch_size"),
- ]
-
- if batch_shape_static is not None:
- if np.any(batch_shape_static < 1):
- raise ValueError("`batch_shape` elements must be positive "
- "(i.e., larger than zero).")
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_positive(
- batch_shape,
- message=("`batch_shape` elements must be positive "
- "(i.e., larger than zero)."),
- name="assert_batch_shape_positive")
- ]
-
- return runtime_assertions
+ if batch_shape.shape.ndims is not None:
+ if batch_shape.shape.ndims != 1:
+ raise ValueError("`batch_shape` must be a vector "
+ "(saw rank: {}).".format(batch_shape.shape.ndims))
+
+ batch_shape_static = tensor_util.constant_value_as_shape(batch_shape)
+ batch_size_static = batch_shape_static.num_elements()
+ dist_batch_size_static = distribution.batch_shape.num_elements()
+
+ if batch_size_static is not None and dist_batch_size_static is not None:
+ if batch_size_static != dist_batch_size_static:
+ raise ValueError("`batch_shape` size ({}) must match "
+ "`distribution.batch_shape` size ({}).".format(
+ batch_size_static, dist_batch_size_static))
+
+ if batch_shape_static.dims is not None:
+ if any(
+ dim.value is not None and dim.value < 1 for dim in batch_shape_static):
+ raise ValueError("`batch_shape` elements must be >=-1.")