aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Jesse <jessehagenaars@gmail.com>2018-06-05 14:35:38 +0200
committerGravatar GitHub <noreply@github.com>2018-06-05 14:35:38 +0200
commitf2e22502fd58e8d81c9e080b9242375fbf2bc772 (patch)
tree187f4639e7f44299b2fdf412d066b1b7230d7889 /tensorflow/contrib/model_pruning
parent83543deedb68fef61ea7e709de3f462a1edd13ce (diff)
Updated line for creating global step + grammar
tf.train.get_global_step() returns None if there is no global step, preventing the pruning from working. Therefore, tf.train.get_or_create_global_step() is a safer option.
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 86f4fd6adf..50e7e5d7cd 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
### Adding pruning ops to the training graph
-The final step involves adding ops to the training graph that monitors the
-distribution of the layer's weight magnitudes and determines the layer threshold
-such masking all the weights below this threshold achieves the sparsity level
-desired for the current training step. This can be achieved as follows:
+The final step involves adding ops to the training graph that monitor the
+distribution of the layer's weight magnitudes and determine the layer threshold,
+such that masking all the weights below this threshold achieves the sparsity
+level desired for the current training step. This can be achieved as follows:
```python
tf.app.flags.DEFINE_string(
@@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
with tf.graph.as_default():
# Create global step variable
- global_step = tf.train.get_global_step()
+ global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)