aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sergio Guadarrama <sguada@google.com>2016-07-19 10:27:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 11:32:52 -0700
commit580cdddab964d65e45d9bfeb7b9eca43f18db7eb (patch)
tree677bf107cb2821b273fa150d8ab243773107bc7c
parent4307ebc0d79b7f725c4b309bff5d73a7506ac720 (diff)
Updated comment about batch_norm layer.
Change: 127853412
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 7588daa8f4..e4a25fa113 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -215,10 +215,13 @@ def batch_norm(inputs,
trainable=False,
collections=moving_variance_collections)
+ # If `is_training` doesn't have a constant value, because it is a `Tensor`,
+ # a `Variable` or `Placeholder` then is_training_value will be None and
+ # `needs_moments` will be true.
is_training_value = utils.constant_value(is_training)
- # Calculate the moments based on the individual batch.
need_moments = is_training_value is None or is_training_value
if need_moments:
+ # Calculate the moments based on the individual batch.
mean, variance = nn.moments(inputs, axis, shift=moving_mean)
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None: