aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-18 17:26:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 17:29:59 -0700
commitb1cf67b0f6b9c600a03bbdc3eec4fd7b2b6d2deb (patch)
tree2bf66bcc7f1ed030d6dc6f820c46998061489d97 /tensorflow/examples/learn
parentbba3957467ad8ba9351b829036120412d5d006cb (diff)
Migrate text_classification.py from .contrib utils to .core.
Some usages are left untouched: datasets, VocabularyProcessor. A tracking bug is filed for embed_sequence. Tested by re-running and the loss numbers look similar to the ones before the change. PiperOrigin-RevId: 172681096
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r--tensorflow/examples/learn/text_classification.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py
index 26e6e086b3..ba89c532be 100644
--- a/tensorflow/examples/learn/text_classification.py
+++ b/tensorflow/examples/learn/text_classification.py
@@ -91,11 +91,11 @@ def rnn_model(features, labels, mode):
word_list = tf.unstack(word_vectors, axis=1)
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE.
- cell = tf.contrib.rnn.GRUCell(EMBEDDING_SIZE)
+ cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE)
# Create an unrolled Recurrent Neural Networks to length of
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit.
- _, encoding = tf.contrib.rnn.static_rnn(cell, word_list, dtype=tf.float32)
+ _, encoding = tf.nn.static_rnn(cell, word_list, dtype=tf.float32)
# Given encoding of RNN, take encoding of last step (e.g hidden size of the
# neural network of last step) and pass it as features for softmax
@@ -107,6 +107,8 @@ def rnn_model(features, labels, mode):
def main(unused_argv):
global n_words
+ tf.logging.set_verbosity(tf.logging.INFO)
+
# Prepare training and testing data
dbpedia = tf.contrib.learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)