aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-02 16:48:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-02 17:07:41 -0800
commit0fc2568652193e0743785dc232757636355e44fc (patch)
tree3a7431eb303cb559bc9d31d31222fccc32bea110 /tensorflow/contrib/legacy_seq2seq
parent4fe798cdf5dbe2f62dfadd7c64c76d6211a3ed8a (diff)
Make legacy_seq2seq seq2seq models perform a shallow copy of the incoming cell.
this forces a separate cell instance to be used in the encoder and decoder(s). right now this is a no-op, but most RNNCells will soon keep track of their variables and using the same cell for an encoder and decoder pair without tied weights will not be legal. at that point, this change will still be a no-op, but without it the RNNCells would raise an error. Change: 146423203
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index bc36e9dced..ed2d49a1bb 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -56,6 +56,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
+
# We disable pylint because we need python3 compatibility.
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
@@ -180,7 +182,8 @@ def basic_rnn_seq2seq(encoder_inputs,
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
- _, enc_state = core_rnn.static_rnn(cell, encoder_inputs, dtype=dtype)
+ enc_cell = copy.copy(cell)
+ _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -352,8 +355,9 @@ def embedding_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
+ encoder_cell = copy.copy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
- cell,
+ encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
_, encoder_state = core_rnn.static_rnn(
@@ -842,8 +846,9 @@ def embedding_attention_seq2seq(encoder_inputs,
scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
+ encoder_cell = copy.copy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
- cell,
+ encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
encoder_outputs, encoder_state = core_rnn.static_rnn(
@@ -964,8 +969,9 @@ def one2many_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
+ encoder_cell = copy.copy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
- cell,
+ encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
_, encoder_state = core_rnn.static_rnn(
@@ -977,8 +983,9 @@ def one2many_rnn_seq2seq(encoder_inputs,
with variable_scope.variable_scope("one2many_decoder_" + str(
name)) as scope:
+ decoder_cell = copy.copy(cell)
decoder_cell = core_rnn_cell.OutputProjectionWrapper(
- cell, num_decoder_symbols)
+ decoder_cell, num_decoder_symbols)
if isinstance(feed_previous, bool):
outputs, state = embedding_rnn_decoder(
decoder_inputs,