aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chris Ying <chrisying@google.com>2017-10-13 18:00:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 18:04:17 -0700
commit860f8c50753bcbfca8243c585033b3d44c4b7c7f (patch)
treee355a1b7e5ba0f20cb5d7b79aa8bdb49eaf41895
parent5a8c47079f664b280bb28eb34ce2c93534305cda (diff)
Fix case where broadcasting is not necessary.
PiperOrigin-RevId: 172169909
-rw-r--r--tensorflow/python/layers/normalization.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index d82946382f..df2b97f03e 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -477,7 +477,8 @@ class BatchNormalization(base.Layer):
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.get_shape()
- reduction_axes = [i for i in range(len(input_shape)) if i not in self.axis]
+ ndims = len(input_shape)
+ reduction_axes = [i for i in range(ndims) if i not in self.axis]
if self.virtual_batch_size is not None:
del reduction_axes[1] # Do not reduce along virtual batch dim
@@ -541,13 +542,15 @@ class BatchNormalization(base.Layer):
else:
mean, variance = self.moving_mean, self.moving_variance
- # Broadcasting only necessary for single-axis batch norm
- broadcast_shape = [1] * len(input_shape)
+ # Broadcasting only necessary for single-axis batch norm where the axis is
+ # not the last dimension
+ broadcast_shape = [1] * ndims
broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value
rank = len(inputs.get_shape())
def _broadcast(v):
- if v is not None and len(v.get_shape()) != rank:
- assert len(self.axis) == 1
+ if (v is not None and
+ len(v.get_shape()) != rank and
+ reduction_axes != list(range(ndims))[:-1]):
return array_ops.reshape(v, broadcast_shape)
return v