aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/eager_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/eager_test.py')
-rw-r--r--tensorflow/compiler/tests/eager_test.py48
1 files changed, 47 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 3524666499..6ead15da13 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -403,7 +403,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
def testSliceInDefun(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x, y):
return x[0::2, y:, ...]
@@ -418,6 +418,22 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
+ def testNestedDefun(self):
+ self.skipTest('Nested defuns do not work on TPU at the moment')
+ with self.test_scope():
+
+ @function.defun
+ def times_two(x):
+ return 2 * x
+
+ @function.defun
+ def two_x_plus_1(x):
+ return times_two(x) + 1
+
+ x = constant_op.constant([2, 3, 4])
+ y = two_x_plus_1(x)
+ self.assertAllEqual([5, 7, 9], y.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
@@ -470,6 +486,36 @@ class ExcessivePaddingTest(xla_test.XLATestCase):
self.assertAllEqual(100 * [[36.0]], reduced)
+def multiple_tpus():
+ devices = context.context().devices()
+ return len([d for d in devices if 'device:TPU:' in d]) > 1
+
+
+class MultiDeviceTest(xla_test.XLATestCase):
+ """Test running TPU computation on more than one core."""
+
+ def testBasic(self):
+ if not multiple_tpus():
+ self.skipTest('MultiDeviceTest requires multiple TPU devices.')
+
+ # Compute 10 on TPU core 0
+ with ops.device('device:TPU:0'):
+ two = constant_op.constant(2)
+ five = constant_op.constant(5)
+ ten = two * five
+ self.assertAllEqual(10, ten)
+
+ # Compute 6 on TPU core 1
+ with ops.device('device:TPU:1'):
+ two = constant_op.constant(2)
+ three = constant_op.constant(3)
+ six = two * three
+ self.assertAllEqual(6, six)
+
+ # Copy 10 and 6 to CPU and sum them
+ self.assertAllEqual(16, ten + six)
+
+
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))