diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-10-18 20:28:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 20:31:48 -0700 |
commit | b8f8a3d3660c75a4034bbe69a766d481638a6a4e (patch) | |
tree | 2a42055a2c692039e4c09a844133b7029d193e63 /tensorflow/examples/learn | |
parent | 6c297fa9d5a0add0e38aceaceb57b0c6d83e0aca (diff) |
Fix the build file and a memory issue for text_classification_character_cnn.py.
PiperOrigin-RevId: 172695522
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/text_classification_character_cnn.py | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/tensorflow/examples/learn/text_classification_character_cnn.py b/tensorflow/examples/learn/text_classification_character_cnn.py index 5f7c8e7371..363ff00362 100644 --- a/tensorflow/examples/learn/text_classification_character_cnn.py +++ b/tensorflow/examples/learn/text_classification_character_cnn.py @@ -30,7 +30,6 @@ import sys import numpy as np import pandas -from sklearn import metrics import tensorflow as tf FLAGS = None @@ -106,6 +105,8 @@ def char_cnn_model(features, labels, mode): def main(unused_argv): + tf.logging.set_verbosity(tf.logging.INFO) + # Prepare training and testing data dbpedia = tf.contrib.learn.datasets.load_dataset( 'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data, size='large') @@ -130,7 +131,7 @@ def main(unused_argv): train_input_fn = tf.estimator.inputs.numpy_input_fn( x={CHARS_FEATURE: x_train}, y=y_train, - batch_size=len(x_train), + batch_size=128, num_epochs=None, shuffle=True) classifier.train(input_fn=train_input_fn, steps=100) @@ -145,13 +146,9 @@ def main(unused_argv): y_predicted = np.array(list(p['class'] for p in predictions)) y_predicted = y_predicted.reshape(np.array(y_test).shape) - # Score with sklearn. - score = metrics.accuracy_score(y_test, y_predicted) - print('Accuracy (sklearn): {0:f}'.format(score)) - # Score with tensorflow. scores = classifier.evaluate(input_fn=test_input_fn) - print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) + print('Accuracy: {0:f}'.format(scores['accuracy'])) if __name__ == '__main__': |