diff options
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/layers/normalization.py')
-rw-r--r-- | tensorflow/python/keras/_impl/keras/layers/normalization.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py index 5462a95d7d..c16fc07fb4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/normalization.py +++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py @@ -593,9 +593,9 @@ class BatchNormalization(Layer): # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(new_mean, - axis=1, keep_dims=True) + axis=1, keepdims=True) new_variance = math_ops.reduce_mean(new_variance, - axis=1, keep_dims=True) + axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: |