diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-10-19 11:22:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-19 11:26:14 -0700 |
commit | 31587244e4821fbb4eebcf7847281a1df3da6a2a (patch) | |
tree | 63e80db62092ac0913228ee354bf3845c5cdad71 /tensorflow/examples/learn | |
parent | fb7892e6d0d749251415fe308c618667058c1c7c (diff) |
Remove sklearn from text_classification_cnn.py.
PiperOrigin-RevId: 172772457
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/text_classification_cnn.py | 13 |
1 files changed, 2 insertions, 11 deletions
diff --git a/tensorflow/examples/learn/text_classification_cnn.py b/tensorflow/examples/learn/text_classification_cnn.py index 0ee2405c8b..be262285a3 100644 --- a/tensorflow/examples/learn/text_classification_cnn.py +++ b/tensorflow/examples/learn/text_classification_cnn.py @@ -22,7 +22,6 @@ import sys import numpy as np import pandas -from sklearn import metrics import tensorflow as tf FLAGS = None @@ -134,23 +133,15 @@ def main(unused_argv): shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) - # Predict. + # Evaluate. test_input_fn = tf.estimator.inputs.numpy_input_fn( x={WORDS_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) - predictions = classifier.predict(input_fn=test_input_fn) - y_predicted = np.array(list(p['class'] for p in predictions)) - y_predicted = y_predicted.reshape(np.array(y_test).shape) - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - - # Score with tensorflow. scores = classifier.evaluate(input_fn=test_input_fn) - print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) + print('Accuracy: {0:f}'.format(scores['accuracy'])) if __name__ == '__main__': |