diff options
Diffstat (limited to 'tensorflow/examples/skflow/iris_val_based_early_stopping.py')
-rw-r--r-- | tensorflow/examples/skflow/iris_val_based_early_stopping.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py index 72e0595544..05dfa96a07 100644 --- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py +++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py @@ -34,21 +34,23 @@ def main(unused_argv): x_val, y_val, early_stopping_rounds=200) # classifier with early stopping on training data - classifier1 = learn.TensorFlowDNNClassifier( + classifier1 = learn.DNNClassifier( hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/') classifier1.fit(x=x_train, y=y_train, steps=2000) score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test)) # classifier with early stopping on validation data, save frequently for # monitor to pick up new checkpoints. - classifier2 = learn.TensorFlowDNNClassifier( + classifier2 = learn.DNNClassifier( hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/', config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1)) classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor]) score2 = metrics.accuracy_score(y_test, classifier2.predict(x_test)) # In many applications, the score is improved by using early stopping - print(score2 > score1) + print('score1: ', score1) + print('score2: ', score2) + print('score2 > score1: ', score2 > score1) if __name__ == '__main__': |