diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-05-16 17:27:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-16 17:32:46 -0700 |
commit | b0c21b654064f11d876392e94f0420eaaff841b3 (patch) | |
tree | 627c0f35d36030ff82481b7b0c3306ce76e8cb06 /tensorflow/python/platform | |
parent | d726686d2dff34a11f85841dfc52bf45698ac49d (diff) |
TensorArray now uses colocation instead of lazy device setters. Added test helper.
* Move a new test method making it easy to write local distributed unit tests
into the tf.test namespace.
* Update the rnn helper functions to use colocation in a few places as well.
* Use it to test the new colocation behavior.
PiperOrigin-RevId: 156254308
Diffstat (limited to 'tensorflow/python/platform')
-rw-r--r-- | tensorflow/python/platform/test.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 452b8f5d3b..5cb2c152b0 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -27,12 +27,15 @@ See the @{$python/test} guide. @@gpu_device_name @@compute_gradient @@compute_gradient_error +@@create_local_cluster + """ from __future__ import absolute_import from __future__ import division from __future__ import print_function + # pylint: disable=g-bad-import-order from tensorflow.python.client import device_lib as _device_lib from tensorflow.python.framework import test_util as _test_util @@ -41,6 +44,7 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: disable=unused-import from tensorflow.python.framework.test_util import assert_equal_graph_def +from tensorflow.python.framework.test_util import create_local_cluster from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase from tensorflow.python.framework.test_util import gpu_device_name @@ -108,6 +112,7 @@ def is_gpu_available(cuda_only=False): return any((x.device_type == 'GPU' or x.device_type == 'SYCL') for x in _device_lib.list_local_devices()) + _allowed_symbols = [ # We piggy-back googletest documentation. 'Benchmark', |