diff options
Diffstat (limited to 'tensorflow/python/ops/nn_impl.py')
-rw-r--r-- | tensorflow/python/ops/nn_impl.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 4710af0d9f..73aea2c260 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -514,15 +514,15 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None): with ops.name_scope(name, "sufficient_statistics", [x, shift]): x = ops.convert_to_tensor(x, name="x") x_shape = x.get_shape() - if x_shape.is_fully_defined(): + if all(x_shape[d].value is not None for d in axes): counts = 1 for d in axes: counts *= x_shape[d].value counts = constant_op.constant(counts, dtype=x.dtype) else: # shape needs to be inferred at runtime. - x_dims = array_ops.gather(array_ops.shape(x), axes) - counts = math_ops.cast( - math_ops.reduce_prod(x_dims), x.dtype, name="count") + x_dims = array_ops.gather( + math_ops.cast(array_ops.shape(x), x.dtype), axes) + counts = math_ops.reduce_prod(x_dims, name="count") if shift is not None: shift = ops.convert_to_tensor(shift, name="shift") m_ss = math_ops.sub(x, shift) |