aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-27 11:49:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 11:53:04 -0700
commit47f28360d2be7979571b4a2d9118651d36eedabd (patch)
tree5be976421c44f159e7d9b6249fd69afc38a674a8 /tensorflow/examples/learn
parentf0c3cbfc9b574245f6998756a12d804bedc08fd4 (diff)
Updates more text classification examples in examples/learn.
PiperOrigin-RevId: 160305131
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/text_classification.py1
-rw-r--r--tensorflow/examples/learn/text_classification_character_cnn.py121
-rw-r--r--tensorflow/examples/learn/text_classification_character_rnn.py92
-rw-r--r--tensorflow/examples/learn/text_classification_cnn.py109
4 files changed, 216 insertions, 107 deletions
diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py
index 4b32bfc382..21d98e9ea2 100644
--- a/tensorflow/examples/learn/text_classification.py
+++ b/tensorflow/examples/learn/text_classification.py
@@ -145,6 +145,7 @@ def main(unused_argv):
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={WORDS_FEATURE: x_train},
y=y_train,
+ batch_size=len(x_train),
num_epochs=None,
shuffle=True)
classifier.train(input_fn=train_input_fn, steps=100)
diff --git a/tensorflow/examples/learn/text_classification_character_cnn.py b/tensorflow/examples/learn/text_classification_character_cnn.py
index 5ad53acf9f..5f7c8e7371 100644
--- a/tensorflow/examples/learn/text_classification_character_cnn.py
+++ b/tensorflow/examples/learn/text_classification_character_cnn.py
@@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This is an example of using convolutional networks over characters for
- DBpedia dataset to predict class from description of an entity.
+"""Example of using convolutional networks over characters for DBpedia dataset.
This model is similar to one described in this paper:
"Character-level Convolutional Networks for Text Classification"
@@ -34,8 +33,6 @@ import pandas
from sklearn import metrics
import tensorflow as tf
-learn = tf.contrib.learn
-
FLAGS = None
MAX_DOCUMENT_LENGTH = 100
@@ -44,53 +41,73 @@ FILTER_SHAPE1 = [20, 256]
FILTER_SHAPE2 = [20, N_FILTERS]
POOLING_WINDOW = 4
POOLING_STRIDE = 2
+MAX_LABEL = 15
+CHARS_FEATURE = 'chars' # Name of the input character feature.
-def char_cnn_model(features, target):
+def char_cnn_model(features, labels, mode):
"""Character level convolutional neural network model to predict classes."""
- target = tf.one_hot(target, 15, 1, 0)
- byte_list = tf.reshape(
- tf.one_hot(features, 256), [-1, MAX_DOCUMENT_LENGTH, 256, 1])
+ features_onehot = tf.one_hot(features[CHARS_FEATURE], 256)
+ input_layer = tf.reshape(
+ features_onehot, [-1, MAX_DOCUMENT_LENGTH, 256, 1])
with tf.variable_scope('CNN_Layer1'):
# Apply Convolution filtering on input sequence.
- conv1 = tf.contrib.layers.convolution2d(
- byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID')
- # Add a ReLU for non linearity.
- conv1 = tf.nn.relu(conv1)
+ conv1 = tf.layers.conv2d(
+ input_layer,
+ filters=N_FILTERS,
+ kernel_size=FILTER_SHAPE1,
+ padding='VALID',
+ # Add a ReLU for non linearity.
+ activation=tf.nn.relu)
# Max pooling across output of Convolution+Relu.
- pool1 = tf.nn.max_pool(
+ pool1 = tf.layers.max_pooling2d(
conv1,
- ksize=[1, POOLING_WINDOW, 1, 1],
- strides=[1, POOLING_STRIDE, 1, 1],
+ pool_size=POOLING_WINDOW,
+ strides=POOLING_STRIDE,
padding='SAME')
# Transpose matrix so that n_filters from convolution becomes width.
pool1 = tf.transpose(pool1, [0, 1, 3, 2])
with tf.variable_scope('CNN_Layer2'):
# Second level of convolution filtering.
- conv2 = tf.contrib.layers.convolution2d(
- pool1, N_FILTERS, FILTER_SHAPE2, padding='VALID')
+ conv2 = tf.layers.conv2d(
+ pool1,
+ filters=N_FILTERS,
+ kernel_size=FILTER_SHAPE2,
+ padding='VALID')
# Max across each filter to get useful features for classification.
pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
# Apply regular WX + B and classification.
- logits = tf.contrib.layers.fully_connected(pool2, 15, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
-
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adam',
- learning_rate=0.01)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ logits = tf.layers.dense(pool2, MAX_LABEL, activation=None)
+
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions={
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ })
+
+ onehot_labels = tf.one_hot(labels, MAX_LABEL, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
# Prepare training and testing data
- dbpedia = learn.datasets.load_dataset(
+ dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data, size='large')
x_train = pandas.DataFrame(dbpedia.train.data)[1]
y_train = pandas.Series(dbpedia.train.target)
@@ -98,21 +115,43 @@ def main(unused_argv):
y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary
- char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
+ char_processor = tf.contrib.learn.preprocessing.ByteProcessor(
+ MAX_DOCUMENT_LENGTH)
x_train = np.array(list(char_processor.fit_transform(x_train)))
x_test = np.array(list(char_processor.transform(x_test)))
+ x_train = x_train.reshape([-1, MAX_DOCUMENT_LENGTH, 1, 1])
+ x_test = x_test.reshape([-1, MAX_DOCUMENT_LENGTH, 1, 1])
+
# Build model
- classifier = learn.Estimator(model_fn=char_cnn_model)
-
- # Train and predict
- classifier.fit(x_train, y_train, steps=100)
- y_predicted = [
- p['class'] for p in classifier.predict(
- x_test, as_iterable=True)
- ]
+ classifier = tf.estimator.Estimator(model_fn=char_cnn_model)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={CHARS_FEATURE: x_train},
+ y=y_train,
+ batch_size=len(x_train),
+ num_epochs=None,
+ shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={CHARS_FEATURE: x_test},
+ y=y_test,
+ num_epochs=1,
+ shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/text_classification_character_rnn.py b/tensorflow/examples/learn/text_classification_character_rnn.py
index 1cb2cd2f88..1fc9388a1a 100644
--- a/tensorflow/examples/learn/text_classification_character_rnn.py
+++ b/tensorflow/examples/learn/text_classification_character_rnn.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""This is an example of using recurrent neural networks over characters for DBpedia dataset to predict class from description of an entity.
+"""Example of recurrent neural networks over characters for DBpedia dataset.
This model is similar to one described in this paper:
"Character-level Convolutional Networks for Text Classification"
@@ -33,41 +33,52 @@ import pandas
from sklearn import metrics
import tensorflow as tf
-learn = tf.contrib.learn
-
FLAGS = None
MAX_DOCUMENT_LENGTH = 100
HIDDEN_SIZE = 20
+MAX_LABEL = 15
+CHARS_FEATURE = 'chars' # Name of the input character feature.
-def char_rnn_model(features, target):
+def char_rnn_model(features, labels, mode):
"""Character level recurrent neural network model to predict classes."""
- target = tf.one_hot(target, 15, 1, 0)
- byte_list = tf.one_hot(features, 256, 1, 0)
- byte_list = tf.unstack(byte_list, axis=1)
+ byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.)
+ byte_list = tf.unstack(byte_vectors, axis=1)
cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
_, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)
- logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
- loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
-
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adam',
- learning_rate=0.01)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ logits = tf.layers.dense(encoding, MAX_LABEL, activation=None)
+
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions={
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ })
+
+ onehot_labels = tf.one_hot(labels, MAX_LABEL, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
# Prepare training and testing data
- dbpedia = learn.datasets.load_dataset(
+ dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
x_train = pandas.DataFrame(dbpedia.train.data)[1]
y_train = pandas.Series(dbpedia.train.target)
@@ -75,21 +86,40 @@ def main(unused_argv):
y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary
- char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
+ char_processor = tf.contrib.learn.preprocessing.ByteProcessor(
+ MAX_DOCUMENT_LENGTH)
x_train = np.array(list(char_processor.fit_transform(x_train)))
x_test = np.array(list(char_processor.transform(x_test)))
# Build model
- classifier = learn.Estimator(model_fn=char_rnn_model)
-
- # Train and predict
- classifier.fit(x_train, y_train, steps=100)
- y_predicted = [
- p['class'] for p in classifier.predict(
- x_test, as_iterable=True)
- ]
+ classifier = tf.estimator.Estimator(model_fn=char_rnn_model)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={CHARS_FEATURE: x_train},
+ y=y_train,
+ batch_size=len(x_train),
+ num_epochs=None,
+ shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={CHARS_FEATURE: x_test},
+ y=y_test,
+ num_epochs=1,
+ shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':
diff --git a/tensorflow/examples/learn/text_classification_cnn.py b/tensorflow/examples/learn/text_classification_cnn.py
index 468a96b58f..0ee2405c8b 100644
--- a/tensorflow/examples/learn/text_classification_cnn.py
+++ b/tensorflow/examples/learn/text_classification_cnn.py
@@ -25,8 +25,6 @@ import pandas
from sklearn import metrics
import tensorflow as tf
-learn = tf.contrib.learn
-
FLAGS = None
MAX_DOCUMENT_LENGTH = 100
@@ -38,59 +36,78 @@ FILTER_SHAPE2 = [WINDOW_SIZE, N_FILTERS]
POOLING_WINDOW = 4
POOLING_STRIDE = 2
n_words = 0
+MAX_LABEL = 15
+WORDS_FEATURE = 'words' # Name of the input words feature.
-def cnn_model(features, target):
+def cnn_model(features, labels, mode):
"""2 layer ConvNet to predict from sequence of words to a class."""
# Convert indexes of words into embeddings.
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then
# maps word indexes of the sequence into [batch_size, sequence_length,
# EMBEDDING_SIZE].
- target = tf.one_hot(target, 15, 1, 0)
word_vectors = tf.contrib.layers.embed_sequence(
- features, vocab_size=n_words, embed_dim=EMBEDDING_SIZE, scope='words')
+ features[WORDS_FEATURE], vocab_size=n_words, embed_dim=EMBEDDING_SIZE)
word_vectors = tf.expand_dims(word_vectors, 3)
with tf.variable_scope('CNN_Layer1'):
# Apply Convolution filtering on input sequence.
- conv1 = tf.contrib.layers.convolution2d(
- word_vectors, N_FILTERS, FILTER_SHAPE1, padding='VALID')
- # Add a RELU for non linearity.
- conv1 = tf.nn.relu(conv1)
+ conv1 = tf.layers.conv2d(
+ word_vectors,
+ filters=N_FILTERS,
+ kernel_size=FILTER_SHAPE1,
+ padding='VALID',
+ # Add a ReLU for non linearity.
+ activation=tf.nn.relu)
# Max pooling across output of Convolution+Relu.
- pool1 = tf.nn.max_pool(
+ pool1 = tf.layers.max_pooling2d(
conv1,
- ksize=[1, POOLING_WINDOW, 1, 1],
- strides=[1, POOLING_STRIDE, 1, 1],
+ pool_size=POOLING_WINDOW,
+ strides=POOLING_STRIDE,
padding='SAME')
# Transpose matrix so that n_filters from convolution becomes width.
pool1 = tf.transpose(pool1, [0, 1, 3, 2])
with tf.variable_scope('CNN_Layer2'):
# Second level of convolution filtering.
- conv2 = tf.contrib.layers.convolution2d(
- pool1, N_FILTERS, FILTER_SHAPE2, padding='VALID')
+ conv2 = tf.layers.conv2d(
+ pool1,
+ filters=N_FILTERS,
+ kernel_size=FILTER_SHAPE2,
+ padding='VALID')
# Max across each filter to get useful features for classification.
pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
# Apply regular WX + B and classification.
- logits = tf.contrib.layers.fully_connected(pool2, 15, activation_fn=None)
- loss = tf.losses.softmax_cross_entropy(target, logits)
-
- train_op = tf.contrib.layers.optimize_loss(
- loss,
- tf.contrib.framework.get_global_step(),
- optimizer='Adam',
- learning_rate=0.01)
-
- return ({
- 'class': tf.argmax(logits, 1),
- 'prob': tf.nn.softmax(logits)
- }, loss, train_op)
+ logits = tf.layers.dense(pool2, MAX_LABEL, activation=None)
+
+ predicted_classes = tf.argmax(logits, 1)
+ if mode == tf.estimator.ModeKeys.PREDICT:
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions={
+ 'class': predicted_classes,
+ 'prob': tf.nn.softmax(logits)
+ })
+
+ onehot_labels = tf.one_hot(labels, MAX_LABEL, 1, 0)
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ eval_metric_ops = {
+ 'accuracy': tf.metrics.accuracy(
+ labels=labels, predictions=predicted_classes)
+ }
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
def main(unused_argv):
global n_words
# Prepare training and testing data
- dbpedia = learn.datasets.load_dataset(
+ dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
x_train = pandas.DataFrame(dbpedia.train.data)[1]
y_train = pandas.Series(dbpedia.train.target)
@@ -98,20 +115,42 @@ def main(unused_argv):
y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary
- vocab_processor = learn.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH)
+ vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(
+ MAX_DOCUMENT_LENGTH)
x_train = np.array(list(vocab_processor.fit_transform(x_train)))
x_test = np.array(list(vocab_processor.transform(x_test)))
n_words = len(vocab_processor.vocabulary_)
print('Total words: %d' % n_words)
# Build model
- classifier = learn.SKCompat(learn.Estimator(model_fn=cnn_model))
-
- # Train and predict
- classifier.fit(x_train, y_train, steps=100)
- y_predicted = classifier.predict(x_test)['class']
+ classifier = tf.estimator.Estimator(model_fn=cnn_model)
+
+ # Train.
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={WORDS_FEATURE: x_train},
+ y=y_train,
+ batch_size=len(x_train),
+ num_epochs=None,
+ shuffle=True)
+ classifier.train(input_fn=train_input_fn, steps=100)
+
+ # Predict.
+ test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x={WORDS_FEATURE: x_test},
+ y=y_test,
+ num_epochs=1,
+ shuffle=False)
+ predictions = classifier.predict(input_fn=test_input_fn)
+ y_predicted = np.array(list(p['class'] for p in predictions))
+ y_predicted = y_predicted.reshape(np.array(y_test).shape)
+
+ # Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy: {0:f}'.format(score))
+ print('Accuracy (sklearn): {0:f}'.format(score))
+
+ # Score with tensorflow.
+ scores = classifier.evaluate(input_fn=test_input_fn)
+ print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':