aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/skflow/iris_custom_decay_dnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/skflow/iris_custom_decay_dnn.py')
-rw-r--r--tensorflow/examples/skflow/iris_custom_decay_dnn.py43
1 files changed, 24 insertions, 19 deletions
diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
index c1e7d22d53..1ce6a830e4 100644
--- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py
@@ -17,24 +17,29 @@ from __future__ import print_function
from sklearn import datasets, metrics
from sklearn.cross_validation import train_test_split
-
import tensorflow as tf
-iris = datasets.load_iris()
-X_train, X_test, y_train, y_test = train_test_split(iris.data,
- iris.target,
- test_size=0.2,
- random_state=42)
-# setup exponential decay function
-def exp_decay(global_step):
- return tf.train.exponential_decay(
- learning_rate=0.1, global_step=global_step,
- decay_steps=100, decay_rate=0.001)
-
-# use customized decay function in learning_rate
-optimizer = tf.train.AdagradOptimizer(learning_rate=exp_decay)
-classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
- n_classes=3,
- optimizer=optimizer)
-classifier.fit(X_train, y_train, steps=800)
-score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+
+def optimizer_exp_decay():
+ global_step = tf.contrib.framework.get_or_create_global_step()
+ learning_rate = tf.train.exponential_decay(
+ learning_rate=0.1, global_step=global_step,
+ decay_steps=100, decay_rate=0.001)
+ return tf.train.AdagradOptimizer(learning_rate=learning_rate)
+
+def main(unused_argv):
+ iris = datasets.load_iris()
+ x_train, x_test, y_train, y_test = train_test_split(
+ iris.data, iris.target, test_size=0.2, random_state=42)
+
+ classifier = tf.contrib.learn.DNNClassifier(hidden_units=[10, 20, 10],
+ n_classes=3,
+ optimizer=optimizer_exp_decay)
+
+ classifier.fit(x_train, y_train, steps=800)
+ score = metrics.accuracy_score(y_test, classifier.predict(x_test))
+ print('Accuracy: {0:f}'.format(score))
+
+
+if __name__ == '__main__':
+ tf.app.run()