diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-10 11:45:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-10 12:08:37 -0800 |
commit | a83290927bc2c7d64560d442d8aaec2d1cf155ab (patch) | |
tree | 6bbd916fd950ee464e39bfbaef31669b92f482f3 /tensorflow/contrib/legacy_seq2seq | |
parent | b4e6d2959810dc2be17e578bc68449c253407b1c (diff) |
More seq2seq tweaks and test updates as RNNCells become more like layers.
Change: 147180372
Diffstat (limited to 'tensorflow/contrib/legacy_seq2seq')
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py | 116 | ||||
-rw-r--r-- | tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py | 36 |
2 files changed, 89 insertions, 63 deletions
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py index 0f9f0a955c..5d77593619 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py @@ -211,24 +211,25 @@ class Seq2SeqTest(test.TestCase): num_decoder_symbols=5, embedding_size=2, feed_previous=constant_op.constant(True)) + with variable_scope.variable_scope("other_2"): + d1, _ = seq2seq_lib.embedding_rnn_seq2seq( + enc_inp, + dec_inp, + cell_fn(), + num_encoder_symbols=2, + num_decoder_symbols=5, + embedding_size=2, + feed_previous=True) + with variable_scope.variable_scope("other_3"): + d2, _ = seq2seq_lib.embedding_rnn_seq2seq( + enc_inp, + dec_inp2, + cell_fn(), + num_encoder_symbols=2, + num_decoder_symbols=5, + embedding_size=2, + feed_previous=True) sess.run([variables.global_variables_initializer()]) - variable_scope.get_variable_scope().reuse_variables() - d1, _ = seq2seq_lib.embedding_rnn_seq2seq( - enc_inp, - dec_inp, - cell_fn(), - num_encoder_symbols=2, - num_decoder_symbols=5, - embedding_size=2, - feed_previous=True) - d2, _ = seq2seq_lib.embedding_rnn_seq2seq( - enc_inp, - dec_inp2, - cell_fn(), - num_encoder_symbols=2, - num_decoder_symbols=5, - embedding_size=2, - feed_previous=True) res1 = sess.run(d1) res2 = sess.run(d2) res3 = sess.run(d3) @@ -302,22 +303,23 @@ class Seq2SeqTest(test.TestCase): num_symbols=5, embedding_size=2, feed_previous=constant_op.constant(True)) + with variable_scope.variable_scope("other_2"): + d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( + enc_inp, + dec_inp, + cell(), + num_symbols=5, + embedding_size=2, + feed_previous=True) + with variable_scope.variable_scope("other_3"): + d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( + enc_inp, + dec_inp2, + cell(), + num_symbols=5, + embedding_size=2, + feed_previous=True) sess.run([variables.global_variables_initializer()]) - variable_scope.get_variable_scope().reuse_variables() - d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( - enc_inp, - dec_inp, - cell(), - num_symbols=5, - embedding_size=2, - feed_previous=True) - d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( - enc_inp, - dec_inp2, - cell(), - num_symbols=5, - embedding_size=2, - feed_previous=True) res1 = sess.run(d1) res2 = sess.run(d2) res3 = sess.run(d3) @@ -654,9 +656,15 @@ class Seq2SeqTest(test.TestCase): i, dtypes.int32, shape=[2]) for i in range(4) ] dec_symbols_dict = {"0": 5, "1": 6} - cell = core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + def EncCellFn(): + return core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True) + def DecCellsFn(): + return dict( + (k, core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=True)) + for k in dec_symbols_dict) outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq( - enc_inp, dec_inp_dict, cell, 2, dec_symbols_dict, embedding_size=2)) + enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(), + 2, dec_symbols_dict, embedding_size=2)) sess.run([variables.global_variables_initializer()]) res = sess.run(outputs_dict["0"]) @@ -688,29 +696,33 @@ class Seq2SeqTest(test.TestCase): outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq( enc_inp, dec_inp_dict2, - cell, + EncCellFn(), + DecCellsFn(), 2, dec_symbols_dict, embedding_size=2, feed_previous=constant_op.constant(True)) + with variable_scope.variable_scope("other_2"): + outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq( + enc_inp, + dec_inp_dict, + EncCellFn(), + DecCellsFn(), + 2, + dec_symbols_dict, + embedding_size=2, + feed_previous=True) + with variable_scope.variable_scope("other_3"): + outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq( + enc_inp, + dec_inp_dict2, + EncCellFn(), + DecCellsFn(), + 2, + dec_symbols_dict, + embedding_size=2, + feed_previous=True) sess.run([variables.global_variables_initializer()]) - variable_scope.get_variable_scope().reuse_variables() - outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq( - enc_inp, - dec_inp_dict, - cell, - 2, - dec_symbols_dict, - embedding_size=2, - feed_previous=True) - outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq( - enc_inp, - dec_inp_dict2, - cell, - 2, - dec_symbols_dict, - embedding_size=2, - feed_previous=True) res1 = sess.run(outputs_dict1["0"]) res2 = sess.run(outputs_dict2["0"]) res3 = sess.run(outputs_dict3["0"]) diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 0d6eac33de..8608054deb 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -917,7 +917,8 @@ def embedding_attention_seq2seq(encoder_inputs, def one2many_rnn_seq2seq(encoder_inputs, decoder_inputs_dict, - cell, + enc_cell, + dec_cells_dict, num_encoder_symbols, num_decoder_symbols_dict, embedding_size, @@ -936,7 +937,9 @@ def one2many_rnn_seq2seq(encoder_inputs, the corresponding decoder_inputs; each decoder_inputs is a list of 1D Tensors of shape [batch_size]; num_decoders is defined as len(decoder_inputs_dict). - cell: core_rnn_cell.RNNCell defining the cell function and size. + enc_cell: core_rnn_cell.RNNCell defining the encoder cell function and size. + dec_cells_dict: A dictionary mapping encoder name (string) to an + instance of core_rnn_cell.RNNCell. num_encoder_symbols: Integer; number of symbols on the encoder side. num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an integer specifying number of symbols for the corresponding decoder; @@ -960,37 +963,48 @@ def one2many_rnn_seq2seq(encoder_inputs, state_dict: A mapping from decoder name (string) to the final state of the corresponding decoder RNN; it is a 2D Tensor of shape [batch_size x cell.state_size]. + + Raises: + TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell. + ValueError: if len(dec_cells) != len(decoder_inputs_dict). """ outputs_dict = {} state_dict = {} + if not isinstance(enc_cell, core_rnn_cell.RNNCell): + raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell)) + if set(dec_cells_dict) != set(decoder_inputs_dict): + raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict") + for dec_cell in dec_cells_dict.values(): + if not isinstance(dec_cell, core_rnn_cell.RNNCell): + raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell)) + with variable_scope.variable_scope( scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: dtype = scope.dtype # Encoder. - encoder_cell = copy.deepcopy(cell) - encoder_cell = core_rnn_cell.EmbeddingWrapper( - encoder_cell, + enc_cell = core_rnn_cell.EmbeddingWrapper( + enc_cell, embedding_classes=num_encoder_symbols, embedding_size=embedding_size) _, encoder_state = core_rnn.static_rnn( - encoder_cell, encoder_inputs, dtype=dtype) + enc_cell, encoder_inputs, dtype=dtype) # Decoder. for name, decoder_inputs in decoder_inputs_dict.items(): num_decoder_symbols = num_decoder_symbols_dict[name] + dec_cell = dec_cells_dict[name] with variable_scope.variable_scope("one2many_decoder_" + str( name)) as scope: - decoder_cell = copy.deepcopy(cell) - decoder_cell = core_rnn_cell.OutputProjectionWrapper( - decoder_cell, num_decoder_symbols) + dec_cell = core_rnn_cell.OutputProjectionWrapper( + dec_cell, num_decoder_symbols) if isinstance(feed_previous, bool): outputs, state = embedding_rnn_decoder( decoder_inputs, encoder_state, - decoder_cell, + dec_cell, num_decoder_symbols, embedding_size, feed_previous=feed_previous) @@ -1005,7 +1019,7 @@ def one2many_rnn_seq2seq(encoder_inputs, outputs, state = embedding_rnn_decoder( decoder_inputs, encoder_state, - decoder_cell, + dec_cell, num_decoder_symbols, embedding_size, feed_previous=feed_previous) |