diff options
author | 2018-01-13 17:23:49 -0800 | |
---|---|---|
committer | 2018-01-13 17:27:40 -0800 | |
commit | 9c94a3f9370f535ccaa705403c60da67dd473bea (patch) | |
tree | 3973e80c6b15f8d304977ac715e7cf7e8fb8dbe7 | |
parent | e6ff665dbe4888aa5fdff8f34c44405acca2ddd1 (diff) |
Add checkpoint conversion for very old models that use the attention mechanism implemented in
tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
PiperOrigin-RevId: 181867510
-rw-r--r-- | tensorflow/contrib/rnn/python/tools/checkpoint_convert.py | 19 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py | 2 |
2 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py index 5536a01328..460e172a6d 100644 --- a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py @@ -128,10 +128,8 @@ RNN_NAME_REPLACEMENTS = collections.OrderedDict([ 'attention_cell_wrapper/attention/bias'), ############################################################################ # contrib/legacy_seq2seq/python/ops/seq2seq.py - ('attention_decoder/weights', - 'attention_decoder/kernel'), - ('attention_decoder/biases', - 'attention_decoder/bias'), + ('attention_decoder/weights', 'attention_decoder/kernel'), + ('attention_decoder/biases', 'attention_decoder/bias'), ('attention_decoder/Attention_0/weights', 'attention_decoder/Attention_0/kernel'), ('attention_decoder/Attention_0/biases', @@ -140,6 +138,19 @@ RNN_NAME_REPLACEMENTS = collections.OrderedDict([ 'attention_decoder/AttnOutputProjection/kernel'), ('attention_decoder/AttnOutputProjection/biases', 'attention_decoder/AttnOutputProjection/bias'), + # contrib/legacy_seq2seq/python/ops/seq2seq.py before cl/140060366 + ('attention_decoder/Attention_0/Linear/Bias', + 'attention_decoder/Attention_0/bias'), + ('attention_decoder/Attention_0/Linear/Matrix', + 'attention_decoder/Attention_0/kernel'), + ('attention_decoder/AttnOutputProjection/Linear/Bias', + 'attention_decoder/AttnOutputProjection/bias'), + ('attention_decoder/AttnOutputProjection/Linear/Matrix', + 'attention_decoder/AttnOutputProjection/kernel'), + ('attention_decoder/LSTMCell/B', 'attention_decoder/lstm_cell/bias'), + ('attention_decoder/LSTMCell/W_0', 'attention_decoder/lstm_cell/kernel'), + ('attention_decoder/Linear/Bias', 'attention_decoder/bias'), + ('attention_decoder/Linear/Matrix', 'attention_decoder/kernel') ]) _RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([ diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py index a9e7949463..b4785ee395 100644 --- a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py @@ -67,7 +67,7 @@ class CheckpointConvertTest(test.TestCase): self._old_ckpt_path, self._new_ckpt_path) self.assertTrue(glob.glob(self._new_ckpt_path + "*")) self.assertItemsEqual( - ["a"] + list(checkpoint_convert.RNN_NAME_REPLACEMENTS.values()), + set(checkpoint_convert.RNN_NAME_REPLACEMENTS.values()).union(["a"]), new_var_map.keys()) self.assertEqual(checkpoint_convert.RNN_NAME_REPLACEMENTS, conversion_map) |