aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/skflow/iris_val_based_early_stopping.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/skflow/iris_val_based_early_stopping.py')
-rw-r--r--tensorflow/examples/skflow/iris_val_based_early_stopping.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 72e0595544..05dfa96a07 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -34,21 +34,23 @@ def main(unused_argv):
x_val, y_val, early_stopping_rounds=200)
# classifier with early stopping on training data
- classifier1 = learn.TensorFlowDNNClassifier(
+ classifier1 = learn.DNNClassifier(
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/')
classifier1.fit(x=x_train, y=y_train, steps=2000)
score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test))
# classifier with early stopping on validation data, save frequently for
# monitor to pick up new checkpoints.
- classifier2 = learn.TensorFlowDNNClassifier(
+ classifier2 = learn.DNNClassifier(
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/',
config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor])
score2 = metrics.accuracy_score(y_test, classifier2.predict(x_test))
# In many applications, the score is improved by using early stopping
- print(score2 > score1)
+ print('score1: ', score1)
+ print('score2: ', score2)
+ print('score2 > score1: ', score2 > score1)
if __name__ == '__main__':