diff options
author | 2018-02-13 19:16:08 -0800 | |
---|---|---|
committer | 2018-02-13 19:19:43 -0800 | |
commit | e23948da81a007b27869a75c4a7dbe8f91ea8c03 (patch) | |
tree | c120e5fe5c5d5f7dc84035a5db20a7240dafcf95 | |
parent | 7325919e07e2ae45b3b5436db1dc9f26a51af6c6 (diff) |
For models running in Eager mode, do not update the weights of the BatchNorm layer if the layer's trainable argument is False.
This change is required in Eager mode to freeze a layer's weights when we set the layer's trainable attribute to False.
This should not be confused with the "training" attribute which refers to a model's training or inference mode behavior.
PiperOrigin-RevId: 185625661
-rw-r--r-- | tensorflow/python/layers/normalization.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 656d566ab5..323a9f8ee3 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -493,6 +493,7 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + in_eager_mode = context.in_eager_mode() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation @@ -595,6 +596,9 @@ class BatchNormalization(base.Layer): axis=1, keep_dims=True) def _do_update(var, value): + if in_eager_mode and not self.trainable: + return + return moving_averages.assign_moving_average( var, value, self.momentum, zero_debias=False) |