aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexander Gorban <gorban@google.com>2018-01-13 17:23:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-13 17:27:40 -0800
commit9c94a3f9370f535ccaa705403c60da67dd473bea (patch)
tree3973e80c6b15f8d304977ac715e7cf7e8fb8dbe7
parente6ff665dbe4888aa5fdff8f34c44405acca2ddd1 (diff)
Add checkpoint conversion for very old models that use the attention mechanism implemented in
tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py PiperOrigin-RevId: 181867510
-rw-r--r--tensorflow/contrib/rnn/python/tools/checkpoint_convert.py19
-rw-r--r--tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py2
2 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
index 5536a01328..460e172a6d 100644
--- a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
+++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
@@ -128,10 +128,8 @@ RNN_NAME_REPLACEMENTS = collections.OrderedDict([
'attention_cell_wrapper/attention/bias'),
############################################################################
# contrib/legacy_seq2seq/python/ops/seq2seq.py
- ('attention_decoder/weights',
- 'attention_decoder/kernel'),
- ('attention_decoder/biases',
- 'attention_decoder/bias'),
+ ('attention_decoder/weights', 'attention_decoder/kernel'),
+ ('attention_decoder/biases', 'attention_decoder/bias'),
('attention_decoder/Attention_0/weights',
'attention_decoder/Attention_0/kernel'),
('attention_decoder/Attention_0/biases',
@@ -140,6 +138,19 @@ RNN_NAME_REPLACEMENTS = collections.OrderedDict([
'attention_decoder/AttnOutputProjection/kernel'),
('attention_decoder/AttnOutputProjection/biases',
'attention_decoder/AttnOutputProjection/bias'),
+ # contrib/legacy_seq2seq/python/ops/seq2seq.py before cl/140060366
+ ('attention_decoder/Attention_0/Linear/Bias',
+ 'attention_decoder/Attention_0/bias'),
+ ('attention_decoder/Attention_0/Linear/Matrix',
+ 'attention_decoder/Attention_0/kernel'),
+ ('attention_decoder/AttnOutputProjection/Linear/Bias',
+ 'attention_decoder/AttnOutputProjection/bias'),
+ ('attention_decoder/AttnOutputProjection/Linear/Matrix',
+ 'attention_decoder/AttnOutputProjection/kernel'),
+ ('attention_decoder/LSTMCell/B', 'attention_decoder/lstm_cell/bias'),
+ ('attention_decoder/LSTMCell/W_0', 'attention_decoder/lstm_cell/kernel'),
+ ('attention_decoder/Linear/Bias', 'attention_decoder/bias'),
+ ('attention_decoder/Linear/Matrix', 'attention_decoder/kernel')
])
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
index a9e7949463..b4785ee395 100644
--- a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
+++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py
@@ -67,7 +67,7 @@ class CheckpointConvertTest(test.TestCase):
self._old_ckpt_path, self._new_ckpt_path)
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
self.assertItemsEqual(
- ["a"] + list(checkpoint_convert.RNN_NAME_REPLACEMENTS.values()),
+ set(checkpoint_convert.RNN_NAME_REPLACEMENTS.values()).union(["a"]),
new_var_map.keys())
self.assertEqual(checkpoint_convert.RNN_NAME_REPLACEMENTS, conversion_map)