aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
blob: 560fc8c5a22a0e7acf1f37cf7daf7790dc14de19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AOpGoE2T-YXS"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
        "\n",
        "# Neural Machine Translation with Attention\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
        "\u003ca target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e  \n",
        "\u003c/td\u003e\u003ctd\u003e\n",
        "\u003ca target=\"_blank\"  href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CiwtNgENbx2g"
      },
      "source": [
        "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
        "\n",
        "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n",
        "\n",
        "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
        "\n",
        "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n",
        "\n",
        "Note: This example takes approximately 10 mintues to run on a single P100 GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tnxXKDjq3jEL"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function\n",
        "\n",
        "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
        "import tensorflow as tf\n",
        "\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "import unicodedata\n",
        "import re\n",
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wfodePkj3jEa"
      },
      "source": [
        "## Download and prepare the dataset\n",
        "\n",
        "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n",
        "\n",
        "```\n",
        "May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
        "```\n",
        "\n",
        "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n",
        "\n",
        "1. Add a *start* and *end* token to each sentence.\n",
        "2. Clean the sentences by removing special characters.\n",
        "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
        "4. Pad each sentence to a maximum length."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kRVATYOgJs1b"
      },
      "outputs": [],
      "source": [
        "# Download the file\n",
        "path_to_zip = tf.keras.utils.get_file(\n",
        "    'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n",
        "    extract=True)\n",
        "\n",
        "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rd0jw-eC3jEh"
      },
      "outputs": [],
      "source": [
        "# Converts the unicode file to ascii\n",
        "def unicode_to_ascii(s):\n",
        "    return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
        "        if unicodedata.category(c) != 'Mn')\n",
        "\n",
        "\n",
        "def preprocess_sentence(w):\n",
        "    w = unicode_to_ascii(w.lower().strip())\n",
        "    \n",
        "    # creating a space between a word and the punctuation following it\n",
        "    # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n",
        "    # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
        "    w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
        "    w = re.sub(r'[\" \"]+', \" \", w)\n",
        "    \n",
        "    # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n",
        "    w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
        "    \n",
        "    w = w.rstrip().strip()\n",
        "    \n",
        "    # adding a start and an end token to the sentence\n",
        "    # so that the model know when to start and stop predicting.\n",
        "    w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n",
        "    return w"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OHn4Dct23jEm"
      },
      "outputs": [],
      "source": [
        "# 1. Remove the accents\n",
        "# 2. Clean the sentences\n",
        "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n",
        "def create_dataset(path, num_examples):\n",
        "    lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n",
        "    \n",
        "    word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')]  for l in lines[:num_examples]]\n",
        "    \n",
        "    return word_pairs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "9xbqO7Iie9bb"
      },
      "outputs": [],
      "source": [
        "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n",
        "# (e.g., 5 -\u003e \"dad\") for each language,\n",
        "class LanguageIndex():\n",
        "  def __init__(self, lang):\n",
        "    self.lang = lang\n",
        "    self.word2idx = {}\n",
        "    self.idx2word = {}\n",
        "    self.vocab = set()\n",
        "    \n",
        "    self.create_index()\n",
        "    \n",
        "  def create_index(self):\n",
        "    for phrase in self.lang:\n",
        "      self.vocab.update(phrase.split(' '))\n",
        "    \n",
        "    self.vocab = sorted(self.vocab)\n",
        "    \n",
        "    self.word2idx['\u003cpad\u003e'] = 0\n",
        "    for index, word in enumerate(self.vocab):\n",
        "      self.word2idx[word] = index + 1\n",
        "    \n",
        "    for word, index in self.word2idx.items():\n",
        "      self.idx2word[index] = word"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eAY9k49G3jE_"
      },
      "outputs": [],
      "source": [
        "def max_length(tensor):\n",
        "    return max(len(t) for t in tensor)\n",
        "\n",
        "\n",
        "def load_dataset(path, num_examples):\n",
        "    # creating cleaned input, output pairs\n",
        "    pairs = create_dataset(path, num_examples)\n",
        "\n",
        "    # index language using the class defined above    \n",
        "    inp_lang = LanguageIndex(sp for en, sp in pairs)\n",
        "    targ_lang = LanguageIndex(en for en, sp in pairs)\n",
        "    \n",
        "    # Vectorize the input and target languages\n",
        "    \n",
        "    # Spanish sentences\n",
        "    input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n",
        "    \n",
        "    # English sentences\n",
        "    target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n",
        "    \n",
        "    # Calculate max_length of input and output tensor\n",
        "    # Here, we'll set those to the longest sentence in the dataset\n",
        "    max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n",
        "    \n",
        "    # Padding the input and output tensor to the maximum length\n",
        "    input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n",
        "                                                                 maxlen=max_length_inp,\n",
        "                                                                 padding='post')\n",
        "    \n",
        "    target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n",
        "                                                                  maxlen=max_length_tar, \n",
        "                                                                  padding='post')\n",
        "    \n",
        "    return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GOi42V79Ydlr"
      },
      "source": [
        "### Limit the size of the dataset to experiment faster (optional)\n",
        "\n",
        "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cnxC7q-j3jFD"
      },
      "outputs": [],
      "source": [
        "# Try experimenting with the size of that dataset\n",
        "num_examples = 30000\n",
        "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4QILQkOs3jFG"
      },
      "outputs": [],
      "source": [
        "# Creating training and validation sets using an 80-20 split\n",
        "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
        "\n",
        "# Show length\n",
        "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rgCLkfv5uO3d"
      },
      "source": [
        "### Create a tf.data dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "TqHsArVZ3jFS"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(input_tensor_train)\n",
        "BATCH_SIZE = 64\n",
        "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n",
        "embedding_dim = 256\n",
        "units = 1024\n",
        "vocab_inp_size = len(inp_lang.word2idx)\n",
        "vocab_tar_size = len(targ_lang.word2idx)\n",
        "\n",
        "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
        "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TNfHIF71ulLu"
      },
      "source": [
        "## Write the encoder and decoder model\n",
        "\n",
        "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n",
        "\n",
        "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n",
        "\n",
        "Here are the equations that are implemented:\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n",
        "\n",
        "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n",
        "\n",
        "* FC = Fully connected (dense) layer\n",
        "* EO = Encoder output\n",
        "* H = hidden state\n",
        "* X = input to the decoder\n",
        "\n",
        "And the pseudo-code:\n",
        "\n",
        "* `score = FC(tanh(FC(EO) + FC(H)))`\n",
        "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
        "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
        "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
        "* `merged vector = concat(embedding output, context vector)`\n",
        "* This merged vector is then given to the GRU\n",
        "  \n",
        "The shapes of all the vectors at each step have been specified in the comments in the code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "avyJ_4VIUoHb"
      },
      "outputs": [],
      "source": [
        "def gru(units):\n",
        "  # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n",
        "  # the code automatically does that.\n",
        "  if tf.test.is_gpu_available():\n",
        "    return tf.keras.layers.CuDNNGRU(units, \n",
        "                                    return_sequences=True, \n",
        "                                    return_state=True, \n",
        "                                    recurrent_initializer='glorot_uniform')\n",
        "  else:\n",
        "    return tf.keras.layers.GRU(units, \n",
        "                               return_sequences=True, \n",
        "                               return_state=True, \n",
        "                               recurrent_activation='sigmoid', \n",
        "                               recurrent_initializer='glorot_uniform')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "nZ2rI24i3jFg"
      },
      "outputs": [],
      "source": [
        "class Encoder(tf.keras.Model):\n",
        "    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
        "        super(Encoder, self).__init__()\n",
        "        self.batch_sz = batch_sz\n",
        "        self.enc_units = enc_units\n",
        "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "        self.gru = gru(self.enc_units)\n",
        "        \n",
        "    def call(self, x, hidden):\n",
        "        x = self.embedding(x)\n",
        "        output, state = self.gru(x, initial_state = hidden)        \n",
        "        return output, state\n",
        "    \n",
        "    def initialize_hidden_state(self):\n",
        "        return tf.zeros((self.batch_sz, self.enc_units))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yJ_B3mhW3jFk"
      },
      "outputs": [],
      "source": [
        "class Decoder(tf.keras.Model):\n",
        "    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
        "        super(Decoder, self).__init__()\n",
        "        self.batch_sz = batch_sz\n",
        "        self.dec_units = dec_units\n",
        "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "        self.gru = gru(self.dec_units)\n",
        "        self.fc = tf.keras.layers.Dense(vocab_size)\n",
        "        \n",
        "        # used for attention\n",
        "        self.W1 = tf.keras.layers.Dense(self.dec_units)\n",
        "        self.W2 = tf.keras.layers.Dense(self.dec_units)\n",
        "        self.V = tf.keras.layers.Dense(1)\n",
        "        \n",
        "    def call(self, x, hidden, enc_output):\n",
        "        # enc_output shape == (batch_size, max_length, hidden_size)\n",
        "        \n",
        "        # hidden shape == (batch_size, hidden size)\n",
        "        # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
        "        # we are doing this to perform addition to calculate the score\n",
        "        hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
        "        \n",
        "        # score shape == (batch_size, max_length, hidden_size)\n",
        "        score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n",
        "        \n",
        "        # attention_weights shape == (batch_size, max_length, 1)\n",
        "        # we get 1 at the last axis because we are applying score to self.V\n",
        "        attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
        "        \n",
        "        # context_vector shape after sum == (batch_size, hidden_size)\n",
        "        context_vector = attention_weights * enc_output\n",
        "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "        \n",
        "        # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
        "        x = self.embedding(x)\n",
        "        \n",
        "        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
        "        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "        \n",
        "        # passing the concatenated vector to the GRU\n",
        "        output, state = self.gru(x)\n",
        "        \n",
        "        # output shape == (batch_size * 1, hidden_size)\n",
        "        output = tf.reshape(output, (-1, output.shape[2]))\n",
        "        \n",
        "        # output shape == (batch_size * 1, vocab)\n",
        "        x = self.fc(output)\n",
        "        \n",
        "        return x, state, attention_weights\n",
        "        \n",
        "    def initialize_hidden_state(self):\n",
        "        return tf.zeros((self.batch_sz, self.dec_units))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "P5UY8wko3jFp"
      },
      "outputs": [],
      "source": [
        "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
        "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_ch_71VbIRfK"
      },
      "source": [
        "## Define the optimizer and the loss function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WmTHr5iV3jFr"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.train.AdamOptimizer()\n",
        "\n",
        "\n",
        "def loss_function(real, pred):\n",
        "  mask = 1 - np.equal(real, 0)\n",
        "  loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
        "  return tf.reduce_mean(loss_)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DMVWzzsfNl4e"
      },
      "source": [
        "## Checkpoints (Object-based saving)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Zj8bXQTgNwrF"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
        "                                 encoder=encoder,\n",
        "                                 decoder=decoder)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hpObfY22IddU"
      },
      "source": [
        "## Training\n",
        "\n",
        "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
        "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n",
        "3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
        "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
        "5. Use *teacher forcing* to decide the next input to the decoder.\n",
        "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n",
        "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ddefjBMa3jF0"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 10\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "    start = time.time()\n",
        "    \n",
        "    hidden = encoder.initialize_hidden_state()\n",
        "    total_loss = 0\n",
        "    \n",
        "    for (batch, (inp, targ)) in enumerate(dataset):\n",
        "        loss = 0\n",
        "        \n",
        "        with tf.GradientTape() as tape:\n",
        "            enc_output, enc_hidden = encoder(inp, hidden)\n",
        "            \n",
        "            dec_hidden = enc_hidden\n",
        "            \n",
        "            dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1)       \n",
        "            \n",
        "            # Teacher forcing - feeding the target as the next input\n",
        "            for t in range(1, targ.shape[1]):\n",
        "                # passing enc_output to the decoder\n",
        "                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
        "                \n",
        "                loss += loss_function(targ[:, t], predictions)\n",
        "                \n",
        "                # using teacher forcing\n",
        "                dec_input = tf.expand_dims(targ[:, t], 1)\n",
        "        \n",
        "        batch_loss = (loss / int(targ.shape[1]))\n",
        "        \n",
        "        total_loss += batch_loss\n",
        "        \n",
        "        variables = encoder.variables + decoder.variables\n",
        "        \n",
        "        gradients = tape.gradient(loss, variables)\n",
        "        \n",
        "        optimizer.apply_gradients(zip(gradients, variables))\n",
        "        \n",
        "        if batch % 100 == 0:\n",
        "            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                                         batch,\n",
        "                                                         batch_loss.numpy()))\n",
        "    # saving (checkpoint) the model every 2 epochs\n",
        "    if (epoch + 1) % 2 == 0:\n",
        "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "    \n",
        "    print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
        "                                        total_loss / N_BATCH))\n",
        "    print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mU3Ce8M6I3rz"
      },
      "source": [
        "## Translate\n",
        "\n",
        "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
        "* Stop predicting when the model predicts the *end token*.\n",
        "* And store the *attention weights for every time step*.\n",
        "\n",
        "Note: The encoder output is calculated only once for one input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "EbQpyYs13jF_"
      },
      "outputs": [],
      "source": [
        "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
        "    attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
        "    \n",
        "    sentence = preprocess_sentence(sentence)\n",
        "\n",
        "    inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n",
        "    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n",
        "    inputs = tf.convert_to_tensor(inputs)\n",
        "    \n",
        "    result = ''\n",
        "\n",
        "    hidden = [tf.zeros((1, units))]\n",
        "    enc_out, enc_hidden = encoder(inputs, hidden)\n",
        "\n",
        "    dec_hidden = enc_hidden\n",
        "    dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n",
        "\n",
        "    for t in range(max_length_targ):\n",
        "        predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n",
        "        \n",
        "        # storing the attention weigths to plot later on\n",
        "        attention_weights = tf.reshape(attention_weights, (-1, ))\n",
        "        attention_plot[t] = attention_weights.numpy()\n",
        "\n",
        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
        "\n",
        "        result += targ_lang.idx2word[predicted_id] + ' '\n",
        "\n",
        "        if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n",
        "            return result, sentence, attention_plot\n",
        "        \n",
        "        # the predicted ID is fed back into the model\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "    return result, sentence, attention_plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "s5hQWlbN3jGF"
      },
      "outputs": [],
      "source": [
        "# function for plotting the attention weights\n",
        "def plot_attention(attention, sentence, predicted_sentence):\n",
        "    fig = plt.figure(figsize=(10,10))\n",
        "    ax = fig.add_subplot(1, 1, 1)\n",
        "    ax.matshow(attention, cmap='viridis')\n",
        "    \n",
        "    fontdict = {'fontsize': 14}\n",
        "    \n",
        "    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
        "    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
        "\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sl9zUHzg3jGI"
      },
      "outputs": [],
      "source": [
        "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
        "    result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n",
        "        \n",
        "    print('Input: {}'.format(sentence))\n",
        "    print('Predicted translation: {}'.format(result))\n",
        "    \n",
        "    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
        "    plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "n250XbnjOaqP"
      },
      "source": [
        "## Restore the latest checkpoint and test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UJpT9D5_OgP6"
      },
      "outputs": [],
      "source": [
        "# restoring the latest checkpoint in checkpoint_dir\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WrAM0FDomq3E"
      },
      "outputs": [],
      "source": [
        "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "zSx2iM36EZQZ"
      },
      "outputs": [],
      "source": [
        "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "A3LLCx3ZE0Ls"
      },
      "outputs": [],
      "source": [
        "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DUQVLVqUE1YW"
      },
      "outputs": [],
      "source": [
        "# wrong translation\n",
        "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RTe5P5ioMJwN"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
        "* Experiment with training on a larger dataset, or using more epochs\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "nmt_with_attention.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
          "timestamp": 1527858391290
        },
        {
          "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
          "timestamp": 1527776041613
        }
      ],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}