aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/ops/distributions/distribution.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index c055ca43e8..0866fa8b0b 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -593,7 +593,7 @@ class Distribution(_BaseDistribution):
Returns:
batch_shape: `TensorShape`, possibly unknown.
"""
- return self._batch_shape()
+ return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
raise NotImplementedError("event_shape_tensor is not implemented")
@@ -626,7 +626,7 @@ class Distribution(_BaseDistribution):
Returns:
event_shape: `TensorShape`, possibly unknown.
"""
- return self._event_shape()
+ return tensor_shape.as_shape(self._event_shape())
def is_scalar_event(self, name="is_scalar_event"):
"""Indicates that `event_shape == []`.