diff options
Diffstat (limited to 'tensorflow/examples/tutorials/monitors/iris_monitors.py')
-rw-r--r-- | tensorflow/examples/tutorials/monitors/iris_monitors.py | 30 |
1 files changed, 3 insertions, 27 deletions
diff --git a/tensorflow/examples/tutorials/monitors/iris_monitors.py b/tensorflow/examples/tutorials/monitors/iris_monitors.py index a4bf353856..850d105f7b 100644 --- a/tensorflow/examples/tutorials/monitors/iris_monitors.py +++ b/tensorflow/examples/tutorials/monitors/iris_monitors.py @@ -21,7 +21,6 @@ import os import numpy as np import tensorflow as tf -from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec tf.logging.set_verbosity(tf.logging.INFO) @@ -41,18 +40,15 @@ def main(unused_argv): "accuracy": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_accuracy, - prediction_key= - tf.contrib.learn.prediction_key.PredictionKey.CLASSES), + prediction_key="classes"), "precision": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_precision, - prediction_key= - tf.contrib.learn.prediction_key.PredictionKey.CLASSES), + prediction_key="classes"), "recall": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_recall, - prediction_key= - tf.contrib.learn.prediction_key.PredictionKey.CLASSES) + prediction_key="classes") } validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( test_set.data, @@ -66,26 +62,6 @@ def main(unused_argv): # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] - validation_metrics = { - "accuracy": MetricSpec( - metric_fn=tf.contrib.metrics.streaming_accuracy, - prediction_key="classes"), - "recall": MetricSpec( - metric_fn=tf.contrib.metrics.streaming_recall, - prediction_key="classes"), - "precision": MetricSpec( - metric_fn=tf.contrib.metrics.streaming_precision, - prediction_key="classes") - } - validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( - test_set.data, - test_set.target, - every_n_steps=50, - metrics=validation_metrics, - early_stopping_metric="loss", - early_stopping_metric_minimize=True, - early_stopping_rounds=200) - # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier( feature_columns=feature_columns, |