diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-02 16:48:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-02 17:07:41 -0800 |
commit | 0fc2568652193e0743785dc232757636355e44fc (patch) | |
tree | 3a7431eb303cb559bc9d31d31222fccc32bea110 /tensorflow/contrib/legacy_seq2seq | |
parent | 4fe798cdf5dbe2f62dfadd7c64c76d6211a3ed8a (diff) |
Make legacy_seq2seq seq2seq models perform a shallow copy of the incoming cell.
this forces a separate cell instance to be used in the encoder and decoder(s).
right now this is a no-op, but most RNNCells will soon keep track of their
variables and using the same cell for an encoder and decoder pair without
tied weights will not be legal. at that point, this change will still
be a no-op, but without it the RNNCells would raise an error.
Change: 146423203
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index bc36e9dced..ed2d49a1bb 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -56,6 +56,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy + # We disable pylint because we need python3 compatibility. from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin @@ -180,7 +182,8 @@ def basic_rnn_seq2seq(encoder_inputs, It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): - _, enc_state = core_rnn.static_rnn(cell, encoder_inputs, dtype=dtype) + enc_cell = copy.copy(cell) + _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) return rnn_decoder(decoder_inputs, enc_state, cell) @@ -352,8 +355,9 @@ def embedding_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. + encoder_cell = copy.copy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( - cell, + encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) _, encoder_state = core_rnn.static_rnn( @@ -842,8 +846,9 @@ def embedding_attention_seq2seq(encoder_inputs, scope or "embedding_attention_seq2seq", dtype=dtype) as scope: dtype = scope.dtype # Encoder. + encoder_cell = copy.copy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( - cell, + encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) encoder_outputs, encoder_state = core_rnn.static_rnn( @@ -964,8 +969,9 @@ def one2many_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. + encoder_cell = copy.copy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( - cell, + encoder_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) _, encoder_state = core_rnn.static_rnn( @@ -977,8 +983,9 @@ def one2many_rnn_seq2seq(encoder_inputs, with variable_scope.variable_scope("one2many_decoder_" + str( name)) as scope: + decoder_cell = copy.copy(cell) decoder_cell = core_rnn_cell.OutputProjectionWrapper( - cell, num_decoder_symbols) + decoder_cell, num_decoder_symbols) if isinstance(feed_previous, bool): outputs, state = embedding_rnn_decoder( decoder_inputs, |