aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_impl.py')
-rw-r--r--tensorflow/python/ops/nn_impl.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 4710af0d9f..73aea2c260 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -514,15 +514,15 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
with ops.name_scope(name, "sufficient_statistics", [x, shift]):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
- if x_shape.is_fully_defined():
+ if all(x_shape[d].value is not None for d in axes):
counts = 1
for d in axes:
counts *= x_shape[d].value
counts = constant_op.constant(counts, dtype=x.dtype)
else: # shape needs to be inferred at runtime.
- x_dims = array_ops.gather(array_ops.shape(x), axes)
- counts = math_ops.cast(
- math_ops.reduce_prod(x_dims), x.dtype, name="count")
+ x_dims = array_ops.gather(
+ math_ops.cast(array_ops.shape(x), x.dtype), axes)
+ counts = math_ops.reduce_prod(x_dims, name="count")
if shift is not None:
shift = ops.convert_to_tensor(shift, name="shift")
m_ss = math_ops.sub(x, shift)