aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
blob: 7c0f9b5b8161a763c4153ebdeece7e0d1b90b384 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "automatic_differentiation.ipynb",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "t09eeeR5prIJ",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "metadata": {
        "id": "GCCk8_dHpuNf",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "cellView": "form"
      },
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "xh8WkEwWpnm7",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Automatic differentiation and gradient tape"
      ]
    },
    {
      "metadata": {
        "id": "idv0bPeCp325",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
        "<a target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "</td><td>\n",
        "<a target=\"_blank\"  href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
      ]
    },
    {
      "metadata": {
        "id": "vDJ4XzMqodTy",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models."
      ]
    },
    {
      "metadata": {
        "id": "GQJysDM__Qb0",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Setup\n"
      ]
    },
    {
      "metadata": {
        "id": "OiMPZStlibBv",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "tfe = tf.contrib.eager # Shorthand for some symbols"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "1CLWJl0QliB0",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Derivatives of a function\n",
        "\n",
        "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: "
      ]
    },
    {
      "metadata": {
        "id": "9FViq92UX7P8",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "from math import pi\n",
        "\n",
        "def f(x):\n",
        "  return tf.square(tf.sin(x))\n",
        "\n",
        "assert f(pi/2).numpy() == 1.0\n",
        "\n",
        "\n",
        "# grad_f will return a list of derivatives of f\n",
        "# with respect to its arguments. Since f() has a single argument,\n",
        "# grad_f will return a list with a single element.\n",
        "grad_f = tfe.gradients_function(f)\n",
        "assert tf.abs(grad_f(pi/2)[0]).numpy() < 1e-7"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "v9fPs8RyopCf",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "### Higher-order gradients\n",
        "\n",
        "The same API can be used to differentiate as many times as you like:\n"
      ]
    },
    {
      "metadata": {
        "id": "3D0ZvnGYo0rW",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def f(x):\n",
        "  return tf.square(tf.sin(x))\n",
        "\n",
        "def grad(f):\n",
        "  return lambda x: tfe.gradients_function(f)(x)[0]\n",
        "\n",
        "x = tf.lin_space(-2*pi, 2*pi, 100)  # 100 points between -2π and +2π\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "plt.plot(x, f(x), label=\"f\")\n",
        "plt.plot(x, grad(f)(x), label=\"first derivative\")\n",
        "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n",
        "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "-39gouo7mtgu",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Gradient tapes\n",
        "\n",
        "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`.  To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n",
        "\n",
        "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n",
        "\n"
      ]
    },
    {
      "metadata": {
        "id": "MH0UfjympWf7",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "def f(x, y):\n",
        "  output = 1\n",
        "  for i in range(y):\n",
        "    output = tf.multiply(output, x)\n",
        "  return output\n",
        "\n",
        "def g(x, y):\n",
        "  # Return the gradient of `f` with respect to it's first parameter\n",
        "  return tfe.gradients_function(f)(x, y)[0]\n",
        "\n",
        "assert f(3.0, 2).numpy() == 9.0   # f(x, 2) is essentially x * x\n",
        "assert g(3.0, 2).numpy() == 6.0   # And its gradient will be 2 * x\n",
        "assert f(4.0, 3).numpy() == 64.0  # f(x, 3) is essentially x * x * x\n",
        "assert g(4.0, 3).numpy() == 48.0  # And its gradient will be 3 * x * x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "aNmR5-jhpX2t",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n",
        "\n",
        "For example:"
      ]
    },
    {
      "metadata": {
        "id": "bAFeIE8EuVIq",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "x = tf.ones((2, 2))\n",
        "  \n",
        "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n",
        "# a single t.gradient() call when the bug is resolved.\n",
        "with tf.GradientTape(persistent=True) as t:\n",
        "  # TODO(ashankar): Explain with \"watch\" argument better?\n",
        "  t.watch(x)\n",
        "  y = tf.reduce_sum(x)\n",
        "  z = tf.multiply(y, y)\n",
        "\n",
        "# Use the same tape to compute the derivative of z with respect to the\n",
        "# intermediate value y.\n",
        "dz_dy = t.gradient(z, y)\n",
        "assert dz_dy.numpy() == 8.0\n",
        "\n",
        "# Derivative of z with respect to the original input tensor x\n",
        "dz_dx = t.gradient(z, x)\n",
        "for i in [0, 1]:\n",
        "  for j in [0, 1]:\n",
        "    assert dz_dx[i][j].numpy() == 8.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "DK05KXrAAld3",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "### Higher-order gradients\n",
        "\n",
        "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:"
      ]
    },
    {
      "metadata": {
        "id": "cPQgthZ7ugRJ",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n",
        "\n",
        "x = tf.constant(1.0)  # Convert the Python 1.0 to a Tensor object\n",
        "\n",
        "with tf.GradientTape() as t:\n",
        "  with tf.GradientTape() as t2:\n",
        "    t2.watch(x)\n",
        "    y = x * x * x\n",
        "  # Compute the gradient inside the 't' context manager\n",
        "  # which means the gradient computation is differentiable as well.\n",
        "  dy_dx = t2.gradient(y, x)\n",
        "d2y_dx2 = t.gradient(dy_dx, x)\n",
        "\n",
        "assert dy_dx.numpy() == 3.0\n",
        "assert d2y_dx2.numpy() == 6.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "4U1KKzUpNl58",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Next Steps\n",
        "\n",
        "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)."
      ]
    }
  ]
}