diff options
author | Igor Saprykin <isaprykin@google.com> | 2017-10-18 17:26:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-18 17:29:59 -0700 |
commit | b1cf67b0f6b9c600a03bbdc3eec4fd7b2b6d2deb (patch) | |
tree | 2bf66bcc7f1ed030d6dc6f820c46998061489d97 /tensorflow/examples/learn | |
parent | bba3957467ad8ba9351b829036120412d5d006cb (diff) |
Migrate text_classification.py from .contrib utils to .core.
Some usages are left untouched: datasets, VocabularyProcessor. A tracking bug is filed for embed_sequence.
Tested by re-running and the loss numbers look similar to the ones before the change.
PiperOrigin-RevId: 172681096
Diffstat (limited to 'tensorflow/examples/learn')
-rw-r--r-- | tensorflow/examples/learn/text_classification.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/examples/learn/text_classification.py b/tensorflow/examples/learn/text_classification.py index 26e6e086b3..ba89c532be 100644 --- a/tensorflow/examples/learn/text_classification.py +++ b/tensorflow/examples/learn/text_classification.py @@ -91,11 +91,11 @@ def rnn_model(features, labels, mode): word_list = tf.unstack(word_vectors, axis=1) # Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE. - cell = tf.contrib.rnn.GRUCell(EMBEDDING_SIZE) + cell = tf.nn.rnn_cell.GRUCell(EMBEDDING_SIZE) # Create an unrolled Recurrent Neural Networks to length of # MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit. - _, encoding = tf.contrib.rnn.static_rnn(cell, word_list, dtype=tf.float32) + _, encoding = tf.nn.static_rnn(cell, word_list, dtype=tf.float32) # Given encoding of RNN, take encoding of last step (e.g hidden size of the # neural network of last step) and pass it as features for softmax @@ -107,6 +107,8 @@ def rnn_model(features, labels, mode): def main(unused_argv): global n_words + tf.logging.set_verbosity(tf.logging.INFO) + # Prepare training and testing data dbpedia = tf.contrib.learn.datasets.load_dataset( 'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data) |