From 668f182b1fdfc31568a44fe650324fe2ddedbbe1 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Sun, 25 Mar 2018 21:57:09 -0700 Subject: Always cast `tf.distributions.Distribution` `_event_shape`, `_batch_shape`. PiperOrigin-RevId: 190415923 --- tensorflow/python/ops/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index c055ca43e8..0866fa8b0b 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -593,7 +593,7 @@ class Distribution(_BaseDistribution): Returns: batch_shape: `TensorShape`, possibly unknown. """ - return self._batch_shape() + return tensor_shape.as_shape(self._batch_shape()) def _event_shape_tensor(self): raise NotImplementedError("event_shape_tensor is not implemented") @@ -626,7 +626,7 @@ class Distribution(_BaseDistribution): Returns: event_shape: `TensorShape`, possibly unknown. """ - return self._event_shape() + return tensor_shape.as_shape(self._event_shape()) def is_scalar_event(self, name="is_scalar_event"): """Indicates that `event_shape == []`. -- cgit v1.2.3