aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/eager_test.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-05-08 16:43:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 17:09:23 -0700
commit14d5f219f33b1ab8e0a67b84d97204d046adb91f (patch)
treeb887f04458ef204e522d2b3d81d15128104b397c /tensorflow/compiler/tests/eager_test.py
parent79b773a4395caf7f0b17ce9ac84a1f34dd277bb9 (diff)
Make eager functions runable on TPU
PiperOrigin-RevId: 195897321
Diffstat (limited to 'tensorflow/compiler/tests/eager_test.py')
-rw-r--r--tensorflow/compiler/tests/eager_test.py112
1 files changed, 108 insertions, 4 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index bdd0185dfe..5ab1585f8c 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.layers import convolutional
+from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
@@ -43,7 +49,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.int32)
+ empty = constant_op.constant([], dtype=dtypes.float32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
@@ -51,7 +57,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
@@ -60,7 +66,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
@@ -131,7 +137,105 @@ class EagerTest(XLATestCase):
self.assertEqual(2., grads[0][0].numpy())
-if __name__ == "__main__":
+class EagerFunctionTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_scope():
+ matmul = function.defun(math_ops.matmul, compiled=True)
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t, transpose_a=True)
+ self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
+
+ def testConv(self):
+ if 'GPU' in self.device:
+ # TODO(b/32333178)
+ self.skipTest('Current implementation of RandomStandardNormal kernel '
+ 'is very slow on GPU, and has been blacklisted.')
+ with self.test_scope():
+ data_format = 'channels_last'
+ conv = convolutional.Conv2D(
+ filters=1, kernel_size=2, padding='VALID',
+ data_format=data_format, activation=nn_ops.relu,
+ kernel_initializer=init_ops.ones_initializer(),
+ bias_initializer=init_ops.zeros_initializer())
+ pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
+
+ def model(x):
+ x = conv(x)
+ return pool(x)
+ model = function.defun(model, compiled=True)
+
+ x = array_ops.ones([1, 4, 4, 1])
+ y = model(x)
+ self.assertAllEqual(y.numpy(), [[[[4.]]]])
+
+ def testReadVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun(compiled=True)
+ def f():
+ return v.read_value()
+
+ var = f()
+ self.assertEqual(1.0, var.numpy())
+
+ def testUpdateVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def f(v):
+ v.assign_add(1.0)
+ return v
+
+ f = function.defun(f, compiled=True)
+
+ var = f(v)
+ self.assertEqual(2.0, var.numpy())
+
+ def testAllArgumentKinds(self):
+ """Test a complex function that takes different argument kinds.
+
+ tf2xla machinery that translates, compiles, and runs defuns
+ classifies arguments into: compile-time constants, regular tensors,
+ and resources. This test creates a function with a mix of all these
+ kinds. Moreover, the order of function arguments is intentionally mixed up.
+
+ This also tests the case when the same argument is a compile-time constant
+ as well as used in an operation that normally expects its inputs to be
+ in device memory - addition in this case.
+ """
+ with self.test_scope():
+ def foo(c1, r1, v1, c2, v2, r2):
+ # c1 and c2 are compile-time constants
+ # r1 and r2 are regular tensors
+ # v1 and v2 are resource variables
+ a = c1 + r1
+ b = math_ops.cast(c2, dtypes.float32) + v2
+ c = array_ops.slice(v1, c1, c2)
+ d = r2 * v2
+ return a, b, c, d
+
+ foo = function.defun(foo, compiled=True)
+
+ c1 = [0, 0]
+ c2 = array_ops.ones([2], dtype=dtypes.int32)
+
+ r1 = array_ops.ones([2])
+ r2 = [[2., 2.], [3., 3.]]
+
+ v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
+ v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
+
+ a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
+
+ self.assertAllEqual([1, 1], a.numpy())
+ self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
+ self.assertAllEqual([[1.]], c.numpy())
+ self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
+
+
+if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()