aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Joshua V. Dillon <jvdillon@google.com>2018-03-25 21:57:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 21:59:14 -0700
commit668f182b1fdfc31568a44fe650324fe2ddedbbe1 (patch)
treebbe283b6a6d782890bca0d45b02758c7d523248a
parent9d9ea88abd63d2c317e445e54a4f9c90d747343a (diff)
Always cast `tf.distributions.Distribution` `_event_shape`, `_batch_shape`.
PiperOrigin-RevId: 190415923
-rw-r--r--tensorflow/python/ops/distributions/distribution.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index c055ca43e8..0866fa8b0b 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -593,7 +593,7 @@ class Distribution(_BaseDistribution):
Returns:
batch_shape: `TensorShape`, possibly unknown.
"""
- return self._batch_shape()
+ return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
raise NotImplementedError("event_shape_tensor is not implemented")
@@ -626,7 +626,7 @@ class Distribution(_BaseDistribution):
Returns:
event_shape: `TensorShape`, possibly unknown.
"""
- return self._event_shape()
+ return tensor_shape.as_shape(self._event_shape())
def is_scalar_event(self, name="is_scalar_event"):
"""Indicates that `event_shape == []`.