aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/platform
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-16 17:27:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-16 17:32:46 -0700
commitb0c21b654064f11d876392e94f0420eaaff841b3 (patch)
tree627c0f35d36030ff82481b7b0c3306ce76e8cb06 /tensorflow/python/platform
parentd726686d2dff34a11f85841dfc52bf45698ac49d (diff)
TensorArray now uses colocation instead of lazy device setters. Added test helper.
* Move a new test method making it easy to write local distributed unit tests into the tf.test namespace. * Update the rnn helper functions to use colocation in a few places as well. * Use it to test the new colocation behavior. PiperOrigin-RevId: 156254308
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r--tensorflow/python/platform/test.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 452b8f5d3b..5cb2c152b0 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -27,12 +27,15 @@ See the @{$python/test} guide.
@@gpu_device_name
@@compute_gradient
@@compute_gradient_error
+@@create_local_cluster
+
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
# pylint: disable=g-bad-import-order
from tensorflow.python.client import device_lib as _device_lib
from tensorflow.python.framework import test_util as _test_util
@@ -41,6 +44,7 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: disable=unused-import
from tensorflow.python.framework.test_util import assert_equal_graph_def
+from tensorflow.python.framework.test_util import create_local_cluster
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.framework.test_util import gpu_device_name
@@ -108,6 +112,7 @@ def is_gpu_available(cuda_only=False):
return any((x.device_type == 'GPU' or x.device_type == 'SYCL')
for x in _device_lib.list_local_devices())
+
_allowed_symbols = [
# We piggy-back googletest documentation.
'Benchmark',