diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-05 13:34:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-05 13:54:42 -0800 |
commit | c8a5ad4bcd8561ac196573ebe79e39ae43e1fdff (patch) | |
tree | f8c4ca2501d8d82466b41ab6832d474cac4aa0b9 /tensorflow/contrib/legacy_seq2seq | |
parent | 824629120293f92bcaba07fd33d7e6546b5804a5 (diff) |
Make legacy_seq2seq seq2seq models perform a **deep** copy of the incoming cell.
Shallow copies are not enough when using a MultiRNNCell.
Change: 146611566
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index ed2d49a1bb..0d6eac33de 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -182,7 +182,7 @@ def basic_rnn_seq2seq(encoder_inputs, It is a 2D Tensor of shape [batch_size x cell.state_size]. """ with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): - enc_cell = copy.copy(cell) + enc_cell = copy.deepcopy(cell) _, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) return rnn_decoder(decoder_inputs, enc_state, cell) @@ -355,7 +355,7 @@ def embedding_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. - encoder_cell = copy.copy(cell) + encoder_cell = copy.deepcopy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( encoder_cell, embedding_classes=num_encoder_symbols, @@ -846,7 +846,7 @@ def embedding_attention_seq2seq(encoder_inputs, scope or "embedding_attention_seq2seq", dtype=dtype) as scope: dtype = scope.dtype # Encoder. - encoder_cell = copy.copy(cell) + encoder_cell = copy.deepcopy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( encoder_cell, embedding_classes=num_encoder_symbols, @@ -969,7 +969,7 @@ def one2many_rnn_seq2seq(encoder_inputs, dtype = scope.dtype # Encoder. - encoder_cell = copy.copy(cell) + encoder_cell = copy.deepcopy(cell) encoder_cell = core_rnn_cell.EmbeddingWrapper( encoder_cell, embedding_classes=num_encoder_symbols, @@ -983,7 +983,7 @@ def one2many_rnn_seq2seq(encoder_inputs, with variable_scope.variable_scope("one2many_decoder_" + str( name)) as scope: - decoder_cell = copy.copy(cell) + decoder_cell = copy.deepcopy(cell) decoder_cell = core_rnn_cell.OutputProjectionWrapper( decoder_cell, num_decoder_symbols) if isinstance(feed_previous, bool): |