aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-05 13:34:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-05 13:54:42 -0800
commitc8a5ad4bcd8561ac196573ebe79e39ae43e1fdff (patch)
treef8c4ca2501d8d82466b41ab6832d474cac4aa0b9 /tensorflow/contrib/legacy_seq2seq
parent824629120293f92bcaba07fd33d7e6546b5804a5 (diff)
Make legacy_seq2seq seq2seq models perform a **deep** copy of the incoming cell.
Shallow copies are not enough when using a MultiRNNCell. Change: 146611566
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index ed2d49a1bb..0d6eac33de 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -182,7 +182,7 @@ def basic_rnn_seq2seq(encoder_inputs,
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
- enc_cell = copy.copy(cell)
+ enc_cell = copy.deepcopy(cell)
_, enc_state = core_rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_state, cell)
@@ -355,7 +355,7 @@ def embedding_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
- encoder_cell = copy.copy(cell)
+ encoder_cell = copy.deepcopy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
encoder_cell,
embedding_classes=num_encoder_symbols,
@@ -846,7 +846,7 @@ def embedding_attention_seq2seq(encoder_inputs,
scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
- encoder_cell = copy.copy(cell)
+ encoder_cell = copy.deepcopy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
encoder_cell,
embedding_classes=num_encoder_symbols,
@@ -969,7 +969,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
dtype = scope.dtype
# Encoder.
- encoder_cell = copy.copy(cell)
+ encoder_cell = copy.deepcopy(cell)
encoder_cell = core_rnn_cell.EmbeddingWrapper(
encoder_cell,
embedding_classes=num_encoder_symbols,
@@ -983,7 +983,7 @@ def one2many_rnn_seq2seq(encoder_inputs,
with variable_scope.variable_scope("one2many_decoder_" + str(
name)) as scope:
- decoder_cell = copy.copy(cell)
+ decoder_cell = copy.deepcopy(cell)
decoder_cell = core_rnn_cell.OutputProjectionWrapper(
decoder_cell, num_decoder_symbols)
if isinstance(feed_previous, bool):