aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-02-13 19:16:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-13 19:19:43 -0800
commite23948da81a007b27869a75c4a7dbe8f91ea8c03 (patch)
treec120e5fe5c5d5f7dc84035a5db20a7240dafcf95
parent7325919e07e2ae45b3b5436db1dc9f26a51af6c6 (diff)
For models running in Eager mode, do not update the weights of the BatchNorm layer if the layer's trainable argument is False.
This change is required in Eager mode to freeze a layer's weights when we set the layer's trainable attribute to False. This should not be confused with the "training" attribute which refers to a model's training or inference mode behavior. PiperOrigin-RevId: 185625661
-rw-r--r--tensorflow/python/layers/normalization.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 656d566ab5..323a9f8ee3 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -493,6 +493,7 @@ class BatchNormalization(base.Layer):
return (r, d, new_mean, new_variance)
def call(self, inputs, training=False):
+ in_eager_mode = context.in_eager_mode()
if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation
@@ -595,6 +596,9 @@ class BatchNormalization(base.Layer):
axis=1, keep_dims=True)
def _do_update(var, value):
+ if in_eager_mode and not self.trainable:
+ return
+
return moving_averages.assign_moving_average(
var, value, self.momentum, zero_debias=False)