aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/layers/normalization.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index d82946382f..df2b97f03e 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -477,7 +477,8 @@ class BatchNormalization(base.Layer):
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.get_shape()
- reduction_axes = [i for i in range(len(input_shape)) if i not in self.axis]
+ ndims = len(input_shape)
+ reduction_axes = [i for i in range(ndims) if i not in self.axis]
if self.virtual_batch_size is not None:
del reduction_axes[1] # Do not reduce along virtual batch dim
@@ -541,13 +542,15 @@ class BatchNormalization(base.Layer):
else:
mean, variance = self.moving_mean, self.moving_variance
- # Broadcasting only necessary for single-axis batch norm
- broadcast_shape = [1] * len(input_shape)
+ # Broadcasting only necessary for single-axis batch norm where the axis is
+ # not the last dimension
+ broadcast_shape = [1] * ndims
broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
rank = len(inputs.get_shape())
def _broadcast(v):
- if v is not None and len(v.get_shape()) != rank:
- assert len(self.axis) == 1
+ if (v is not None and
+ len(v.get_shape()) != rank and
+ reduction_axes != list(range(ndims))[:-1]):
return array_ops.reshape(v, broadcast_shape)
return v