diff options
Diffstat (limited to 'tensorflow/examples/skflow/iris_custom_decay_dnn.py')
-rw-r--r-- | tensorflow/examples/skflow/iris_custom_decay_dnn.py | 43 |
1 files changed, 24 insertions, 19 deletions
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py index c1e7d22d53..1ce6a830e4 100644 --- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py +++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py @@ -17,24 +17,29 @@ from __future__ import print_function from sklearn import datasets, metrics from sklearn.cross_validation import train_test_split - import tensorflow as tf -iris = datasets.load_iris() -X_train, X_test, y_train, y_test = train_test_split(iris.data, - iris.target, - test_size=0.2, - random_state=42) -# setup exponential decay function -def exp_decay(global_step): - return tf.train.exponential_decay( - learning_rate=0.1, global_step=global_step, - decay_steps=100, decay_rate=0.001) - -# use customized decay function in learning_rate -optimizer = tf.train.AdagradOptimizer(learning_rate=exp_decay) -classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], - n_classes=3, - optimizer=optimizer) -classifier.fit(X_train, y_train, steps=800) -score = metrics.accuracy_score(y_test, classifier.predict(X_test)) + +def optimizer_exp_decay(): + global_step = tf.contrib.framework.get_or_create_global_step() + learning_rate = tf.train.exponential_decay( + learning_rate=0.1, global_step=global_step, + decay_steps=100, decay_rate=0.001) + return tf.train.AdagradOptimizer(learning_rate=learning_rate) + +def main(unused_argv): + iris = datasets.load_iris() + x_train, x_test, y_train, y_test = train_test_split( + iris.data, iris.target, test_size=0.2, random_state=42) + + classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10], + n_classes=3, + optimizer=optimizer_exp_decay) + + classifier.fit(x_train, y_train, steps=800) + score = metrics.accuracy_score(y_test, classifier.predict(x_test)) + print('Accuracy: {0:f}'.format(score)) + + +if __name__ == '__main__': + tf.app.run() |