diff options
author | Chris Ying <chrisying@google.com> | 2017-10-13 18:00:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 18:04:17 -0700 |
commit | 860f8c50753bcbfca8243c585033b3d44c4b7c7f (patch) | |
tree | e355a1b7e5ba0f20cb5d7b79aa8bdb49eaf41895 | |
parent | 5a8c47079f664b280bb28eb34ce2c93534305cda (diff) |
Fix case where broadcasting is not necessary.
PiperOrigin-RevId: 172169909
-rw-r--r-- | tensorflow/python/layers/normalization.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index d82946382f..df2b97f03e 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -477,7 +477,8 @@ class BatchNormalization(base.Layer): # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() - reduction_axes = [i for i in range(len(input_shape)) if i not in self.axis] + ndims = len(input_shape) + reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim @@ -541,13 +542,15 @@ class BatchNormalization(base.Layer): else: mean, variance = self.moving_mean, self.moving_variance - # Broadcasting only necessary for single-axis batch norm - broadcast_shape = [1] * len(input_shape) + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value rank = len(inputs.get_shape()) def _broadcast(v): - if v is not None and len(v.get_shape()) != rank: - assert len(self.axis) == 1 + if (v is not None and + len(v.get_shape()) != rank and + reduction_axes != list(range(ndims))[:-1]): return array_ops.reshape(v, broadcast_shape) return v |