aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/skflow/text_classification_builtin_rnn_model.py')
-rw-r--r--tensorflow/examples/skflow/text_classification_builtin_rnn_model.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
new file mode 100644
index 0000000000..239aa48d9c
--- /dev/null
+++ b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py
@@ -0,0 +1,73 @@
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division, print_function, absolute_import
+
+import numpy as np
+from sklearn import metrics
+import pandas
+
+import tensorflow as tf
+from tensorflow.contrib import skflow
+
+### Training data
+
+# Download dbpedia_csv.tar.gz from
+# https://drive.google.com/folderview?id=0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M
+# Unpack: tar -xvf dbpedia_csv.tar.gz
+
+train = pandas.read_csv('dbpedia_csv/train.csv', header=None)
+X_train, y_train = train[2], train[0]
+test = pandas.read_csv('dbpedia_csv/test.csv', header=None)
+X_test, y_test = test[2], test[0]
+
+### Process vocabulary
+
+MAX_DOCUMENT_LENGTH = 10
+
+vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+X_train = np.array(list(vocab_processor.fit_transform(X_train)))
+X_test = np.array(list(vocab_processor.transform(X_test)))
+
+n_words = len(vocab_processor.vocabulary_)
+print('Total words: %d' % n_words)
+
+### Models
+
+EMBEDDING_SIZE = 50
+
+# Customized function to transform batched X into embeddings
+def input_op_fn(X):
+ # Convert indexes of words into embeddings.
+ # This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
+ # maps word indexes of the sequence into [batch_size, sequence_length,
+ # EMBEDDING_SIZE].
+ word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
+ embedding_size=EMBEDDING_SIZE, name='words')
+ # Split into list of embedding per word, while removing doc length dim.
+ # word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
+ word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors)
+ return word_list
+
+# Single direction GRU with a single layer
+classifier = skflow.TensorFlowRNNClassifier(rnn_size=EMBEDDING_SIZE,
+ n_classes=15, cell_type='gru', input_op_fn=input_op_fn,
+ num_layers=1, bidirectional=False, sequence_length=None,
+ steps=1000, optimizer='Adam', learning_rate=0.01, continue_training=True)
+
+# Continously train for 1000 steps & predict on test set.
+while True:
+ classifier.fit(X_train, y_train, logdir='/tmp/tf_examples/word_rnn')
+ score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+ print('Accuracy: {0:f}'.format(score))