From 860f8c50753bcbfca8243c585033b3d44c4b7c7f Mon Sep 17 00:00:00 2001 From: Chris Ying Date: Fri, 13 Oct 2017 18:00:04 -0700 Subject: Fix case where broadcasting is not necessary. PiperOrigin-RevId: 172169909 --- tensorflow/python/layers/normalization.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index d82946382f..df2b97f03e 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -477,7 +477,8 @@ class BatchNormalization(base.Layer): # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() - reduction_axes = [i for i in range(len(input_shape)) if i not in self.axis] + ndims = len(input_shape) + reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim @@ -541,13 +542,15 @@ class BatchNormalization(base.Layer): else: mean, variance = self.moving_mean, self.moving_variance - # Broadcasting only necessary for single-axis batch norm - broadcast_shape = [1] * len(input_shape) + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value rank = len(inputs.get_shape()) def _broadcast(v): - if v is not None and len(v.get_shape()) != rank: - assert len(self.axis) == 1 + if (v is not None and + len(v.get_shape()) != rank and + reduction_axes != list(range(ndims))[:-1]): return array_ops.reshape(v, broadcast_shape) return v -- cgit v1.2.3