aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/eager_test.py
blob: e438832a23a670596d12cbc67d71a9f561b82193 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for eager execution using XLA."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
from tensorflow.python.training import adam


class EagerTest(XLATestCase):

  def testBasic(self):
    with self.test_scope():
      three = constant_op.constant(3)
      five = constant_op.constant(5)
      product = three * five
      self.assertAllEqual(15, product)

  def testGradientTape(self):
    with self.test_scope():

      x = constant_op.constant(1.0)
      y = constant_op.constant(10.0)
      with backprop.GradientTape(persistent=True) as tape:
        tape.watch(x)
        tape.watch(y)
        a = x + y + x * y
      da_dx = tape.gradient(a, x)
      da_dy = tape.gradient(a, y)

    self.assertEqual(11.0, da_dx.numpy())
    self.assertEqual(2.0, da_dy.numpy())

  def testExecuteListOutputLen0(self):
    with self.test_scope():
      empty = constant_op.constant([], dtype=dtypes.float32)
      result = array_ops.unstack(empty, 0)
      self.assertTrue(isinstance(result, list))
      self.assertEqual(0, len(result))

  def testExecuteListOutputLen1(self):
    with self.test_scope():
      split_dim = constant_op.constant(1)
      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
      result = array_ops.split(value, 1, axis=split_dim)
      self.assertTrue(isinstance(result, list))
      self.assertEqual(1, len(result))
      self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])

  def testExecuteListOutputLen3(self):
    with self.test_scope():
      split_dim = constant_op.constant(1)
      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
      result = array_ops.split(value, 3, axis=split_dim)
      self.assertTrue(isinstance(result, list))
      self.assertEqual(3, len(result))
      self.assertAllEqual([[0], [3]], result[0])
      self.assertAllEqual([[1], [4]], result[1])
      self.assertAllEqual([[2], [5]], result[2])

  def testBasicGraph(self):
    # Run some ops eagerly
    with self.test_scope():
      three = constant_op.constant(3)
      five = constant_op.constant(5)
      product = three * five
      self.assertAllEqual(15, product)

    # Run some ops graphly
    with context.graph_mode(), self.test_session() as sess:
      with self.test_scope():
        three = constant_op.constant(3)
        five = constant_op.constant(5)
        product = three * five
        self.assertAllEqual(15, sess.run(product))

  def testDegenerateSlices(self):
    with self.test_scope():
      npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
      t = constant_op.constant(npt)
      # degenerate by offering a forward interval with a negative stride
      self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
      # degenerate with a reverse interval with a positive stride
      self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
      # empty interval in every dimension
      self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])

  def testIdentity(self):
    with self.test_scope():
      self.assertAllEqual(2, array_ops.identity(2))

  def testIdentityOnVariable(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable(True)
      i = array_ops.identity(v)
    self.assertAllEqual(True, i.numpy())

  def testAssignAddVariable(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable(1.0)
      v.assign_add(2.0)
    self.assertEqual(3.0, v.numpy())

  def testReadAssignRead(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable(1.0)
      val1 = v.read_value()
      v.assign_add(2.0)
      val2 = v.read_value()
    self.assertEqual(1.0, val1.numpy())
    self.assertEqual(3.0, val2.numpy())

  def testGradient(self):
    def f(x):
      return x

    with self.test_scope():
      grad_fn = backprop.gradients_function(f)
      self.assertAllEqual(2., grad_fn(1., dy=2.)[0])

  def testVariableGradient(self):
    with self.test_scope():
      v0 = resource_variable_ops.ResourceVariable(1.0)

      def f():
        x = v0 * v0
        return x

      grads = backprop.implicit_grad(f)()
    self.assertEqual(2., grads[0][0].numpy())

  def testMultipleVariableReads(self):
    # This test makes sure consecutive variable reads don't copy
    # the underlying memory.
    with self.test_scope():
      # Create 128MiB variables
      var = resource_variable_ops.ResourceVariable(
          array_ops.ones([32, 1024, 1024]))

      # Read the same variable 100 times. If the underlying tensor
      # is not copied, this is a trivial operation. If it is copied,
      # this will eat over 13GB and OOM.
      values = []
      for _ in range(100):
        values.append(var.value())

  # The shape, shape_n, size, and rank are tested here because their
  # execution kernels (as opposed to compilation only tf2xla kernels)
  # are distincts from tf2xla kernels.

  def testShape(self):
    def const(value):
      return array_ops.shape(
          constant_op.constant(value)).numpy()

    def ones(value):
      return array_ops.shape(
          array_ops.ones(value)).numpy()

    with self.test_scope():
      # Shapes of directly constructed tensors
      self.assertAllEqual([], const(3))
      self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
      self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
      self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))

      # Shapes of tensors created by op running on device
      # We make this distinction because directly constructed tensors
      # are treated differently in a few places that can influence shape:
      #  - they always have on_host_tensor
      #  - they and their shapes can be cached
      #  - they end up on device via a copy, instead of as program output
      self.assertAllEqual([], ones([]))
      self.assertAllEqual([3], ones([3]))
      self.assertAllEqual([2, 2], ones([2, 2]))
      self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))

  def testShapeN(self):
    with self.test_scope():
      # Shapes of directly constructed tensors
      shapes = array_ops.shape_n([
          constant_op.constant(1.0),
          constant_op.constant([1.0, 2.0, 3.0]),
          constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
      self.assertAllEqual(
          [[], [3], [2, 2]],
          [x.numpy().tolist() for x in shapes])

      # Shapes of tensors created by op running on device
      shapes = array_ops.shape_n([
          array_ops.ones([]),
          array_ops.ones([3]),
          array_ops.ones([2, 2])])
      self.assertAllEqual(
          [[], [3], [2, 2]],
          [x.numpy().tolist() for x in shapes])

  def testSize(self):
    with self.test_scope():
      self.assertEqual(
          1, array_ops.size(constant_op.constant(1.0)).numpy())
      self.assertEqual(
          3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
      self.assertEqual(
          4, array_ops.size(
              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())

  def testRank(self):
    with self.test_scope():
      self.assertEqual(
          0, array_ops.rank(constant_op.constant(1.0)).numpy())
      self.assertEqual(
          1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
      self.assertEqual(
          2, array_ops.rank(
              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())

  def testAdam(self):
    with self.test_scope():
      optimizer = adam.AdamOptimizer(0.1)
      x = resource_variable_ops.ResourceVariable(10.0)
      with backprop.GradientTape() as tape:
        y = x * x
      dy_dx = tape.gradient(y, x)
      optimizer.apply_gradients([(dy_dx, x)])
      self.assertAlmostEqual(9.9, x.numpy(), places=3)

  def testAdamSparse(self):
    with ops.device('/cpu:0'):
      # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates
      # are not implemented on TPU.
      embedding_matrix = resource_variable_ops.ResourceVariable(
          array_ops.ones([3, 2]))

    with self.test_scope():
      with backprop.GradientTape() as tape:
        embedding = embedding_ops.embedding_lookup(embedding_matrix, [1])
        y = math_ops.reduce_sum(embedding)
      dy_dx = tape.gradient(y, embedding_matrix)
      self.assertIsInstance(dy_dx, ops.IndexedSlices)
      optimizer = adam.AdamOptimizer(0.1)
      # The gradient application operations will run on CPU because optimizer
      # updates are always collocated with the variable.
      optimizer.apply_gradients([(dy_dx, embedding_matrix)])

      # This assign_add will run on CPU because when an input to an
      # operation is a resource, this operation is placed on the resource's
      # device by the eager runtime.
      embedding_matrix.assign_add(array_ops.ones([3, 2]))

    self.assertAllClose([[2.0, 2.0],
                         [1.9, 1.9],
                         [2.0, 2.0]], embedding_matrix.numpy())


class EagerFunctionTest(XLATestCase):

  def testBasic(self):
    with self.test_scope():
      matmul = function.defun(math_ops.matmul)
      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
      sq = matmul(t, t, transpose_a=True)
      self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])

  def testConv(self):
    if 'GPU' in self.device:
      # TODO(b/32333178)
      self.skipTest('Current implementation of RandomStandardNormal kernel '
                    'is very slow on GPU, and has been blacklisted.')
    with self.test_scope():
      data_format = 'channels_last'
      conv = convolutional.Conv2D(
          filters=1, kernel_size=2, padding='VALID',
          data_format=data_format, activation=nn_ops.relu,
          kernel_initializer=init_ops.ones_initializer(),
          bias_initializer=init_ops.zeros_initializer())
      pool = pooling.MaxPooling2D(2, 2, data_format=data_format)

      def model(x):
        x = conv(x)
        return pool(x)
      model = function.defun(model)

      x = array_ops.ones([1, 4, 4, 1])
      y = model(x)
      self.assertAllEqual(y.numpy(), [[[[4.]]]])

  def testReadVariable(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable(1.0)

      @function.defun
      def f():
        return v.read_value()

      var = f()
      self.assertEqual(1.0, var.numpy())

  def testUpdateVariable(self):
    with self.test_scope():
      v = resource_variable_ops.ResourceVariable(1.0)

      def f(v):
        v.assign_add(1.0)
        return v

      f = function.defun(f)

      var = f(v)
      self.assertEqual(2.0, var.numpy())

  def testAllArgumentKinds(self):
    """Test a complex function that takes different argument kinds.

    tf2xla machinery that translates, compiles, and runs defuns
    classifies arguments into: compile-time constants, regular tensors,
    and resources. This test creates a function with a mix of all these
    kinds. Moreover, the order of function arguments is intentionally mixed up.

    This also tests the case when the same argument is a compile-time constant
    as well as used in an operation that normally expects its inputs to be
    in device memory - addition in this case.
    """
    with self.test_scope():
      def foo(c1, r1, v1, c2, v2, r2):
        # c1 and c2 are compile-time constants
        # r1 and r2 are regular tensors
        # v1 and v2 are resource variables
        a = c1 + r1
        b = math_ops.cast(c2, dtypes.float32) + v2
        c = array_ops.slice(v1, c1, c2)
        d = r2 * v2
        return a, b, c, d

      foo = function.defun(foo)

      c1 = [0, 0]
      c2 = array_ops.ones([2], dtype=dtypes.int32)

      r1 = array_ops.ones([2])
      r2 = [[2., 2.], [3., 3.]]

      v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
      v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])

      a, b, c, d = foo(c1, r1, v1, c2, v2, r2)

      self.assertAllEqual([1, 1], a.numpy())
      self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
      self.assertAllEqual([[1.]], c.numpy())
      self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())

  def testDefunInGradientTape(self):
    with self.test_scope():
      v0 = resource_variable_ops.ResourceVariable(5.0)

      @function.defun
      def f(x):
        x = v0 * v0 * x
        return x

      x = constant_op.constant(3.0)
      with backprop.GradientTape() as tape:
        y = f(x)
      dy = tape.gradient(y, v0)

    self.assertEqual(75, y.numpy())
    self.assertEqual(30, dy.numpy())

  def testSliceInDefun(self):
    with self.test_scope():

      @function.defun(compiled=True)
      def f(x, y):
        return x[0::2, y:, ...]

      x = array_ops.ones([2, 3, 4])
      y = array_ops.ones([], dtype=dtypes.int32)
      with backprop.GradientTape() as tape:
        tape.watch(x)
        tape.watch(y)
        z = f(x, y)
      dz = tape.gradient(z, x)

      self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
      self.assertAllEqual((2, 3, 4), dz.shape.as_list())


class ExcessivePaddingTest(XLATestCase):
  """Test that eager execution works with TPU flattened tensors.

  Tensors that would normally be excessively padded when written
  to TPU memory are reshaped to 1-D flat tensors.

  This test case verifies that such tensors work with eager execution.

  The flattening currently only happens on TPU, but tests should work
  fine with all backends as flattening is transparent.
  """

  def testFromConstant(self):
    with self.test_scope():
      # Create constant of shape [100, 2, 1]. This tensor would be
      # excessively padded on TPU.
      tensor = constant_op.constant(100 * [[[10.0], [2.0]]])
      # Use reduce_sum since it requires correctly working with
      # a particular dimension.
      reduced = math_ops.reduce_sum(tensor, axis=1)
      self.assertAllEqual(100 * [[12.0]], reduced)

  def testFromOperation(self):
    with self.test_scope():
      tensor = array_ops.ones([3, 100, 2, 2])
      reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3])
      self.assertAllEqual(100 * [12.0], reduced)

  def testAsFunctionInput(self):
    with self.test_scope():

      @function.defun
      def f(x):
        return math_ops.reduce_sum(x, axis=2)

      tensor = constant_op.constant(100 * [[[10.0, 2.0]]])
      reduced = f(tensor)
      self.assertAllEqual(100 * [[12.0]], reduced)

  def testAsFunctionOutput(self):
    with self.test_scope():

      @function.defun
      def f(x):
        return x * constant_op.constant(100 * [[[10.0, 2.0]]])

      y = f(3)
      reduced = math_ops.reduce_sum(y, axis=2)
      self.assertAllEqual(100 * [[36.0]], reduced)


if __name__ == '__main__':
  ops.enable_eager_execution(
      config=config_pb2.ConfigProto(log_device_placement=True))
  googletest.main()