aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-07-07 14:33:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-07 14:38:06 -0700
commit204c367ab1a38cee71dac2a64164e96abaeffbf2 (patch)
tree62837fea81b98e53a3474d9bd96974625c4b8135 /tensorflow/examples/tutorials
parente729dd303a80db6fb628a2e63f15aa12c28c956b (diff)
Fix KeyError when looking for 'softmax_tensor' in the layers tutorial.
PiperOrigin-RevId: 161247019
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/layers/cnn_mnist.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/examples/tutorials/layers/cnn_mnist.py b/tensorflow/examples/tutorials/layers/cnn_mnist.py
index f92277dac7..2124843fcb 100644
--- a/tensorflow/examples/tutorials/layers/cnn_mnist.py
+++ b/tensorflow/examples/tutorials/layers/cnn_mnist.py
@@ -86,13 +86,14 @@ def cnn_model_fn(features, labels, mode):
# Output Tensor Shape: [batch_size, 10]
logits = tf.layers.dense(inputs=dropout, units=10)
- # Generate Predictions (for PREDICT mode)
- predicted_classes = tf.argmax(input=logits, axis=1)
+ predictions = {
+ # Generate predictions (for PREDICT and EVAL mode)
+ "classes": tf.argmax(input=logits, axis=1),
+ # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
+ # `logging_hook`.
+ "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
+ }
if mode == tf.estimator.ModeKeys.PREDICT:
- predictions = {
- "classes": predicted_classes,
- "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
- }
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Calculate Loss (for both TRAIN and EVAL modes)
@@ -111,7 +112,7 @@ def cnn_model_fn(features, labels, mode):
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
- labels=labels, predictions=predicted_classes)}
+ labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)