aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-19 11:22:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-19 11:26:14 -0700
commit31587244e4821fbb4eebcf7847281a1df3da6a2a (patch)
tree63e80db62092ac0913228ee354bf3845c5cdad71 /tensorflow/examples/learn
parentfb7892e6d0d749251415fe308c618667058c1c7c (diff)
Remove sklearn from text_classification_cnn.py.
PiperOrigin-RevId: 172772457
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/text_classification_cnn.py13
1 files changed, 2 insertions, 11 deletions
diff --git a/tensorflow/examples/learn/text_classification_cnn.py b/tensorflow/examples/learn/text_classification_cnn.py
index 0ee2405c8b..be262285a3 100644
--- a/tensorflow/examples/learn/text_classification_cnn.py
+++ b/tensorflow/examples/learn/text_classification_cnn.py
@@ -22,7 +22,6 @@ import sys
import numpy as np
import pandas
-from sklearn import metrics
import tensorflow as tf
FLAGS = None
@@ -134,23 +133,15 @@ def main(unused_argv):
shuffle=True)
classifier.train(input_fn=train_input_fn, steps=100)
- # Predict.
+ # Evaluate.
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={WORDS_FEATURE: x_test},
y=y_test,
num_epochs=1,
shuffle=False)
- predictions = classifier.predict(input_fn=test_input_fn)
- y_predicted = np.array(list(p['class'] for p in predictions))
- y_predicted = y_predicted.reshape(np.array(y_test).shape)
- # Score with sklearn.
- score = metrics.accuracy_score(y_test, y_predicted)
- print('Accuracy (sklearn): {0:f}'.format(score))
-
- # Score with tensorflow.
scores = classifier.evaluate(input_fn=test_input_fn)
- print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
+ print('Accuracy: {0:f}'.format(scores['accuracy']))
if __name__ == '__main__':