diff options
Diffstat (limited to 'tensorflow/compiler/tests/eager_test.py')
-rw-r--r-- | tensorflow/compiler/tests/eager_test.py | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 3524666499..6ead15da13 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -403,7 +403,7 @@ class EagerFunctionTest(xla_test.XLATestCase): def testSliceInDefun(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x, y): return x[0::2, y:, ...] @@ -418,6 +418,22 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) self.assertAllEqual((2, 3, 4), dz.shape.as_list()) + def testNestedDefun(self): + self.skipTest('Nested defuns do not work on TPU at the moment') + with self.test_scope(): + + @function.defun + def times_two(x): + return 2 * x + + @function.defun + def two_x_plus_1(x): + return times_two(x) + 1 + + x = constant_op.constant([2, 3, 4]) + y = two_x_plus_1(x) + self.assertAllEqual([5, 7, 9], y.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. @@ -470,6 +486,36 @@ class ExcessivePaddingTest(xla_test.XLATestCase): self.assertAllEqual(100 * [[36.0]], reduced) +def multiple_tpus(): + devices = context.context().devices() + return len([d for d in devices if 'device:TPU:' in d]) > 1 + + +class MultiDeviceTest(xla_test.XLATestCase): + """Test running TPU computation on more than one core.""" + + def testBasic(self): + if not multiple_tpus(): + self.skipTest('MultiDeviceTest requires multiple TPU devices.') + + # Compute 10 on TPU core 0 + with ops.device('device:TPU:0'): + two = constant_op.constant(2) + five = constant_op.constant(5) + ten = two * five + self.assertAllEqual(10, ten) + + # Compute 6 on TPU core 1 + with ops.device('device:TPU:1'): + two = constant_op.constant(2) + three = constant_op.constant(3) + six = two * three + self.assertAllEqual(6, six) + + # Copy 10 and 6 to CPU and sum them + self.assertAllEqual(16, ten + six) + + if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) |