diff options
author | Martin Wicke <wicke@google.com> | 2016-12-14 15:46:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-14 16:04:11 -0800 |
commit | 2e4869af1afe55135d522142be3a2a483162a1b1 (patch) | |
tree | c676b228682e796ff0c2896b20faf97b66b8a1f2 /tensorflow/examples/tutorials | |
parent | 811629aed466db32eeefbd60783e199d2fe154a9 (diff) |
Merge changes from github.
Change: 142074581
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/monitors/iris_monitors.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/examples/tutorials/monitors/iris_monitors.py b/tensorflow/examples/tutorials/monitors/iris_monitors.py index e2a46baf48..041592b9b0 100644 --- a/tensorflow/examples/tutorials/monitors/iris_monitors.py +++ b/tensorflow/examples/tutorials/monitors/iris_monitors.py @@ -21,6 +21,7 @@ import os import numpy as np import tensorflow as tf +from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec tf.logging.set_verbosity(tf.logging.INFO) @@ -65,6 +66,26 @@ def main(unused_argv): # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] + validation_metrics = { + "accuracy": MetricSpec( + metric_fn=tf.contrib.metrics.streaming_accuracy, + prediction_key="classes"), + "recall": MetricSpec( + metric_fn=tf.contrib.metrics.streaming_recall, + prediction_key="classes"), + "precision": MetricSpec( + metric_fn=tf.contrib.metrics.streaming_precision, + prediction_key="classes") + } + validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( + test_set.data, + test_set.target, + every_n_steps=50, + metrics=validation_metrics, + early_stopping_metric="loss", + early_stopping_metric_minimize=True, + early_stopping_rounds=200) + # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier( feature_columns=feature_columns, |