aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD8
-rw-r--r--tensorflow/python/framework/test_util.py56
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py117
-rw-r--r--tensorflow/python/ops/rnn.py19
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py136
-rw-r--r--tensorflow/python/platform/test.py5
-rw-r--r--tensorflow/python/training/localhost_cluster_performance_test.py32
-rw-r--r--tensorflow/tools/api/golden/tensorflow.test.pbtxt4
-rw-r--r--tensorflow/tools/pip_package/BUILD1
10 files changed, 214 insertions, 167 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 648745e931..2f183413dc 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -588,6 +588,7 @@ py_library(
":platform",
":platform_test",
":pywrap_tensorflow",
+ ":training",
":util",
"//third_party/py/numpy",
"@six_archive//:six",
@@ -595,6 +596,12 @@ py_library(
)
py_library(
+ name = "distributed_framework_test_lib",
+ srcs_version = "PY2AND3",
+ deps = [":framework_test_lib"],
+)
+
+py_library(
name = "client_testlib",
srcs = ["platform/test.py"],
srcs_version = "PY2AND3",
@@ -2829,6 +2836,7 @@ cuda_py_test(
additional_deps = [
":client",
":client_testlib",
+ ":distributed_framework_test_lib",
":framework_for_generated_wrappers",
":partitioned_variables",
":training",
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index ac551a6e1a..194d608a88 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -27,6 +27,7 @@ import tempfile
import threading
import numpy as np
+import portpicker
import six
from google.protobuf import descriptor_pool
@@ -45,6 +46,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util.protobuf import compare
@@ -776,3 +778,57 @@ class TensorFlowTestCase(googletest.TestCase):
assertItemsEqual = googletest.TestCase.assertCountEqual
# pylint: enable=invalid-name
+
+
+def create_local_cluster(num_workers, num_ps, protocol="grpc"):
+ """Create and start local servers and return the associated `Server` objects.
+
+ Example:
+ ```python
+ workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
+
+ worker_sessions = [tf.Session(w.target) for w in workers]
+
+ with tf.device("/job:ps/task:0"):
+ ...
+ with tf.device("/job:ps/task:1"):
+ ...
+ with tf.device("/job:worker/task:0"):
+ ...
+ with tf.device("/job:worker/task:1"):
+ ...
+
+ worker_sessions[0].run(...)
+ ```
+
+ Args:
+ num_workers: Number of worker servers to start.
+ num_ps: Number of PS servers to start.
+ protocol: Communication protocol. Allowed values are documented in
+ the documentation of `tf.train.Server`.
+
+ Returns:
+ A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list
+ of `num_workers` objects of type `tf.train.Server` (all running locally);
+ and `ps_servers` is a list of `num_ps` objects of similar type.
+ """
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
+ ]
+
+ return workers, ps_servers
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5d9534a206..ab69b33cb1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1114,6 +1114,7 @@ cuda_py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
+ "//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -1894,10 +1895,12 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
+ "//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_grad",
+ "//tensorflow/python:training",
"//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:variables",
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 41fe29e006..5b0f318efe 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -1194,82 +1196,83 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual(expected_grad, grad_vals[0])
def testTensorArrayGetsDeviceFromFirstWrite(self):
- with ops.device("/gpu:1"):
+ with ops.device("/job:worker/task:0/cpu:0"):
+ # this initial device will be ignored.
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
- # parent device was ignored when creating the TensorArray
- self.assertEqual(ta.handle.device, "")
- self.assertEqual(ta.flow.device, "")
- with ops.device("/gpu:0"):
- # the first write sets the op's device
+ with ops.device("/job:worker/task:1/cpu:0"):
+ # the first write sets the op's device.
ta = ta.write(0, 1.0)
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
- with ops.device("/gpu:1"):
- # subsequent writes do not modify the op's device
+ with ops.device("/job:worker/task:2/cpu:0"):
+ # subsequent writes do not modify the op's device.
ta = ta.write(1, 1.0)
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
+ # The gradient TA will sit on the same device as the forward TA.
ta_grad = ta.grad("grad")
- self.assertTrue("gpu:0" in ta_grad.handle.device.lower())
- self.assertTrue("gpu:0" in ta_grad.flow.device.lower())
+ flows = [ta.flow, ta_grad.flow]
# Similar tests for unpack and split
- ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
- self.assertEqual(ta.handle.device, "")
- self.assertEqual(ta.flow.device, "")
- with ops.device("/gpu:0"):
+ with ops.device("/job:worker/task:0/cpu:0"):
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
+ with ops.device("/job:worker/task:1/cpu:0"):
ta = ta.unstack([1.0, 2.0])
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
- with ops.device("/gpu:1"):
- ta = ta.unstack([1.0, 2.0])
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
+ with ops.device("/job:worker/task:2/cpu:0"):
+ ta = ta.write(2, 3.0)
+ flows.append(ta.flow)
- ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
- self.assertEqual(ta.handle.device, "")
- self.assertEqual(ta.flow.device, "")
- with ops.device("/gpu:0"):
- ta = ta.split([1.0, 2.0], [1, 1])
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
- with ops.device("/gpu:1"):
+ with ops.device("/job:worker/task:0/cpu:0"):
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ with ops.device("/job:worker/task:1/cpu:0"):
ta = ta.split([1.0, 2.0], [1, 1])
- self.assertTrue("gpu:0" in ta.handle.device.lower())
- self.assertTrue("gpu:0" in ta.flow.device.lower())
+ flows.append(ta.flow)
+
+ workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
+ session = session_lib.Session(workers[0].target)
+
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ session.run(flows, options=run_options, run_metadata=run_metadata)
+ self.assertTrue(run_metadata.HasField("step_stats"))
+ dev_stats = {d.device: d.node_stats
+ for d in run_metadata.step_stats.dev_stats}
+ for d in dev_stats:
+ if "/task:1/" in d:
+ self.assertTrue(
+ [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
+ else:
+ self.assertFalse(
+ [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
def testTensorArrayGetsDeviceFromFirstWriteInWhileLoop(self):
- ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ with ops.device("/job:worker/task:0/cpu:0"):
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
def _body(i, ta_i):
- with ops.device("/gpu:0"):
+ with ops.device("/job:worker/task:1/cpu:0"):
return i + 1, ta_i.write(i, 0.0)
- self.assertEqual(ta.handle.device, "")
- self.assertEqual(ta.flow.device, "")
-
_, ta_out = control_flow_ops.while_loop(
lambda i, ta: i < 2, _body, loop_vars=[0, ta])
- self.assertTrue("gpu:0" in ta_out.handle.device.lower())
- self.assertTrue("gpu:0" in ta.handle.device.lower())
-
- def testTensorArrayLazyDeviceSettingDoesNotConfuseInitialAccess(self):
- with self.test_session(use_gpu=True) as session:
- ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
- self.assertEqual(ta.handle.device, "")
-
- with ops.device("/cpu:0"):
- size = ta.size()
- with ops.device("/gpu:0"):
- ta = ta.write(0, 0.0)
-
- self.assertTrue("gpu:0" in ta.handle.device.lower())
-
- # This should use the TensorArray on /gpu:0
- size_value, _ = session.run((size, ta.flow))
- self.assertEqual(2, size_value)
+ workers, _ = test.create_local_cluster(num_workers=3, num_ps=0)
+ session = session_lib.Session(workers[0].target)
+
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ session.run(ta_out.flow, options=run_options, run_metadata=run_metadata)
+ self.assertTrue(run_metadata.HasField("step_stats"))
+ dev_stats = {d.device: d.node_stats
+ for d in run_metadata.step_stats.dev_stats}
+ for d in dev_stats:
+ if "/task:1/" in d:
+ self.assertTrue(
+ [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
+ else:
+ self.assertFalse(
+ [s for s in dev_stats[d] if "/TensorArray" in s.node_name])
def testTensorArrayIdentity(self):
with self.test_session(use_gpu=True) as session:
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 2aa288e36a..a6fba046da 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -98,15 +98,6 @@ def _infer_state_dtype(explicit_dtype, state):
return state.dtype
-def _on_device(fn, device):
- """Build the subgraph defined by lambda `fn` on `device` if it's not None."""
- if device:
- with ops.device(device):
- return fn()
- else:
- return fn()
-
-
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -168,9 +159,8 @@ def _rnn_step(
def _copy_one_through(output, new_output):
copy_cond = (time >= sequence_length)
- return _on_device(
- lambda: array_ops.where(copy_cond, output, new_output),
- device=new_output.op.device)
+ with ops.colocate_with(new_output):
+ return array_ops.where(copy_cond, output, new_output)
def _copy_some_through(flat_new_output, flat_new_state):
# Use broadcasting select to determine which values should get
@@ -1020,9 +1010,8 @@ def raw_rnn(cell, loop_fn,
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
- return _on_device(
- lambda: array_ops.where(elements_finished, cur_i, cand_i),
- device=cand_i.op.device)
+ with ops.colocate_with(cand_i):
+ return array_ops.where(elements_finished, cur_i, cand_i)
return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index b1c7d74a0c..8b119f5842 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -22,6 +22,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
+
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -31,24 +33,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_should_use
-def _maybe_set_device(handle_op, value_t):
- # NOTE(ebrevdo): Do not try this at home, kids
- # _______________________________________________
- # | I WILL NOT ACCESS PRIVATE METHODS ^^^^^^^^\ |
- # | I WILL NOT ACCESS PRIVATE METHODS | | |
- # | I WILL NOT ACCESS PRIVATE METHODS |_ __ | |
- # | I WILL NOT ACCESS PRIVATE METHODS (.(. ) | |
- # | I WILL NOT ACCESS PRIVATE (_ ) |
- # | \\ /___/' / |
- # | _\\_ \ | |
- # | (( ) /====| |
- # | \ <.__._- \ |
- # |___________________________ <//___. ||
- #
- if not handle_op.device and value_t.device:
- handle_op._set_device(value_t.device) # pylint: disable=protected-access
-
-
# TensorArray object accesses many of the hidden generated ops, but is
# in fact built to wrap these methods.
# pylint: disable=protected-access
@@ -132,6 +116,12 @@ class TensorArray(object):
dynamic_size = dynamic_size or False
self._dtype = dtype
+
+ # Used to keep track of what tensors the TensorArray should be
+ # colocated with. We choose to colocate the TensorArray with the
+ # first tensor written to it.
+ self._colocate_with = []
+
# Record the current static shape for the array elements. The element
# shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
@@ -197,6 +187,24 @@ class TensorArray(object):
else:
self._element_shape.append(shape)
+ @contextlib.contextmanager
+ def _maybe_colocate_with(self, value):
+ """Colocate operations with an internal colocation group or `value`.
+
+ Args:
+ value: `Tensor`, the tensor to try to colocate with.
+
+ Yields:
+ Does not yield anything, but the new context is a colocation context.
+
+ If no internal colocation group is set, colocate with `value` and set
+ the internal colocation group to be value.
+ """
+ if not self._colocate_with:
+ self._colocate_with.append(value)
+ with ops.colocate_with(self._colocate_with[0]):
+ yield
+
def identity(self):
"""Returns a TensorArray with the same content and properties.
@@ -209,6 +217,7 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow,
infer_shape=self._infer_shape)
ta._element_shape = self._element_shape
+ ta._colocate_with = self._colocate_with
return ta
def grad(self, source, flow=None, name=None):
@@ -242,16 +251,15 @@ class TensorArray(object):
Returns:
The tensor at index `index`.
"""
- with ops.colocate_with(self._handle):
- value = gen_data_flow_ops._tensor_array_read_v3(
- handle=self._handle,
- index=index,
- flow_in=self._flow,
- dtype=self._dtype,
- name=name)
- if self._element_shape:
- value.set_shape(self._element_shape[0].dims)
- return value
+ value = gen_data_flow_ops._tensor_array_read_v3(
+ handle=self._handle,
+ index=index,
+ flow_in=self._flow,
+ dtype=self._dtype,
+ name=name)
+ if self._element_shape:
+ value.set_shape(self._element_shape[0].dims)
+ return value
@tf_should_use.should_use_result
def write(self, index, value, name=None):
@@ -271,8 +279,7 @@ class TensorArray(object):
"""
with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
value = ops.convert_to_tensor(value, name="value")
- _maybe_set_device(self._handle.op, value)
- with ops.colocate_with(self._handle):
+ with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops._tensor_array_write_v3(
handle=self._handle,
index=index,
@@ -282,6 +289,7 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
+ ta._colocate_with = self._colocate_with
if ta._infer_shape:
ta._merge_element_shape(value.get_shape())
return ta
@@ -316,21 +324,20 @@ class TensorArray(object):
Returns:
The in the `TensorArray` selected by `indices`, packed into one tensor.
"""
- with ops.colocate_with(self._handle):
- if self._element_shape:
- element_shape = self._element_shape[0]
- else:
- element_shape = tensor_shape.TensorShape(None)
- value = gen_data_flow_ops._tensor_array_gather_v3(
- handle=self._handle,
- indices=indices,
- flow_in=self._flow,
- dtype=self._dtype,
- name=name,
- element_shape=element_shape)
- if self._element_shape and self._element_shape[0].dims is not None:
- value.set_shape([None] + self._element_shape[0].dims)
- return value
+ if self._element_shape:
+ element_shape = self._element_shape[0]
+ else:
+ element_shape = tensor_shape.TensorShape(None)
+ value = gen_data_flow_ops._tensor_array_gather_v3(
+ handle=self._handle,
+ indices=indices,
+ flow_in=self._flow,
+ dtype=self._dtype,
+ name=name,
+ element_shape=element_shape)
+ if self._element_shape and self._element_shape[0].dims is not None:
+ value.set_shape([None] + self._element_shape[0].dims)
+ return value
def concat(self, name=None):
"""Return the values in the TensorArray as a concatenated `Tensor`.
@@ -349,16 +356,15 @@ class TensorArray(object):
tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
else:
element_shape_except0 = tensor_shape.TensorShape(None)
- with ops.colocate_with(self._handle):
- value, _ = gen_data_flow_ops._tensor_array_concat_v3(
- handle=self._handle,
- flow_in=self._flow,
- dtype=self._dtype,
- name=name,
- element_shape_except0=element_shape_except0)
- if self._element_shape and self._element_shape[0].dims is not None:
- value.set_shape([None] + self._element_shape[0].dims[1:])
- return value
+ value, _ = gen_data_flow_ops._tensor_array_concat_v3(
+ handle=self._handle,
+ flow_in=self._flow,
+ dtype=self._dtype,
+ name=name,
+ element_shape_except0=element_shape_except0)
+ if self._element_shape and self._element_shape[0].dims is not None:
+ value.set_shape([None] + self._element_shape[0].dims[1:])
+ return value
@tf_should_use.should_use_result
def unstack(self, value, name=None):
@@ -403,8 +409,7 @@ class TensorArray(object):
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value")
- _maybe_set_device(self._handle.op, value)
- with ops.colocate_with(self._handle):
+ with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops._tensor_array_scatter_v3(
handle=self._handle,
indices=indices,
@@ -414,6 +419,7 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
+ ta._colocate_with = self._colocate_with
if ta._infer_shape:
val_shape = flow_out.op.inputs[2].get_shape()
element_shape = tensor_shape.unknown_shape()
@@ -442,9 +448,8 @@ class TensorArray(object):
with ops.name_scope(name, "TensorArraySplit",
[self._handle, value, lengths]):
value = ops.convert_to_tensor(value, name="value")
- _maybe_set_device(self._handle.op, value)
- lengths_64 = math_ops.to_int64(lengths)
- with ops.colocate_with(self._handle):
+ with self._maybe_colocate_with(value):
+ lengths_64 = math_ops.to_int64(lengths)
flow_out = gen_data_flow_ops._tensor_array_split_v3(
handle=self._handle,
value=value,
@@ -454,6 +459,7 @@ class TensorArray(object):
ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
+ ta._colocate_with = self._colocate_with
if ta._infer_shape:
val_shape = flow_out.op.inputs[1].get_shape()
clengths = tensor_util.constant_value(flow_out.op.inputs[2])
@@ -467,15 +473,13 @@ class TensorArray(object):
def size(self, name=None):
"""Return the size of the TensorArray."""
- with ops.colocate_with(self._handle):
- return gen_data_flow_ops._tensor_array_size_v3(
- handle=self._handle, flow_in=self.flow, name=name)
+ return gen_data_flow_ops._tensor_array_size_v3(
+ handle=self._handle, flow_in=self.flow, name=name)
@tf_should_use.should_use_result
def close(self, name=None):
"""Close the current TensorArray."""
- with ops.colocate_with(self._handle):
- return gen_data_flow_ops._tensor_array_close_v3(
- handle=self._handle, name=name)
+ return gen_data_flow_ops._tensor_array_close_v3(
+ handle=self._handle, name=name)
# pylint: enable=protected-access
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 452b8f5d3b..5cb2c152b0 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -27,12 +27,15 @@ See the @{$python/test} guide.
@@gpu_device_name
@@compute_gradient
@@compute_gradient_error
+@@create_local_cluster
+
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
# pylint: disable=g-bad-import-order
from tensorflow.python.client import device_lib as _device_lib
from tensorflow.python.framework import test_util as _test_util
@@ -41,6 +44,7 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: disable=unused-import
from tensorflow.python.framework.test_util import assert_equal_graph_def
+from tensorflow.python.framework.test_util import create_local_cluster
from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
from tensorflow.python.framework.test_util import gpu_device_name
@@ -108,6 +112,7 @@ def is_gpu_available(cuda_only=False):
return any((x.device_type == 'GPU' or x.device_type == 'SYCL')
for x in _device_lib.list_local_devices())
+
_allowed_symbols = [
# We piggy-back googletest documentation.
'Benchmark',
diff --git a/tensorflow/python/training/localhost_cluster_performance_test.py b/tensorflow/python/training/localhost_cluster_performance_test.py
index 9de681837d..7c097b943d 100644
--- a/tensorflow/python/training/localhost_cluster_performance_test.py
+++ b/tensorflow/python/training/localhost_cluster_performance_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import time
import numpy as np
-import portpicker
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
@@ -31,37 +30,12 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
-from tensorflow.python.training import server_lib
-
-
-def create_local_cluster(num_workers, num_ps, protocol="grpc"):
- """Create local GRPC servers and return their servers."""
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
- cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
- }
- cs = server_lib.ClusterSpec(cluster_dict)
-
- workers = [
- server_lib.Server(
- cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_workers)
- ]
- ps_servers = [
- server_lib.Server(
- cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_ps)
- ]
-
- return workers, ps_servers
class CreateLocalClusterTest(test.TestCase):
def testCreateLocalCluster(self):
- workers, _ = create_local_cluster(num_workers=2, num_ps=2)
+ workers, _ = test.create_local_cluster(num_workers=2, num_ps=2)
worker_sessions = [session_lib.Session(w.target) for w in workers]
with ops.device("/job:ps/task:0"):
var0 = variables.Variable(0.0)
@@ -88,7 +62,7 @@ class CreateLocalClusterBenchmark(test.Benchmark):
iters = 5
for _ in range(iters):
start_time = time.time()
- create_local_cluster(num_workers=1, num_ps=10)
+ test.create_local_cluster(num_workers=1, num_ps=10)
end_time = time.time()
deltas.append(end_time - start_time)
@@ -104,7 +78,7 @@ class CreateLocalClusterBenchmark(test.Benchmark):
class PartitionedVariablesBenchmark(test.Benchmark):
def benchmark_create_1000_partitions_with_100_parameter_servers(self):
- workers, _ = create_local_cluster(num_workers=1, num_ps=100)
+ workers, _ = test.create_local_cluster(num_workers=1, num_ps=100)
worker_sessions = [session_lib.Session(w.target) for w in workers]
worker = worker_sessions[0]
partition_sizes = (1, 512, 1024 * 32, 1024 * 128)
diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.pbtxt
index c4768a68bf..1e717ad237 100644
--- a/tensorflow/tools/api/golden/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.test.pbtxt
@@ -25,6 +25,10 @@ tf_module {
argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], "
}
member_method {
+ name: "create_local_cluster"
+ argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\'], varargs=None, keywords=None, defaults=[\'grpc\'], "
+ }
+ member_method {
name: "get_temp_dir"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 83be430e7d..377a687c34 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -150,6 +150,7 @@ sh_binary(
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip",
"//tensorflow/examples/tutorials/mnist:package",
+ "//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:util_example_parser_configuration",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/saved_model:saved_model",