diff options
author | Jesse <jessehagenaars@gmail.com> | 2018-06-05 14:35:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-05 14:35:38 +0200 |
commit | f2e22502fd58e8d81c9e080b9242375fbf2bc772 (patch) | |
tree | 187f4639e7f44299b2fdf412d066b1b7230d7889 /tensorflow/contrib/model_pruning | |
parent | 83543deedb68fef61ea7e709de3f462a1edd13ce (diff) |
Updated line for creating global step + grammar
tf.train.get_global_step() returns None if there is no global step, preventing the pruning from working. Therefore, tf.train.get_or_create_global_step() is a safer option.
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/README.md | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 86f4fd6adf..50e7e5d7cd 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3. ### Adding pruning ops to the training graph -The final step involves adding ops to the training graph that monitors the -distribution of the layer's weight magnitudes and determines the layer threshold -such masking all the weights below this threshold achieves the sparsity level -desired for the current training step. This can be achieved as follows: +The final step involves adding ops to the training graph that monitor the +distribution of the layer's weight magnitudes and determine the layer threshold, +such that masking all the weights below this threshold achieves the sparsity +level desired for the current training step. This can be achieved as follows: ```python tf.app.flags.DEFINE_string( @@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string( with tf.graph.as_default(): # Create global step variable - global_step = tf.train.get_global_step() + global_step = tf.train.get_or_create_global_step() # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) |