aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
blob: 1a5a186e7a3e456cc43f8091370d3eeb795d5e0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "image_captioning_with_attention.ipynb",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [
        {
          "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg",
          "timestamp": 1530222436922
        }
      ],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "metadata": {
        "id": "K2s1A9eLRPEj",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\").\n"
      ]
    },
    {
      "metadata": {
        "id": "Cffg2i257iMS",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Image Captioning with Attention\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
        "<a target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
        "</td><td>\n",
        "<a target=\"_blank\"  href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
      ]
    },
    {
      "metadata": {
        "id": "QASbY_HGo4Lq",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Image captioning is the task of generating a caption for an image. Given an image like this:\n",
        "\n",
        "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n",
        "\n",
        "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
        "\n",
        "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
        "\n",
        "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
        "\n",
        "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n",
        "\n",
        "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n",
        "\n",
        "This notebook is an end-to-end example. If you run it, it will download the  [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n",
        "\n",
        "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n",
        "\n",
        "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n"
      ]
    },
    {
      "metadata": {
        "id": "U8l4RJ0XRPEm",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# Import TensorFlow and enable eager execution\n",
        "# This code requires TensorFlow version >=1.9\n",
        "import tensorflow as tf\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "# We'll generate plots of attention in order to see which parts of an image\n",
        "# our model focuses on during captioning\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Scikit-learn includes many helpful utilities\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.utils import shuffle\n",
        "\n",
        "import re\n",
        "import numpy as np\n",
        "import os\n",
        "import time\n",
        "import json\n",
        "from glob import glob\n",
        "from PIL import Image\n",
        "import pickle"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "b6qbGw8MRPE5",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Download and prepare the MS-COCO dataset\n",
        "\n",
        "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically.  \n",
        "\n",
        "**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
      ]
    },
    {
      "metadata": {
        "id": "krQuPYTtRPE7",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "annotation_zip = tf.keras.utils.get_file('captions.zip', \n",
        "                                          cache_subdir=os.path.abspath('.'),\n",
        "                                          origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n",
        "                                          extract = True)\n",
        "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n",
        "\n",
        "name_of_zip = 'train2014.zip'\n",
        "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n",
        "  image_zip = tf.keras.utils.get_file(name_of_zip, \n",
        "                                      cache_subdir=os.path.abspath('.'),\n",
        "                                      origin = 'http://images.cocodataset.org/zips/train2014.zip',\n",
        "                                      extract = True)\n",
        "  PATH = os.path.dirname(image_zip)+'/train2014/'\n",
        "else:\n",
        "  PATH = os.path.abspath('.')+'/train2014/'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "aANEzb5WwSzg",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Optionally, limit the size of the training set for faster training\n",
        "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data."
      ]
    },
    {
      "metadata": {
        "id": "4G3b8x8_RPFD",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# read the json file\n",
        "with open(annotation_file, 'r') as f:\n",
        "    annotations = json.load(f)\n",
        "\n",
        "# storing the captions and the image name in vectors\n",
        "all_captions = []\n",
        "all_img_name_vector = []\n",
        "\n",
        "for annot in annotations['annotations']:\n",
        "    caption = '<start> ' + annot['caption'] + ' <end>'\n",
        "    image_id = annot['image_id']\n",
        "    full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n",
        "    \n",
        "    all_img_name_vector.append(full_coco_image_path)\n",
        "    all_captions.append(caption)\n",
        "\n",
        "# shuffling the captions and image_names together\n",
        "# setting a random state\n",
        "train_captions, img_name_vector = shuffle(all_captions,\n",
        "                                          all_img_name_vector,\n",
        "                                          random_state=1)\n",
        "\n",
        "# selecting the first 30000 captions from the shuffled set\n",
        "num_examples = 30000\n",
        "train_captions = train_captions[:num_examples]\n",
        "img_name_vector = img_name_vector[:num_examples]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "mPBMgK34RPFL",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "len(train_captions), len(all_captions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "8cSW4u-ORPFQ",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Preprocess the images using InceptionV3\n",
        "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n",
        "\n",
        "First, we will need to convert the images into the format inceptionV3 expects by:\n",
        "* Resizing the image to (299, 299)\n",
        "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)."
      ]
    },
    {
      "metadata": {
        "id": "zXR0217aRPFR",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def load_image(image_path):\n",
        "    img = tf.read_file(image_path)\n",
        "    img = tf.image.decode_jpeg(img, channels=3)\n",
        "    img = tf.image.resize_images(img, (299, 299))\n",
        "    img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
        "    return img, image_path"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "MDvIu4sXRPFV",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Initialize InceptionV3 and load the pretrained Imagenet weights\n",
        "\n",
        "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n",
        "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n",
        "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n",
        "* We avoid doing this during training so it does not become a bottleneck. \n",
        "* After all the images are passed through the network, we pickle the dictionary and save it to disk."
      ]
    },
    {
      "metadata": {
        "id": "RD3vW4SsRPFW",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "image_model = tf.keras.applications.InceptionV3(include_top=False, \n",
        "                                                weights='imagenet')\n",
        "new_input = image_model.input\n",
        "hidden_layer = image_model.layers[-1].output\n",
        "\n",
        "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "rERqlR3WRPGO",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Caching the features extracted from InceptionV3\n",
        "\n",
        "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n",
        "\n",
        "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n",
        "\n",
        "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n",
        "\n",
        "```for img, path in image_dataset:``` \n",
        "\n",
        "to:\n",
        "\n",
        "```for img, path in tqdm(image_dataset):```."
      ]
    },
    {
      "metadata": {
        "id": "Dx_fvbVgRPGQ",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# getting the unique images\n",
        "encode_train = sorted(set(img_name_vector))\n",
        "\n",
        "# feel free to change the batch_size according to your system configuration\n",
        "image_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "                                encode_train).map(load_image).batch(16)\n",
        "\n",
        "for img, path in image_dataset:\n",
        "  batch_features = image_features_extract_model(img)\n",
        "  batch_features = tf.reshape(batch_features, \n",
        "                              (batch_features.shape[0], -1, batch_features.shape[3]))\n",
        "\n",
        "  for bf, p in zip(batch_features, path):\n",
        "    path_of_feature = p.numpy().decode(\"utf-8\")\n",
        "    np.save(path_of_feature, bf.numpy())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "nyqH3zFwRPFi",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Preprocess and tokenize the captions\n",
        "\n",
        "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a  vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n",
        "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n",
        "* Finally, we create a word --> index mapping and vice-versa.\n",
        "* We will then pad all sequences to the be same length as the longest one. "
      ]
    },
    {
      "metadata": {
        "id": "HZfK8RhQRPFj",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# This will find the maximum length of any caption in our dataset\n",
        "def calc_max_length(tensor):\n",
        "    return max(len(t) for t in tensor)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "oJGE34aiRPFo",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# The steps above is a general process of dealing with text processing\n",
        "\n",
        "# choosing the top 5000 words from the vocabulary\n",
        "top_k = 5000\n",
        "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n",
        "                                                  oov_token=\"<unk>\", \n",
        "                                                  filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n",
        "tokenizer.fit_on_texts(train_captions)\n",
        "train_seqs = tokenizer.texts_to_sequences(train_captions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "8Q44tNQVRPFt",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n",
        "# putting <unk> token in the word2idx dictionary\n",
        "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n",
        "tokenizer.word_index['<pad>'] = 0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "0fpJb5ojRPFv",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# creating the tokenized vectors\n",
        "train_seqs = tokenizer.texts_to_sequences(train_captions)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "olQArbgbRPF1",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# creating a reverse mapping (index -> word)\n",
        "index_word = {value:key for key, value in tokenizer.word_index.items()}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "AidglIZVRPF4",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# padding each vector to the max_length of the captions\n",
        "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n",
        "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "gL0wkttkRPGA",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# calculating the max_length \n",
        "# used to store the attention weights\n",
        "max_length = calc_max_length(train_seqs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "M3CD75nDpvTI",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Split the data into training and testing"
      ]
    },
    {
      "metadata": {
        "id": "iS7DDMszRPGF",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# Create training and validation sets using 80-20 split\n",
        "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n",
        "                                                                    cap_vector, \n",
        "                                                                    test_size=0.2, \n",
        "                                                                    random_state=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XmViPkRFRPGH",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "uEWM9xrYcg45",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "id": "Q3TnZ1ToRPGV",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# feel free to change these parameters according to your system's configuration\n",
        "\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = 1000\n",
        "embedding_dim = 256\n",
        "units = 512\n",
        "vocab_size = len(tokenizer.word_index)\n",
        "# shape of the vector extracted from InceptionV3 is (64, 2048)\n",
        "# these two variables represent that\n",
        "features_shape = 2048\n",
        "attention_features_shape = 64"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "SmZS2N0bXG3T",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# loading the numpy files \n",
        "def map_func(img_name, cap):\n",
        "    img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n",
        "    return img_tensor, cap"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "FDF_Nm3tRPGZ",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n",
        "\n",
        "# using map to load the numpy files in parallel\n",
        "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n",
        "# https://www.tensorflow.org/api_docs/python/tf/py_func\n",
        "dataset = dataset.map(lambda item1, item2: tf.py_func(\n",
        "          map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n",
        "\n",
        "# shuffling and batching\n",
        "dataset = dataset.shuffle(BUFFER_SIZE)\n",
        "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n",
        "dataset = dataset.batch(BATCH_SIZE)\n",
        "dataset = dataset.prefetch(1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "nrvoDphgRPGd",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Model\n",
        "\n",
        "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
        "\n",
        "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n",
        "\n",
        "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n",
        "* We squash that to a shape of (64, 2048).\n",
        "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n",
        "* The RNN(here GRU) attends over the image to predict the next word."
      ]
    },
    {
      "metadata": {
        "id": "AAppCGLKRPGd",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def gru(units):\n",
        "  # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n",
        "  # significant speedup).\n",
        "  if tf.test.is_gpu_available():\n",
        "    return tf.keras.layers.CuDNNGRU(units, \n",
        "                                    return_sequences=True, \n",
        "                                    return_state=True, \n",
        "                                    recurrent_initializer='glorot_uniform')\n",
        "  else:\n",
        "    return tf.keras.layers.GRU(units, \n",
        "                               return_sequences=True, \n",
        "                               return_state=True, \n",
        "                               recurrent_activation='sigmoid', \n",
        "                               recurrent_initializer='glorot_uniform')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ja2LFTMSdeV3",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class BahdanauAttention(tf.keras.Model):\n",
        "  def __init__(self, units):\n",
        "    super(BahdanauAttention, self).__init__()\n",
        "    self.W1 = tf.keras.layers.Dense(units)\n",
        "    self.W2 = tf.keras.layers.Dense(units)\n",
        "    self.V = tf.keras.layers.Dense(1)\n",
        "  \n",
        "  def call(self, features, hidden):\n",
        "    # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n",
        "    \n",
        "    # hidden shape == (batch_size, hidden_size)\n",
        "    # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n",
        "    hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
        "    \n",
        "    # score shape == (batch_size, 64, hidden_size)\n",
        "    score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n",
        "    \n",
        "    # attention_weights shape == (batch_size, 64, 1)\n",
        "    # we get 1 at the last axis because we are applying score to self.V\n",
        "    attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
        "    \n",
        "    # context_vector shape after sum == (batch_size, hidden_size)\n",
        "    context_vector = attention_weights * features\n",
        "    context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "    \n",
        "    return context_vector, attention_weights"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "AZ7R1RxHRPGf",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class CNN_Encoder(tf.keras.Model):\n",
        "    # Since we have already extracted the features and dumped it using pickle\n",
        "    # This encoder passes those features through a Fully connected layer\n",
        "    def __init__(self, embedding_dim):\n",
        "        super(CNN_Encoder, self).__init__()\n",
        "        # shape after fc == (batch_size, 64, embedding_dim)\n",
        "        self.fc = tf.keras.layers.Dense(embedding_dim)\n",
        "        \n",
        "    def call(self, x):\n",
        "        x = self.fc(x)\n",
        "        x = tf.nn.relu(x)\n",
        "        return x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "V9UbGQmERPGi",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "class RNN_Decoder(tf.keras.Model):\n",
        "  def __init__(self, embedding_dim, units, vocab_size):\n",
        "    super(RNN_Decoder, self).__init__()\n",
        "    self.units = units\n",
        "\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = gru(self.units)\n",
        "    self.fc1 = tf.keras.layers.Dense(self.units)\n",
        "    self.fc2 = tf.keras.layers.Dense(vocab_size)\n",
        "    \n",
        "    self.attention = BahdanauAttention(self.units)\n",
        "        \n",
        "  def call(self, x, features, hidden):\n",
        "    # defining attention as a separate model\n",
        "    context_vector, attention_weights = self.attention(features, hidden)\n",
        "    \n",
        "    # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
        "    x = self.embedding(x)\n",
        "    \n",
        "    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
        "    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "    \n",
        "    # passing the concatenated vector to the GRU\n",
        "    output, state = self.gru(x)\n",
        "    \n",
        "    # shape == (batch_size, max_length, hidden_size)\n",
        "    x = self.fc1(output)\n",
        "    \n",
        "    # x shape == (batch_size * max_length, hidden_size)\n",
        "    x = tf.reshape(x, (-1, x.shape[2]))\n",
        "    \n",
        "    # output shape == (batch_size * max_length, vocab)\n",
        "    x = self.fc2(x)\n",
        "\n",
        "    return x, state, attention_weights\n",
        "\n",
        "  def reset_state(self, batch_size):\n",
        "    return tf.zeros((batch_size, self.units))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Qs_Sr03wRPGk",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "encoder = CNN_Encoder(embedding_dim)\n",
        "decoder = RNN_Decoder(embedding_dim, units, vocab_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "-bYN7xA0RPGl",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "optimizer = tf.train.AdamOptimizer()\n",
        "\n",
        "# We are masking the loss calculated for padding\n",
        "def loss_function(real, pred):\n",
        "    mask = 1 - np.equal(real, 0)\n",
        "    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
        "    return tf.reduce_mean(loss_)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "PHod7t72RPGn",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Training\n",
        "\n",
        "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n",
        "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n",
        "* The decoder returns the predictions and the decoder hidden state.\n",
        "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
        "* Use teacher forcing to decide the next input to the decoder.\n",
        "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n",
        "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n"
      ]
    },
    {
      "metadata": {
        "id": "Vt4WZ5mhJE-E",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# adding this in a separate cell because if you run the training cell \n",
        "# many times, the loss_plot array will be reset\n",
        "loss_plot = []"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "UlA4VIQpRPGo",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "EPOCHS = 20\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "    start = time.time()\n",
        "    total_loss = 0\n",
        "    \n",
        "    for (batch, (img_tensor, target)) in enumerate(dataset):\n",
        "        loss = 0\n",
        "        \n",
        "        # initializing the hidden state for each batch\n",
        "        # because the captions are not related from image to image\n",
        "        hidden = decoder.reset_state(batch_size=target.shape[0])\n",
        "\n",
        "        dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * BATCH_SIZE, 1)\n",
        "        \n",
        "        with tf.GradientTape() as tape:\n",
        "            features = encoder(img_tensor)\n",
        "            \n",
        "            for i in range(1, target.shape[1]):\n",
        "                # passing the features through the decoder\n",
        "                predictions, hidden, _ = decoder(dec_input, features, hidden)\n",
        "\n",
        "                loss += loss_function(target[:, i], predictions)\n",
        "                \n",
        "                # using teacher forcing\n",
        "                dec_input = tf.expand_dims(target[:, i], 1)\n",
        "        \n",
        "        total_loss += (loss / int(target.shape[1]))\n",
        "        \n",
        "        variables = encoder.variables + decoder.variables\n",
        "        \n",
        "        gradients = tape.gradient(loss, variables) \n",
        "        \n",
        "        optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
        "        \n",
        "        if batch % 100 == 0:\n",
        "            print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n",
        "                                                          batch, \n",
        "                                                          loss.numpy() / int(target.shape[1])))\n",
        "    # storing the epoch end loss value to plot later\n",
        "    loss_plot.append(total_loss / len(cap_vector))\n",
        "    \n",
        "    print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n",
        "                                         total_loss/len(cap_vector)))\n",
        "    print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "1Wm83G-ZBPcC",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "plt.plot(loss_plot)\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.title('Loss Plot')\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "xGvOcLQKghXN",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Caption!\n",
        "\n",
        "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
        "* Stop predicting when the model predicts the end token.\n",
        "* And store the attention weights for every time step."
      ]
    },
    {
      "metadata": {
        "id": "RCWpDtyNRPGs",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def evaluate(image):\n",
        "    attention_plot = np.zeros((max_length, attention_features_shape))\n",
        "\n",
        "    hidden = decoder.reset_state(batch_size=1)\n",
        "\n",
        "    temp_input = tf.expand_dims(load_image(image)[0], 0)\n",
        "    img_tensor_val = image_features_extract_model(temp_input)\n",
        "    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n",
        "\n",
        "    features = encoder(img_tensor_val)\n",
        "\n",
        "    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n",
        "    result = []\n",
        "\n",
        "    for i in range(max_length):\n",
        "        predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n",
        "\n",
        "        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
        "\n",
        "        predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
        "        result.append(index_word[predicted_id])\n",
        "\n",
        "        if index_word[predicted_id] == '<end>':\n",
        "            return result, attention_plot\n",
        "\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "    attention_plot = attention_plot[:len(result), :]\n",
        "    return result, attention_plot"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "fD_y7PD6RPGt",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def plot_attention(image, result, attention_plot):\n",
        "    temp_image = np.array(Image.open(image))\n",
        "\n",
        "    fig = plt.figure(figsize=(10, 10))\n",
        "    \n",
        "    len_result = len(result)\n",
        "    for l in range(len_result):\n",
        "        temp_att = np.resize(attention_plot[l], (8, 8))\n",
        "        ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n",
        "        ax.set_title(result[l])\n",
        "        img = ax.imshow(temp_image)\n",
        "        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "io7ws3ReRPGv",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# captions on the validation set\n",
        "rid = np.random.randint(0, len(img_name_val))\n",
        "image = img_name_val[rid]\n",
        "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n",
        "result, attention_plot = evaluate(image)\n",
        "\n",
        "print ('Real Caption:', real_caption)\n",
        "print ('Prediction Caption:', ' '.join(result))\n",
        "plot_attention(image, result, attention_plot)\n",
        "# opening the image\n",
        "Image.open(img_name_val[rid])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Rprk3HEvZuxb",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Try it on your own images\n",
        "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n"
      ]
    },
    {
      "metadata": {
        "id": "9Psd1quzaAWg",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "image_url = 'https://tensorflow.org/images/surf.jpg'\n",
        "image_extension = image_url[-4:]\n",
        "image_path = tf.keras.utils.get_file('image'+image_extension, \n",
        "                                     origin=image_url)\n",
        "\n",
        "result, attention_plot = evaluate(image_path)\n",
        "print ('Prediction Caption:', ' '.join(result))\n",
        "plot_attention(image_path, result, attention_plot)\n",
        "# opening the image\n",
        "Image.open(image_path)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "VJZXyJco6uLO",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Next steps\n",
        "\n",
        "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset."
      ]
    }
  ]
}