aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-04 16:10:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 16:14:56 -0700
commitcf8e7cf89abb4a7783b9a99f17574ea128fa767a (patch)
tree52e733a0ec849c70356ed51675e8ac46916bbc18 /tensorflow/python/eager
parentd6a2e7bcca5683c377b592f177bcac9aeb1c550f (diff)
Pin ops with small integer inputs (already on the cpu) to the cpu in eager.
An environment variable (TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING) is provided to turn this off if necessary (its on by default). PiperOrigin-RevId: 215821915
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/core_test.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index fb5442b646..e601aa376f 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -631,6 +631,34 @@ class TFETest(test_util.TensorFlowTestCase):
for t in tensors:
self.assertIsInstance(t, ops.EagerTensor)
+ def testSmallIntegerOpsForcedToCPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.int64)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.int64)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op forced to CPU since all constants are integers and small.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:CPU:0')
+
+ a = array_ops.zeros((8, 10), dtype=dtypes.int64)
+ b = array_ops.ones((8, 10), dtype=dtypes.int64)
+
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the tensors are larger than 64 elements.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.float32)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.float32)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the constants are not integers.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
class SendRecvTest(test_util.TensorFlowTestCase):