aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/eager_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-08 02:11:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-08 13:54:39 -0700
commita799cdbe78ca2c2e9c41f2b1bf8a3f57162fbcea (patch)
tree81351088d5f3e6d0ae2b2c4575923332416b5c4b /tensorflow/compiler/tests/eager_test.py
parent069f3124eedab44b4e884c3c64ba8d5eccadfe93 (diff)
Automated g4 rollback of changelist 195748721
PiperOrigin-RevId: 195790581
Diffstat (limited to 'tensorflow/compiler/tests/eager_test.py')
-rw-r--r--tensorflow/compiler/tests/eager_test.py112
1 files changed, 4 insertions, 108 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 5ab1585f8c..bdd0185dfe 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -24,16 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.layers import convolutional
-from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
@@ -49,7 +43,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen0(self):
with self.test_scope():
- empty = constant_op.constant([], dtype=dtypes.float32)
+ empty = constant_op.constant([], dtype=dtypes.int32)
result = array_ops.unstack(empty, 0)
self.assertTrue(isinstance(result, list))
self.assertEqual(0, len(result))
@@ -57,7 +51,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen1(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
result = array_ops.split(value, 1, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(1, len(result))
@@ -66,7 +60,7 @@ class EagerTest(XLATestCase):
def testExecuteListOutputLen3(self):
with self.test_scope():
split_dim = constant_op.constant(1)
- value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
result = array_ops.split(value, 3, axis=split_dim)
self.assertTrue(isinstance(result, list))
self.assertEqual(3, len(result))
@@ -137,105 +131,7 @@ class EagerTest(XLATestCase):
self.assertEqual(2., grads[0][0].numpy())
-class EagerFunctionTest(XLATestCase):
-
- def testBasic(self):
- with self.test_scope():
- matmul = function.defun(math_ops.matmul, compiled=True)
- t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq = matmul(t, t, transpose_a=True)
- self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
-
- def testConv(self):
- if 'GPU' in self.device:
- # TODO(b/32333178)
- self.skipTest('Current implementation of RandomStandardNormal kernel '
- 'is very slow on GPU, and has been blacklisted.')
- with self.test_scope():
- data_format = 'channels_last'
- conv = convolutional.Conv2D(
- filters=1, kernel_size=2, padding='VALID',
- data_format=data_format, activation=nn_ops.relu,
- kernel_initializer=init_ops.ones_initializer(),
- bias_initializer=init_ops.zeros_initializer())
- pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
-
- def model(x):
- x = conv(x)
- return pool(x)
- model = function.defun(model, compiled=True)
-
- x = array_ops.ones([1, 4, 4, 1])
- y = model(x)
- self.assertAllEqual(y.numpy(), [[[[4.]]]])
-
- def testReadVariable(self):
- with self.test_scope():
- v = resource_variable_ops.ResourceVariable(1.0)
-
- @function.defun(compiled=True)
- def f():
- return v.read_value()
-
- var = f()
- self.assertEqual(1.0, var.numpy())
-
- def testUpdateVariable(self):
- with self.test_scope():
- v = resource_variable_ops.ResourceVariable(1.0)
-
- def f(v):
- v.assign_add(1.0)
- return v
-
- f = function.defun(f, compiled=True)
-
- var = f(v)
- self.assertEqual(2.0, var.numpy())
-
- def testAllArgumentKinds(self):
- """Test a complex function that takes different argument kinds.
-
- tf2xla machinery that translates, compiles, and runs defuns
- classifies arguments into: compile-time constants, regular tensors,
- and resources. This test creates a function with a mix of all these
- kinds. Moreover, the order of function arguments is intentionally mixed up.
-
- This also tests the case when the same argument is a compile-time constant
- as well as used in an operation that normally expects its inputs to be
- in device memory - addition in this case.
- """
- with self.test_scope():
- def foo(c1, r1, v1, c2, v2, r2):
- # c1 and c2 are compile-time constants
- # r1 and r2 are regular tensors
- # v1 and v2 are resource variables
- a = c1 + r1
- b = math_ops.cast(c2, dtypes.float32) + v2
- c = array_ops.slice(v1, c1, c2)
- d = r2 * v2
- return a, b, c, d
-
- foo = function.defun(foo, compiled=True)
-
- c1 = [0, 0]
- c2 = array_ops.ones([2], dtype=dtypes.int32)
-
- r1 = array_ops.ones([2])
- r2 = [[2., 2.], [3., 3.]]
-
- v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
- v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
-
- a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
-
- self.assertAllEqual([1, 1], a.numpy())
- self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
- self.assertAllEqual([[1.]], c.numpy())
- self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
googletest.main()