aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 15:02:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 15:07:17 -0700
commit5f84a0cb372ca6542385ad0a372496c1f56f7ce6 (patch)
treedc24f54907dabddab5cb974e6018396042b02e26 /tensorflow/examples/learn
parent23e0d5044d6ddcdd14a673d7300a96e735c6663a (diff)
Updates text_classification example.
PiperOrigin-RevId: 160200457
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/text_classification.py144
1 files changed, 90 insertions, 54 deletions
diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py
index 7e10014c39..4b32bfc382 100644
--- a/tensorflow/examples/learn/text_classification.py
+++ b/tensorflow/examples/learn/text_classification.py
@@ -24,43 +24,67 @@ import numpy as np
import pandas
from sklearn import metrics
import tensorflow as tf
-from tensorflow.contrib.layers.python.layers import encoders
-
-learn = tf.contrib.learn
FLAGS = None
MAX_DOCUMENT_LENGTH = 10
EMBEDDING_SIZE = 50
n_words = 0
+MAX_LABEL = 15
+WORDS_FEATURE = 'words' # Name of the input words feature.
+
+
+def estimator_spec_for_softmax_classification(
+ logits, labels, mode):
+ """Returns EstimatorSpec instance for softmax classification."""
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions={
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ })
+
+ onehot_labels = tf.one_hot(labels, MAX_LABEL, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
+
+
+def bag_of_words_model(features, labels, mode):
+ """A bag-of-words model. Note it disregards the word order in the text."""
+ bow_column = tf.feature_column.categorical_column_with_identity(
+ WORDS_FEATURE, num_buckets=n_words)
+ bow_embedding_column = tf.feature_column.embedding_column(
+ bow_column, dimension=EMBEDDING_SIZE)
+ bow = tf.feature_column.input_layer(
+ features,
+ feature_columns=[bow_embedding_column])
+ logits = tf.layers.dense(bow, MAX_LABEL, activation=None)
+ return estimator_spec_for_softmax_classification(
+ logits=logits, labels=labels, mode=mode)
-def bag_of_words_model(features, target):
- """A bag-of-words model. Note it disregards the word order in the text."""
- target = tf.one_hot(target, 15, 1, 0)
- features = encoders.bow_encoder(
- features, vocab_size=n_words, embed_dim=EMBEDDING_SIZE)
- logits = tf.contrib.layers.fully_connected(features, 15, activation_fn=None)
- loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adam',
- learning_rate=0.01)
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
-
-
-def rnn_model(features, target):
+
+def rnn_model(features, labels, mode):
"""RNN model to predict from sequence of words to a class."""
# Convert indexes of words into embeddings.
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
word_vectors = tf.contrib.layers.embed_sequence(
- features, vocab_size=n_words, embed_dim=EMBEDDING_SIZE, scope='words')
+ features[WORDS_FEATURE], vocab_size=n_words, embed_dim=EMBEDDING_SIZE)
# Split into list of embedding per word, while removing doc length dim.
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE].
@@ -74,29 +98,17 @@ def rnn_model(features, target):
_, encoding = tf.contrib.rnn.static_rnn(cell, word_list, dtype=tf.float32)
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
- # neural network of last step) and pass it as features for logistic
- # regression over output classes.
- target = tf.one_hot(target, 15, 1, 0)
- logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
- loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
-
- # Create a training op.
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adam',
- learning_rate=0.01)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ # neural network of last step) and pass it as features for softmax
+ # classification over output classes.
+ logits = tf.layers.dense(encoding, MAX_LABEL, activation=None)
+ return estimator_spec_for_softmax_classification(
+ logits=logits, labels=labels, mode=mode)
def main(unused_argv):
global n_words
# Prepare training and testing data
- dbpedia = learn.datasets.load_dataset(
+ dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
x_train = pandas.DataFrame(dbpedia.train.data)[1]
y_train = pandas.Series(dbpedia.train.target)
@@ -104,14 +116,15 @@ def main(unused_argv):
y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary
- vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
-
+ vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(
+ MAX_DOCUMENT_LENGTH)
+
x_transform_train = vocab_processor.fit_transform(x_train)
x_transform_test = vocab_processor.transform(x_test)
-
+
x_train = np.array(list(x_transform_train))
x_test = np.array(list(x_transform_test))
-
+
n_words = len(vocab_processor.vocabulary_)
print('Total words: %d' % n_words)
@@ -119,17 +132,40 @@ def main(unused_argv):
# Switch between rnn_model and bag_of_words_model to test different models.
model_fn = rnn_model
if FLAGS.bow_model:
+ # Subtract 1 because VocabularyProcessor outputs a word-id matrix where word
+ # ids start from 1 and 0 means 'no word'. But
+ # categorical_column_with_identity assumes 0-based count and uses -1 for
+ # missing word.
+ x_train -= 1
+ x_test -= 1
model_fn = bag_of_words_model
- classifier = learn.Estimator(model_fn=model_fn)
-
- # Train and predict
- classifier.fit(x_train, y_train, steps=100)
- y_predicted = [
- p['class'] for p in classifier.predict(
- x_test, as_iterable=True)
- ]
+ classifier = tf.estimator.Estimator(model_fn=model_fn)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={WORDS_FEATURE: x_train},
+ y=y_train,
+ num_epochs=None,
+ shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={WORDS_FEATURE: x_test},
+ y=y_test,
+ num_epochs=1,
+ shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':