diff options
author | 2018-03-21 12:07:51 -0700 | |
---|---|---|
committer | 2018-03-21 12:10:30 -0700 | |
commit | 2d0531d72c7dcbb0e149cafdd3a16ee8c3ff357a (patch) | |
tree | 1179ecdd684d10c6549f85aa95f33dd79463a093 /tensorflow/python/layers | |
parent | cbede3ea7574b36f429710bc08617d08455bcc21 (diff) |
Merge changes from github.
PiperOrigin-RevId: 189945839
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 2 | ||||
-rw-r--r-- | tensorflow/python/layers/normalization.py | 9 |
2 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index e9066d3fda..e4395bea92 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -578,7 +578,7 @@ class Layer(checkpointable.CheckpointableBase): if isinstance(variable, tf_variables.PartitionedVariable): raise RuntimeError( 'Partitioned variable regularization is not yet ' - 'supported when executing eagerly. File a feature request' + 'supported when executing eagerly. File a feature request ' 'if this is important to you.') # Save a zero-argument lambda which runs the regularizer on the # variable, to be executed when `Layer.losses` is requested. diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 11daf01670..29fb92ccb5 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -664,9 +664,16 @@ def batch_normalization(inputs, Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they - need to be added as a dependency to the `train_op`. For example: + need to be added as a dependency to the `train_op`. Also, be sure to add + any batch_normalization ops before getting the update_ops collection. + Otherwise, update_ops will be empty, and training/inference will not work + properly. For example: ```python + x_norm = tf.layers.batch_normalization(x, training=training) + + # ... + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) |