aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2016-12-15 09:00:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 09:06:41 -0800
commit5a21e678bd8b0fe294717a0362d41c6dfc467cef (patch)
tree8e386fe2a0abcd85b5e334551f5ffbd1bebe61ad /tensorflow/contrib/legacy_seq2seq
parentfd7ff167e6f02fe0966fa70ef52a99d16e0490ec (diff)
Move the implementation code of static_rnn, static_bidirectional_rnn and static_state_saving_rnn from core to contrib.
Change: 142148394
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index d2018bb219..b04b6ce16a 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -60,6 +60,7 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.contrib.rnn.python.ops import core_rnn as contrib_rnn
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -67,7 +68,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
@@ -180,7 +180,7 @@ def basic_rnn_seq2seq(encoder_inputs,
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
- _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)
+ _, enc_state = contrib_rnn.static_rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -216,7 +216,8 @@ def tied_rnn_seq2seq(encoder_inputs,
"""
with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
scope = scope or "tied_rnn_seq2seq"
- _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype, scope=scope)
+ _, enc_state = contrib_rnn.static_rnn(
+ cell, encoder_inputs, dtype=dtype, scope=scope)
variable_scope.get_variable_scope().reuse_variables()
return rnn_decoder(
decoder_inputs,
@@ -355,7 +356,8 @@ def embedding_rnn_seq2seq(encoder_inputs,
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = contrib_rnn.static_rnn(
+ encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
if output_projection is None:
@@ -844,9 +846,8 @@ def embedding_attention_seq2seq(encoder_inputs,
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- encoder_outputs, encoder_state = rnn.rnn(encoder_cell,
- encoder_inputs,
- dtype=dtype)
+ encoder_outputs, encoder_state = contrib_rnn.static_rnn(
+ encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
top_states = [
@@ -967,7 +968,8 @@ def one2many_rnn_seq2seq(encoder_inputs,
cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
- _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
+ _, encoder_state = contrib_rnn.static_rnn(
+ encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
for name, decoder_inputs in decoder_inputs_dict.items():