diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-07-07 14:33:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-07 14:38:06 -0700 |
commit | 204c367ab1a38cee71dac2a64164e96abaeffbf2 (patch) | |
tree | 62837fea81b98e53a3474d9bd96974625c4b8135 /tensorflow/examples/tutorials | |
parent | e729dd303a80db6fb628a2e63f15aa12c28c956b (diff) |
Fix KeyError when looking for 'softmax_tensor' in the layers tutorial.
PiperOrigin-RevId: 161247019
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/layers/cnn_mnist.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/examples/tutorials/layers/cnn_mnist.py b/tensorflow/examples/tutorials/layers/cnn_mnist.py index f92277dac7..2124843fcb 100644 --- a/tensorflow/examples/tutorials/layers/cnn_mnist.py +++ b/tensorflow/examples/tutorials/layers/cnn_mnist.py @@ -86,13 +86,14 @@ def cnn_model_fn(features, labels, mode): # Output Tensor Shape: [batch_size, 10] logits = tf.layers.dense(inputs=dropout, units=10) - # Generate Predictions (for PREDICT mode) - predicted_classes = tf.argmax(input=logits, axis=1) + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by the + # `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") + } if mode == tf.estimator.ModeKeys.PREDICT: - predictions = { - "classes": predicted_classes, - "probabilities": tf.nn.softmax(logits, name="softmax_tensor") - } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate Loss (for both TRAIN and EVAL modes) @@ -111,7 +112,7 @@ def cnn_model_fn(features, labels, mode): # Add evaluation metrics (for EVAL mode) eval_metric_ops = { "accuracy": tf.metrics.accuracy( - labels=labels, predictions=predicted_classes)} + labels=labels, predictions=predictions["classes"])} return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) |