diff options
-rw-r--r-- | tensorflow/python/ops/distributions/distribution.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index c055ca43e8..0866fa8b0b 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -593,7 +593,7 @@ class Distribution(_BaseDistribution): Returns: batch_shape: `TensorShape`, possibly unknown. """ - return self._batch_shape() + return tensor_shape.as_shape(self._batch_shape()) def _event_shape_tensor(self): raise NotImplementedError("event_shape_tensor is not implemented") @@ -626,7 +626,7 @@ class Distribution(_BaseDistribution): Returns: event_shape: `TensorShape`, possibly unknown. """ - return self._event_shape() + return tensor_shape.as_shape(self._event_shape()) def is_scalar_event(self, name="is_scalar_event"): """Indicates that `event_shape == []`. |