aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-07-12 13:36:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 13:39:30 -0700
commitc61c7f1ea6a5e6aa0af19eb21d03b351031d944c (patch)
tree288cb67e85946c4de0889c2dcce0c4842542ee99 /tensorflow
parent0ef634190dc2e49e4002a841185fc850b80cc1b9 (diff)
Add a test to eagerly run ops on multiple TPU cores
PiperOrigin-RevId: 204354687
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/eager_test.py47
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py24
-rw-r--r--tensorflow/python/eager/function.py29
3 files changed, 83 insertions, 17 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 8a3ed382a1..a8919d1afd 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -414,7 +414,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
def testSliceInDefun(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x, y):
return x[0::2, y:, ...]
@@ -429,6 +429,21 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
+ def testNestedDefun(self):
+ with self.test_scope():
+
+ @function.defun
+ def times_two(x):
+ return 2 * x
+
+ @function.defun
+ def two_x_plus_1(x):
+ return times_two(x) + 1
+
+ x = constant_op.constant([2, 3, 4])
+ y = two_x_plus_1(x)
+ self.assertAllEqual([5, 7, 9], y.numpy())
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
@@ -481,6 +496,36 @@ class ExcessivePaddingTest(xla_test.XLATestCase):
self.assertAllEqual(100 * [[36.0]], reduced)
+def multiple_tpus():
+ devices = context.context().devices()
+ return len([d for d in devices if 'device:TPU:' in d]) > 1
+
+
+class MultiDeviceTest(xla_test.XLATestCase):
+ """Test running TPU computation on more than one core."""
+
+ def testBasic(self):
+ if not multiple_tpus():
+ self.skipTest('MultiDeviceTest requires multiple TPU devices.')
+
+ # Compute 10 on TPU core 0
+ with ops.device('device:TPU:0'):
+ two = constant_op.constant(2)
+ five = constant_op.constant(5)
+ ten = two * five
+ self.assertAllEqual(10, ten)
+
+ # Compute 6 on TPU core 1
+ with ops.device('device:TPU:1'):
+ two = constant_op.constant(2)
+ three = constant_op.constant(3)
+ six = two * three
+ self.assertAllEqual(6, six)
+
+ # Copy 10 and 6 to CPU and sum them
+ self.assertAllEqual(16, ten + six)
+
+
if __name__ == '__main__':
ops.enable_eager_execution(
config=config_pb2.ConfigProto(log_device_placement=True))
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index b14ef1df8f..07d8788882 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -29,6 +29,7 @@ import tensorflow.contrib.eager as tfe
from tensorflow.contrib.eager.python.examples.resnet50 import resnet50
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.client import device_lib
+from tensorflow.python.eager import tape
def device_and_data_format():
@@ -49,13 +50,21 @@ def random_batch(batch_size, data_format):
return images, one_hot
-def compute_gradients(model, images, labels):
- with tf.GradientTape() as tape:
+def compute_gradients(model, images, labels, num_replicas=1):
+ with tf.GradientTape() as grad_tape:
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
- return tape.gradient(loss, model.variables)
+ if num_replicas != 1:
+ loss /= num_replicas
+
+ # TODO(b/110991947): We can mistakenly trace the gradient call in
+ # multi-threaded environment. Explicitly disable recording until
+ # this is fixed.
+ with tape.stop_recording():
+ grads = grad_tape.gradient(loss, model.variables)
+ return grads
def apply_gradients(model, optimizer, gradients):
@@ -188,11 +197,14 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return (32,)
return (16, 32)
- def _report(self, label, start, num_iters, device, batch_size, data_format):
+ def _report(self, label, start, num_iters, device, batch_size, data_format,
+ num_replicas=1):
avg_time = (time.time() - start) / num_iters
dev = tf.DeviceSpec.from_string(device).device_type.lower()
- name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format)
- extras = {'examples_per_sec': batch_size / avg_time}
+ replica_str = '' if num_replicas == 1 else 'replicas_%d_' % num_replicas
+ name = '%s_%s_batch_%d_%s%s' % (label, dev, batch_size,
+ replica_str, data_format)
+ extras = {'examples_per_sec': (num_replicas * batch_size) / avg_time}
self.report_benchmark(
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index df83d673ad..29a3848bd8 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import threading
import numpy as np
@@ -137,7 +138,7 @@ class CapturingGraph(ops.Graph):
inputs[i] = self.capture(inp)
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_shapes, compute_device)
+ compute_device=compute_device)
# pylint: disable=invalid-name
@@ -770,6 +771,11 @@ class _PolymorphicFunction(object):
See the documentation for `defun` for more information on the semantics of
defined functions.
+
+ _PolymorphicFunction class is thread-compatible meaning that minimal
+ usage of defuns (defining and calling) is thread-safe, but if users call other
+ methods or invoke the base `python_function` themselves, external
+ synchronization is necessary.
"""
def __init__(self, python_function, name, compiled=False):
@@ -787,6 +793,8 @@ class _PolymorphicFunction(object):
self._arguments_to_functions = {}
self._variables = []
+ self._lock = threading.Lock()
+
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
@@ -825,15 +833,16 @@ class _PolymorphicFunction(object):
# signature so we don't improperly capture tensors such as variables.
signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
- if signature not in self._arguments_to_functions:
- graph_function = _trace_and_define_function(
- self._name, self._python_function, self._compiled, args, kwds)
- self._arguments_to_functions[signature] = graph_function
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
- return graph_function, inputs
- else:
- return self._arguments_to_functions[signature], inputs
+ with self._lock:
+ if signature not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[signature] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function, inputs
+ else:
+ return self._arguments_to_functions[signature], inputs
def __call__(self, *args, **kwds):
"""Calls a graph function specialized for this input signature."""