aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/legacy_seq2seq
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-08 12:31:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 12:46:20 -0800
commit6d85f317e6773aa653d8166a28dc601f58aa7198 (patch)
tree96da7c9a8554b9126b2432d55a462c2121983997 /tensorflow/contrib/legacy_seq2seq
parent1fa028759586e1fa5a6bdfe4675490087dc3f6fe (diff)
Prepare legacy_seq2seq for coming changes to RNNCell.
* Create new instances of RNNCells as appropriate for testing. * Disable certain tests because there is no way to for them to work until we allow variable reuse from an RNNCell's current scope within another scope (in April). Change: 146942741
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py305
1 files changed, 175 insertions, 130 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 900f609681..0f9f0a955c 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import math
import random
import sys
@@ -110,14 +111,17 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell = cell_fn()
_, enc_state = core_rnn.static_rnn(cell, inp, dtype=dtypes.float32)
dec_inp = [
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.embedding_rnn_decoder(
- dec_inp, enc_state, cell, num_symbols=4, embedding_size=2)
+ dec_inp, enc_state, cell_fn(), num_symbols=4, embedding_size=2)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -140,7 +144,8 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell = cell_fn()
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
@@ -159,11 +164,11 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
- cell1 = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell_nt = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
- cell1,
+ cell_nt,
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2)
@@ -182,7 +187,7 @@ class Seq2SeqTest(test.TestCase):
dec, _ = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell_fn(),
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2,
@@ -201,7 +206,7 @@ class Seq2SeqTest(test.TestCase):
d3, _ = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp2,
- cell,
+ cell_fn(),
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2,
@@ -211,7 +216,7 @@ class Seq2SeqTest(test.TestCase):
d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell_fn(),
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2,
@@ -219,7 +224,7 @@ class Seq2SeqTest(test.TestCase):
d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
enc_inp,
dec_inp2,
- cell,
+ cell_fn(),
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2,
@@ -242,9 +247,11 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell = functools.partial(
+ core_rnn_cell_impl.BasicLSTMCell,
+ 2, state_is_tuple=True)
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
- enc_inp, dec_inp, cell, num_symbols=5, embedding_size=2)
+ enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -260,7 +267,7 @@ class Seq2SeqTest(test.TestCase):
dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell(),
num_symbols=5,
num_decoder_symbols=3,
embedding_size=2)
@@ -276,7 +283,7 @@ class Seq2SeqTest(test.TestCase):
dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell(),
num_symbols=5,
embedding_size=2,
output_projection=(w, b))
@@ -291,7 +298,7 @@ class Seq2SeqTest(test.TestCase):
d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp2,
- cell,
+ cell(),
num_symbols=5,
embedding_size=2,
feed_previous=constant_op.constant(True))
@@ -300,14 +307,14 @@ class Seq2SeqTest(test.TestCase):
d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell(),
num_symbols=5,
embedding_size=2,
feed_previous=True)
d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
enc_inp,
dec_inp2,
- cell,
+ cell(),
num_symbols=5,
embedding_size=2,
feed_previous=True)
@@ -321,7 +328,8 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -329,8 +337,11 @@ class Seq2SeqTest(test.TestCase):
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Create a new cell instance for the decoder, since it uses a
+ # different variable scope
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4)
+ dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -343,7 +354,8 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -351,8 +363,12 @@ class Seq2SeqTest(test.TestCase):
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4, num_heads=2)
+ dec_inp, enc_state, attn_states, cell_fn(),
+ output_size=4, num_heads=2)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -365,14 +381,18 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
cell, inp, dtype=dtypes.float32)
attn_states = enc_outputs
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4)
+ dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -385,14 +405,19 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = rnn.dynamic_rnn(
cell, inp, dtype=dtypes.float32)
attn_states = enc_outputs
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4, num_heads=2)
+ dec_inp, enc_state, attn_states, cell_fn(),
+ output_size=4, num_heads=2)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -407,8 +432,9 @@ class Seq2SeqTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
2, state_is_tuple=True)
- cell = core_rnn_cell_impl.MultiRNNCell(
+ cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
cells=[single_cell() for _ in range(2)], state_is_tuple=True)
+ cell = cell_fn()
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -416,8 +442,11 @@ class Seq2SeqTest(test.TestCase):
array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4)
+ dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -434,11 +463,9 @@ class Seq2SeqTest(test.TestCase):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
- single_cell = lambda: core_rnn_cell_impl.BasicLSTMCell( # pylint: disable=g-long-lambda
- 2, state_is_tuple=True)
-
- cell = core_rnn_cell_impl.MultiRNNCell(
- cells=[single_cell() for _ in range(2)], state_is_tuple=True)
+ cell_fn = lambda: core_rnn_cell_impl.MultiRNNCell( # pylint: disable=g-long-lambda
+ cells=[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)])
+ cell = cell_fn()
inp = constant_op.constant(0.5, shape=[2, 2, 2])
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
@@ -447,8 +474,11 @@ class Seq2SeqTest(test.TestCase):
for e in enc_outputs
], 1)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.attention_decoder(
- dec_inp, enc_state, attn_states, cell, output_size=4)
+ dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
@@ -466,7 +496,8 @@ class Seq2SeqTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
- cell = core_rnn_cell_impl.GRUCell(2)
+ cell_fn = lambda: core_rnn_cell_impl.GRUCell(2)
+ cell = cell_fn()
enc_outputs, enc_state = core_rnn.static_rnn(
cell, inp, dtype=dtypes.float32)
attn_states = array_ops.concat([
@@ -476,11 +507,14 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
+
+ # Use a new cell instance since the attention decoder uses a
+ # different variable scope.
dec, mem = seq2seq_lib.embedding_attention_decoder(
dec_inp,
enc_state,
attn_states,
- cell,
+ cell_fn(),
num_symbols=4,
embedding_size=2,
output_size=3)
@@ -504,7 +538,8 @@ class Seq2SeqTest(test.TestCase):
constant_op.constant(
i, dtypes.int32, shape=[2]) for i in range(3)
]
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)
+ cell_fn = lambda: core_rnn_cell_impl.BasicLSTMCell(2)
+ cell = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
@@ -523,11 +558,14 @@ class Seq2SeqTest(test.TestCase):
# Test with state_is_tuple=False.
with variable_scope.variable_scope("no_tuple"):
- cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+ cell_fn = functools.partial(
+ core_rnn_cell_impl.BasicLSTMCell,
+ 2, state_is_tuple=False)
+ cell_nt = cell_fn()
dec, mem = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell_nt,
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2)
@@ -546,7 +584,7 @@ class Seq2SeqTest(test.TestCase):
dec, _ = seq2seq_lib.embedding_attention_seq2seq(
enc_inp,
dec_inp,
- cell,
+ cell_fn(),
num_encoder_symbols=2,
num_decoder_symbols=5,
embedding_size=2,
@@ -556,43 +594,47 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual(3, len(res))
self.assertEqual((2, 2), res[0].shape)
- # Test that previous-feeding model ignores inputs after the first.
- dec_inp2 = [
- constant_op.constant(
- 0, dtypes.int32, shape=[2]) for _ in range(3)
- ]
- with variable_scope.variable_scope("other"):
- d3, _ = seq2seq_lib.embedding_attention_seq2seq(
- enc_inp,
- dec_inp2,
- cell,
- num_encoder_symbols=2,
- num_decoder_symbols=5,
- embedding_size=2,
- feed_previous=constant_op.constant(True))
- sess.run([variables.global_variables_initializer()])
- variable_scope.get_variable_scope().reuse_variables()
- d1, _ = seq2seq_lib.embedding_attention_seq2seq(
- enc_inp,
- dec_inp,
- cell,
- num_encoder_symbols=2,
- num_decoder_symbols=5,
- embedding_size=2,
- feed_previous=True)
- d2, _ = seq2seq_lib.embedding_attention_seq2seq(
- enc_inp,
- dec_inp2,
- cell,
- num_encoder_symbols=2,
- num_decoder_symbols=5,
- embedding_size=2,
- feed_previous=True)
- res1 = sess.run(d1)
- res2 = sess.run(d2)
- res3 = sess.run(d3)
- self.assertAllClose(res1, res2)
- self.assertAllClose(res1, res3)
+ # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
+ # within a variable scope that already has a weights tensor.
+ #
+ # # Test that previous-feeding model ignores inputs after the first.
+ # dec_inp2 = [
+ # constant_op.constant(
+ # 0, dtypes.int32, shape=[2]) for _ in range(3)
+ # ]
+ # with variable_scope.variable_scope("other"):
+ # d3, _ = seq2seq_lib.embedding_attention_seq2seq(
+ # enc_inp,
+ # dec_inp2,
+ # cell_fn(),
+ # num_encoder_symbols=2,
+ # num_decoder_symbols=5,
+ # embedding_size=2,
+ # feed_previous=constant_op.constant(True))
+ # sess.run([variables.global_variables_initializer()])
+ # variable_scope.get_variable_scope().reuse_variables()
+ # cell = cell_fn()
+ # d1, _ = seq2seq_lib.embedding_attention_seq2seq(
+ # enc_inp,
+ # dec_inp,
+ # cell,
+ # num_encoder_symbols=2,
+ # num_decoder_symbols=5,
+ # embedding_size=2,
+ # feed_previous=True)
+ # d2, _ = seq2seq_lib.embedding_attention_seq2seq(
+ # enc_inp,
+ # dec_inp2,
+ # cell,
+ # num_encoder_symbols=2,
+ # num_decoder_symbols=5,
+ # embedding_size=2,
+ # feed_previous=True)
+ # res1 = sess.run(d1)
+ # res2 = sess.run(d2)
+ # res3 = sess.run(d3)
+ # self.assertAllClose(res1, res2)
+ # self.assertAllClose(res1, res3)
def testOne2ManyRNNSeq2Seq(self):
with self.test_session() as sess:
@@ -734,61 +776,64 @@ class Seq2SeqTest(test.TestCase):
res = sess.run(loss_per_sequence)
self.assertAllClose(np.asarray([4.828314, 4.828314]), res)
- def testModelWithBucketsScopeAndLoss(self):
- """Test that variable scope reuse is not reset after model_with_buckets."""
- classes = 10
- buckets = [(4, 4), (8, 8)]
-
- with self.test_session():
- # Here comes a sample Seq2Seq model using GRU cells.
- def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
- """Example sequence-to-sequence model that uses GRU cells."""
-
- def GRUSeq2Seq(enc_inp, dec_inp):
- cell = core_rnn_cell_impl.MultiRNNCell(
- [core_rnn_cell_impl.GRUCell(24) for _ in range(2)],
- state_is_tuple=True)
- return seq2seq_lib.embedding_attention_seq2seq(
- enc_inp,
- dec_inp,
- cell,
- num_encoder_symbols=classes,
- num_decoder_symbols=classes,
- embedding_size=24)
-
- targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
- return seq2seq_lib.model_with_buckets(
- enc_inp,
- dec_inp,
- targets,
- weights,
- buckets,
- GRUSeq2Seq,
- per_example_loss=per_example_loss)
-
- # Now we construct the copy model.
- inp = [
- array_ops.placeholder(
- dtypes.int32, shape=[None]) for _ in range(8)
- ]
- out = [
- array_ops.placeholder(
- dtypes.int32, shape=[None]) for _ in range(8)
- ]
- weights = [
- array_ops.ones_like(
- inp[0], dtype=dtypes.float32) for _ in range(8)
- ]
- with variable_scope.variable_scope("root"):
- _, losses1 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=False)
- # Now check that we did not accidentally set reuse.
- self.assertEqual(False, variable_scope.get_variable_scope().reuse)
- # Construct one more model with per-example loss.
- variable_scope.get_variable_scope().reuse_variables()
- _, losses2 = SampleGRUSeq2Seq(inp, out, weights, per_example_loss=True)
- # First loss is scalar, the second one is a 1-dimensinal tensor.
- self.assertEqual([], losses1[0].get_shape().as_list())
- self.assertEqual([None], losses2[0].get_shape().as_list())
+ # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
+ # within a variable scope that already has a weights tensor.
+ #
+ # def testModelWithBucketsScopeAndLoss(self):
+ # """Test variable scope reuse is not reset after model_with_buckets."""
+ # classes = 10
+ # buckets = [(4, 4), (8, 8)]
+
+ # with self.test_session():
+ # # Here comes a sample Seq2Seq model using GRU cells.
+ # def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
+ # """Example sequence-to-sequence model that uses GRU cells."""
+
+ # def GRUSeq2Seq(enc_inp, dec_inp):
+ # cell = core_rnn_cell_impl.MultiRNNCell(
+ # [core_rnn_cell_impl.GRUCell(24) for _ in range(2)])
+ # return seq2seq_lib.embedding_attention_seq2seq(
+ # enc_inp,
+ # dec_inp,
+ # cell,
+ # num_encoder_symbols=classes,
+ # num_decoder_symbols=classes,
+ # embedding_size=24)
+
+ # targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
+ # return seq2seq_lib.model_with_buckets(
+ # enc_inp,
+ # dec_inp,
+ # targets,
+ # weights,
+ # buckets,
+ # GRUSeq2Seq,
+ # per_example_loss=per_example_loss)
+
+ # # Now we construct the copy model.
+ # inp = [
+ # array_ops.placeholder(
+ # dtypes.int32, shape=[None]) for _ in range(8)
+ # ]
+ # out = [
+ # array_ops.placeholder(
+ # dtypes.int32, shape=[None]) for _ in range(8)
+ # ]
+ # weights = [
+ # array_ops.ones_like(
+ # inp[0], dtype=dtypes.float32) for _ in range(8)
+ # ]
+ # with variable_scope.variable_scope("root"):
+ # _, losses1 = SampleGRUSeq2Seq(
+ # inp, out, weights, per_example_loss=False)
+ # # Now check that we did not accidentally set reuse.
+ # self.assertEqual(False, variable_scope.get_variable_scope().reuse)
+ # with variable_scope.variable_scope("new"):
+ # _, losses2 = SampleGRUSeq2Seq
+ # inp, out, weights, per_example_loss=True)
+ # # First loss is scalar, the second one is a 1-dimensinal tensor.
+ # self.assertEqual([], losses1[0].get_shape().as_list())
+ # self.assertEqual([None], losses2[0].get_shape().as_list())
def testModelWithBuckets(self):
"""Larger tests that does full sequence-to-sequence model training."""