diff options
author | 2017-10-19 15:27:52 -0700 | |
---|---|---|
committer | 2017-10-19 16:05:30 -0700 | |
commit | f080052284a4a39113051fb1178d91365e9872a8 (patch) | |
tree | 2fd7de12aac27e5d0b1e938caa18002f75f6b149 /tensorflow/examples/learn | |
parent | 2cd178ef5a4e5cac27b55729f0203c4864540063 (diff) |
Move text_classification_character_rnn from .contrib utils to .core utils.
Also removes sklearn comparison.
PiperOrigin-RevId: 172808535
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/text_classification_character_rnn.py | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/tensorflow/examples/learn/text_classification_character_rnn.py b/tensorflow/examples/learn/text_classification_character_rnn.py index 1fc9388a1a..86adc056ad 100644 --- a/tensorflow/examples/learn/text_classification_character_rnn.py +++ b/tensorflow/examples/learn/text_classification_character_rnn.py @@ -30,7 +30,6 @@ import sys import numpy as np import pandas -from sklearn import metrics import tensorflow as tf FLAGS = None @@ -46,8 +45,8 @@ def char_rnn_model(features, labels, mode): byte_vectors = tf.one_hot(features[CHARS_FEATURE], 256, 1., 0.) byte_list = tf.unstack(byte_vectors, axis=1) - cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE) - _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32) + cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE) + _, encoding = tf.nn.static_rnn(cell, byte_list, dtype=tf.float32) logits = tf.layers.dense(encoding, MAX_LABEL, activation=None) @@ -98,28 +97,20 @@ def main(unused_argv): train_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_train}, y=y_train, - batch_size=len(x_train), + batch_size=128, num_epochs=None, shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) - # Predict. + # Eval. test_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. scores = classifier.evaluate(input_fn=test_input_fn) - print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) + print('Accuracy: {0:f}'.format(scores['accuracy'])) if __name__ == '__main__': |