aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD133
-rw-r--r--tensorflow/contrib/distribute/python/checkpoint_utils_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py123
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py120
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py28
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py69
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py58
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py659
-rw-r--r--tensorflow/contrib/distribute/python/examples/BUILD15
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py126
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py (renamed from tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py)10
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py39
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py161
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py50
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py19
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py141
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy_test.py62
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py107
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py76
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py190
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py6
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py5
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py21
-rw-r--r--tensorflow/contrib/distribute/python/values.py211
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py20
-rw-r--r--tensorflow/contrib/distribute/python/warm_starting_util_test.py4
31 files changed, 1889 insertions, 594 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index ae50d4e3fc..94deb2a432 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -23,8 +23,6 @@ py_library(
deps = [
":input_ops",
":prefetching_ops_v2",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/eager/python:datasets",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -72,49 +70,72 @@ py_library(
":cross_tower_ops",
":shared_variable_creator",
":values",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:device",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
- "@six_archive//:six",
],
)
py_library(
- name = "multi_worker_strategy",
- srcs = ["multi_worker_strategy.py"],
+ name = "parameter_server_strategy",
+ srcs = ["parameter_server_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
+ ":cross_tower_ops",
":mirrored_strategy",
":values",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
],
)
-py_library(
- name = "parameter_server_strategy",
- srcs = ["parameter_server_strategy.py"],
- visibility = ["//tensorflow:internal"],
- deps = [
- ":cross_tower_ops",
- ":mirrored_strategy",
+cuda_py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
":values",
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
- "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -148,6 +169,7 @@ py_library(
"//tensorflow/python:collective_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
],
)
@@ -185,7 +207,6 @@ py_library(
],
deps = [
":mirrored_strategy",
- ":multi_worker_strategy",
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
@@ -220,9 +241,13 @@ py_test(
],
deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":strategy_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:distribute",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -244,40 +269,12 @@ py_test(
],
)
-py_test(
- name = "parameter_server_strategy_test",
- srcs = ["parameter_server_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- ":combinations",
- ":multi_worker_test_base",
- ":parameter_server_strategy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:layers",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
additional_deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":values",
":strategy_test_lib",
"//tensorflow/python:distribute",
@@ -346,19 +343,17 @@ py_library(
],
)
-py_test(
+cuda_py_test(
name = "collective_all_reduce_strategy_test",
srcs = ["collective_all_reduce_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
+ additional_deps = [
":collective_all_reduce_strategy",
":combinations",
":cross_tower_utils",
":multi_worker_test_base",
":strategy_test_lib",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -372,8 +367,10 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:estimator_py",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -453,6 +450,35 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "estimator_training_test",
+ size = "large",
+ srcs = ["estimator_training_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":mirrored_strategy",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ ],
+ tags = [
+ "manual",
+ "multi_and_single_gpu",
+ "no_pip",
+ "nogpu",
+ "notap",
+ ],
+)
+
py_library(
name = "single_loss_example",
srcs = ["single_loss_example.py"],
@@ -608,6 +634,7 @@ cuda_py_test(
":combinations",
":cross_tower_ops",
":multi_worker_test_base",
+ ":mirrored_strategy",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
index bcb977f640..865dba803f 100644
--- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
+++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
@@ -48,7 +48,7 @@ class CheckpointUtilsWithDistributionStrategyTest(
mode=["graph"]))
def testInitFromCheckpoint(self, distribution, in_tower_mode):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints(
session, checkpoint_dir)
@@ -62,7 +62,7 @@ class CheckpointUtilsWithDistributionStrategyTest(
"var1": "new_var1",
"var2": "new_var2"
})
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
session.run(variables.global_variables_initializer())
self.assertAllEqual(v1_value, self.evaluate(v1))
self.assertAllEqual(v2_value, self.evaluate(v2))
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 9afcaecf78..2331444261 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -18,30 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import os
-
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
-from tensorflow.python.training import server_lib
-
-
-# TODO(yuefengz): move this function to a common util file.
-def _normalize_cluster_spec(cluster_spec):
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- return server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- return cluster_spec
# TODO(yuefengz): shard the dataset.
@@ -52,51 +37,45 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Distribution strategy that uses collective ops for all-reduce.
It is similar to the MirroredStrategy but it uses collective ops for
- reduction. It currently only works for between-graph replication and its
- reduction will reduce across all workers.
+ reduction.
+
+ When `cluster_spec` is given by the `configure` method, it turns into the
+ mulit-worker version that works on multiple workers with between-graph
+ replication.
+
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type="worker",
- task_id=0):
+ def __init__(self, num_gpus_per_worker=0):
"""Initializes the object.
Args:
num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type, such as "worker".
- task_id: the current task id.
-
- Raises:
- ValueError: if `task_type` is not in the `cluster_spec`.
"""
self._num_gpus_per_worker = num_gpus_per_worker
- self._initialize(cluster_spec, task_type, task_id)
+ self._initialize(None, None, None)
def _initialize(self, cluster_spec, task_type, task_id):
- if task_type not in ["chief", "worker"]:
- raise ValueError(
- "Unrecognized task_type: %r, valid task types are: \"chief\", "
- "\"worker\"." % task_type)
if cluster_spec:
- self._cluster_spec = _normalize_cluster_spec(cluster_spec)
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
+ if task_type not in ["chief", "worker"]:
+ raise ValueError(
+ "Unrecognized task_type: %r, valid task types are: \"chief\", "
+ "\"worker\"." % task_type)
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec)
worker_device = "/job:%s/task:%d" % (task_type, task_id)
- num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
- if "chief" in self._cluster_spec.as_dict():
- num_workers += 1
+ num_workers = len(self._cluster_spec.as_dict().get("worker", [])) + len(
+ self._cluster_spec.as_dict().get("chief", []))
if not num_workers:
- raise ValueError("`task_type` shoud be in `cluster_spec`.")
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
- # TODO(yuefengz): create a utility to infer chief.
- if "chief" in self._cluster_spec.as_dict() and task_type == "chief":
- assert task_id == 0
- self._is_chief = True
- else:
- assert task_type == "worker"
- self._is_chief = task_id == 0
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
else:
self._cluster_spec = None
self._is_chief = True
@@ -187,19 +166,41 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
return mirrored_strategy._create_mirrored_variable(
devices, _real_mirrored_creator, *args, **kwargs)
- def configure(self, session_config=None):
- # Use TF_CONFIG to get the cluster spec and the current job.
- if not self._cluster_spec:
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the object.
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", "worker")
- task_id = int(task_env.get("index", "0"))
- else:
- task_type = "worker"
- task_id = 0
+ Args:
+ session_config: a @{tf.ConfigProto}
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
- if cluster_spec:
- self._initialize(cluster_spec, task_type, task_id)
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ # TODO(yuefengz): we'll need to mutate the session_config to add
+ # configurations for collective ops.
+ del session_config
+ if not self._cluster_spec and cluster_spec:
+ self._initialize(cluster_spec, task_type, task_id)
+
+ @property
+ def between_graph(self):
+ return True
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
+
+ @property
+ def should_save_summary(self):
+ return self._is_chief
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index b5e54e3b7d..e284969b1a 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -41,53 +39,43 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class DistributedCollectiveAllReduceStrategyTest(
- multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+class CollectiveAllReduceStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
- @classmethod
- def setUpClass(cls):
- """Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ]
- }
-
def setUp(self):
self._run_options = config_pb2.RunOptions()
self._run_options.experimental.collective_graph_key = 6
self._sess_config = config_pb2.ConfigProto()
- self._sess_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
# tests.
- DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
- super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
+ super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus,
- cluster_spec=self._cluster_spec,
- task_type=task_type,
- task_id=task_id)
+ num_gpus_per_worker=num_gpus)
+ if task_type and task_id is not None:
+ distribution.configure(
+ cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution._collective_keys = collective_keys
distribution._cross_tower_ops._collective_keys = collective_keys
- return distribution, self._workers[task_id].target
+ if task_type and task_id is not None:
+ return distribution, 'grpc://' + self._cluster_spec[task_type][task_id]
+ else:
+ return distribution, ''
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_object(task_type, task_id, num_gpus)
@@ -155,12 +143,6 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertLess(error_after, error_before)
return error_after < error_before
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testMinimizeLossGraph(self, num_gpus):
- self._run_between_graph_clients(self._test_minimize_loss_graph,
- self._cluster_spec, num_gpus)
-
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -182,16 +164,74 @@ class DistributedCollectiveAllReduceStrategyTest(
distribution.reduce(
variable_scope.VariableAggregation.MEAN, x,
destinations='/cpu:0'))[0]
+ x = distribution.unwrap(x)[0]
sess.run(
variables.global_variables_initializer(), options=self._run_options)
+
x_value, reduced_x_value = sess.run(
[x, reduced_x], options=self._run_options)
- self.assertTrue(np.array_equal(x_value, reduced_x_value))
- return np.array_equal(x_value, reduced_x_value)
+ self.assertTrue(
+ np.allclose(x_value, reduced_x_value, atol=1e-5),
+ msg=('x_value = %r, reduced_x_value = %r' % (x_value,
+ reduced_x_value)))
+ return np.allclose(x_value, reduced_x_value, atol=1e-5)
+
+
+class DistributedCollectiveAllReduceStrategyTest(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class DistributedCollectiveAllReduceStrategyTestWithChief(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0, has_chief=True)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
+ self._run_options.experimental.collective_graph_key = 7
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testVariableInitialization(self, num_gpus):
if context.num_gpus() < num_gpus:
return
@@ -201,16 +241,14 @@ class DistributedCollectiveAllReduceStrategyTest(
num_gpus=num_gpus)
-class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
- parameterized.TestCase):
+class LocalCollectiveAllReduceStrategy(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
def testMinimizeLossGraph(self, num_gpus=2):
# Collective ops doesn't support strategy with one device.
if context.num_gpus() < num_gpus:
return
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus)
- self._test_minimize_loss_graph(distribution)
+ self._test_minimize_loss_graph(None, None, num_gpus)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index aeec9c44d7..2301ba9233 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -48,7 +48,6 @@ import six
from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
-from tensorflow.contrib.distribute.python import multi_worker_strategy
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
@@ -342,33 +341,6 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
-multi_worker_strategy_with_cpu = NamedDistribution(
- "MultiWorkerCPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=0), 0)
-multi_worker_strategy_with_one_gpu = NamedDistribution(
- "MultiWorker1GPU",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=1), 1)
-multi_worker_strategy_with_two_gpus = NamedDistribution(
- "MultiWorker2GPUs",
- lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus_per_worker=2), 2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 3a7addf221..2a653b0f10 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -53,7 +53,7 @@ def validate_destinations(destinations):
if not isinstance(
destinations,
(value_lib.DistributedValues, resource_variable_ops.ResourceVariable,
- six.string_types, list)):
+ value_lib.AggregatingVariable, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, a device string, a list of device "
"strings or None")
@@ -62,7 +62,44 @@ def validate_destinations(destinations):
raise ValueError("destinations can not be empty")
+def _make_tensor_into_per_device(input_tensor):
+ """Converts a single tensor into a PerDevice object."""
+ if isinstance(input_tensor, (tuple, list)):
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, "
+ "got %r but expected a object that is not a tuple or list."
+ % (input_tensor,))
+ if isinstance(input_tensor, value_lib.PerDevice):
+ return input_tensor
+
+ try:
+ device = input_tensor.device
+ except AttributeError:
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object "
+ "because it doesn't have device set.")
+
+ return value_lib.PerDevice({device: input_tensor})
+
+
+def _normalize_value_destination_pairs(value_destination_pairs):
+ """Converts each tensor into a PerDevice object in the input list."""
+ result = []
+ if not isinstance(value_destination_pairs, (list, tuple)):
+ raise ValueError("`value_destination_pairs` should be a list or tuple")
+ for pair in value_destination_pairs:
+ if not isinstance(pair, tuple):
+ raise ValueError(
+ "Each element of `value_destination_pairs` should be a tuple.")
+ if len(pair) != 2:
+ raise ValueError("Each element of `value_destination_pairs` should be a "
+ "tuple of size 2.")
+
+ per_device = _make_tensor_into_per_device(pair[0])
+ result.append((per_device, pair[1]))
+ return result
+
+
def _validate_value_destination_pairs(value_destination_pairs):
+ # TODO(yuefengz): raise exceptions instead of returning False.
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
@@ -78,12 +115,15 @@ def _validate_value_destination_pairs(value_destination_pairs):
def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
- elif isinstance(destinations, resource_variable_ops.ResourceVariable):
+ elif isinstance(destinations, (resource_variable_ops.ResourceVariable,
+ value_lib.AggregatingVariable)):
return [destinations.device]
elif isinstance(destinations, six.string_types):
return [device_util.resolve(destinations)]
- else:
+ elif isinstance(destinations, (list, tuple)):
return [device_util.resolve(destination) for destination in destinations]
+ else:
+ return [destinations.device]
def _devices_match(left, right):
@@ -158,7 +198,7 @@ class CrossTowerOps(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
- per_device_value: a PerDevice object.
+ per_device_value: a PerDevice object or a tensor with device set.
destinations: the reduction destinations.
Returns:
@@ -168,7 +208,8 @@ class CrossTowerOps(object):
ValueError: if per_device_value is not a PerDevice object.
"""
if not isinstance(per_device_value, value_lib.PerDevice):
- raise ValueError("`per_device_value` must be a `PerDevice` object.")
+ per_device_value = _make_tensor_into_per_device(per_device_value)
+
if destinations is not None:
validate_destinations(destinations)
return self._reduce(aggregation, per_device_value, destinations)
@@ -183,8 +224,9 @@ class CrossTowerOps(object):
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
- and destinations. If a destination is None, then the destinations
- are set to match the devices of the input PerDevice object.
+ (or tensors with device set if there is one tower) and destinations. If
+ a destination is None, then the destinations are set to match the
+ devices of the input PerDevice object.
Returns:
a list of Mirrored objects.
@@ -194,8 +236,11 @@ class CrossTowerOps(object):
tuples of PerDevice objects and destinations
"""
if not _validate_value_destination_pairs(value_destination_pairs):
- raise ValueError("`value_destination_pairs` must be a list or a tuple of "
- "tuples of PerDevice objects and destinations")
+ # If the first element of each pair is a tensor, we try to turn it into a
+ # PerDevice object.
+ value_destination_pairs = _normalize_value_destination_pairs(
+ value_destination_pairs)
+
for _, d in value_destination_pairs:
if d is not None:
validate_destinations(d)
@@ -756,7 +801,7 @@ class CollectiveAllReduce(CrossTowerOps):
)
super(CollectiveAllReduce, self).__init__()
- # TODO(yuefengz, tucker): is index slices supported by collective ops?
+ # TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, aggregation, per_device_value, destinations):
all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
if destinations is None or _devices_match(per_device_value, destinations):
@@ -768,8 +813,10 @@ class CollectiveAllReduce(CrossTowerOps):
if d in all_reduced._index:
index[d] = all_reduced._index[d]
else:
- with ops.device(d):
+ with ops.control_dependencies(list(
+ all_reduced._index.values())), ops.device(d):
index[d] = array_ops.identity(list(all_reduced._index.values())[0])
+
return value_lib.Mirrored(index)
def _batch_reduce(self, aggregation, value_destination_pairs):
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index aec53b01d7..2ad91d56e9 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -26,12 +26,12 @@ import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -40,9 +40,17 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
-def _make_per_device(values, devices):
+def _make_per_device(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
+
+ # We simulate the result of regroup called on PerDevice which strips the
+ # PerDevice wrapper if it has only one value.
+ if len(values) == 1 and regroup:
+ with ops.device(devices[0]):
+ placed_v = array_ops.identity(values[0])
+ return placed_v
+
index = {}
for d, v in zip(devices, values):
with ops.device(d):
@@ -368,14 +376,27 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
("xring", 2, -1)], 0, 0, 0)),
],
distribution=[
- combinations.multi_worker_strategy_with_cpu,
- combinations.multi_worker_strategy_with_one_gpu,
- combinations.multi_worker_strategy_with_two_gpus
+ combinations.NamedDistribution(
+ "MirroredCPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ required_gpus=0),
+ combinations.NamedDistribution(
+ "Mirrored1GPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "Mirrored2GPUs",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ required_gpus=2),
],
mode=["graph"])
@combinations.generate(multi_worker_allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ distribution.configure(cluster_spec={
+ "worker":
+ ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
+ })
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
@@ -388,13 +409,8 @@ class MultiWorkerCollectiveAllReduceTest(
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- "fake_worker_0", "fake_worker_1", "fake_worker_2"
- ]
- }
def setUp(self):
super(MultiWorkerCollectiveAllReduceTest, self).setUp()
@@ -417,7 +433,7 @@ class MultiWorkerCollectiveAllReduceTest(
devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
else:
devices = ["/device:CPU:0"]
- return collective_all_reduce_ops, devices, "local"
+ return collective_all_reduce_ops, devices, ""
else:
collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
3, num_gpus, collective_keys=collective_keys)
@@ -428,7 +444,8 @@ class MultiWorkerCollectiveAllReduceTest(
]
else:
devices = ["/job:%s/task:%d" % (task_type, task_id)]
- return collective_all_reduce_ops, devices, self._workers[task_id].target
+ return (collective_all_reduce_ops, devices,
+ "grpc://" + self._cluster_spec[task_type][task_id])
def _assert_values_equal(self, left, right, sess):
if isinstance(left, list):
@@ -455,7 +472,8 @@ class MultiWorkerCollectiveAllReduceTest(
num_workers = 1
worker_device = None
else:
- num_workers = len(self._workers)
+ num_workers = len(self._cluster_spec.get("chief", [])) + len(
+ self._cluster_spec.get("worker", []))
worker_device = "/job:%s/task:%d" % (task_type, task_id)
with ops.Graph().as_default(), \
ops.device(worker_device), \
@@ -463,7 +481,7 @@ class MultiWorkerCollectiveAllReduceTest(
# Collective ops doesn't support scalar tensors, so we have to construct
# 1-d tensors.
values = [constant_op.constant([float(d)]) for d in range(len(devices))]
- per_device = _make_per_device(values, devices)
+ per_device = _make_per_device(values, devices, regroup=True)
mean = np.array([(len(devices) - 1.) / 2.])
values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
@@ -476,7 +494,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_different, None, destination_mirrored, destination_str,
destination_list
]
@@ -533,13 +551,19 @@ class MultiWorkerCollectiveAllReduceTest(
return True
@combinations.generate(
- combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1))
def testReductionDistributed(self, num_gpus):
if context.num_gpus() < num_gpus:
return
self._run_between_graph_clients(self._test_reduction, self._cluster_spec,
num_gpus)
+ # Collective ops doesn't support strategy with one device.
+ def testReductionLocal(self, num_gpus=2):
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_reduction(None, None, num_gpus, local_mode=True)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
new file mode 100644
index 0000000000..5348512016
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -0,0 +1,659 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that show Distribute Coordinator works with Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import json
+import os
+import sys
+import tempfile
+import threading
+from absl.testing import parameterized
+import numpy as np
+import six
+
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.optimizer_v2 import adagrad
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import estimator_training as dc_training
+from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.estimator import training as estimator_training
+from tensorflow.python.estimator.canned import dnn_linear_combined
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export as export_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import server_lib
+
+BATCH_SIZE = 10
+LABEL_DIMENSION = 2
+DATA = np.linspace(
+ 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
+ BATCH_SIZE, LABEL_DIMENSION)
+EVAL_NAME = "foo"
+EXPORTER_NAME = "saved_model_exporter"
+MAX_STEPS = 10
+
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+WORKER = dc._TaskType.WORKER
+PS = dc._TaskType.PS
+
+original_run_distribute_coordinator = dc.run_distribute_coordinator
+
+
+# TODO(yuefengz): merge this method back to test_util.
+def _create_local_cluster(num_workers,
+ num_ps,
+ has_eval=False,
+ protocol="grpc",
+ worker_config=None,
+ ps_config=None):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ if has_eval:
+ cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs,
+ job_name="worker",
+ protocol=protocol,
+ task_index=ix,
+ config=worker_config,
+ start=True) for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs,
+ job_name="ps",
+ protocol=protocol,
+ task_index=ix,
+ config=ps_config,
+ start=True) for ix in range(num_ps)
+ ]
+ if has_eval:
+ evals = [
+ server_lib.Server(
+ cs,
+ job_name="evaluator",
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+ ]
+ else:
+ evals = []
+
+ return workers, ps_servers, evals
+
+
+def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
+ """Create an in-process cluster that consists of only standard server."""
+ # Leave some memory for cuda runtime.
+ if has_eval:
+ gpu_mem_frac = 0.7 / (num_workers + 1)
+ else:
+ gpu_mem_frac = 0.7 / num_workers
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ # Enable collective ops which has no impact on non-collective ops.
+ # TODO(yuefengz, tucker): removing this after we move the initialization of
+ # collective mgr to the session level.
+ worker_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ ps_config = config_pb2.ConfigProto()
+ ps_config.device_count["GPU"] = 0
+
+ return _create_local_cluster(
+ num_workers,
+ num_ps=num_ps,
+ has_eval=has_eval,
+ worker_config=worker_config,
+ ps_config=ps_config,
+ protocol="grpc")
+
+
+def _create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
+ if num_workers:
+ cluster_spec[WORKER] = [
+ "localhost:%s" % portpicker.pick_unused_port()
+ for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec[PS] = [
+ "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
+ return cluster_spec
+
+
+def _bytes_to_str(maybe_bytes):
+ if isinstance(maybe_bytes, six.string_types):
+ return maybe_bytes
+ else:
+ return str(maybe_bytes, "utf-8")
+
+
+def _strip_protocol(target):
+ # cluster_spec expects "host:port" strings.
+ if "//" in target:
+ return target.split("//")[1]
+ else:
+ return target
+
+
+class DistributeCoordinatorIntegrationTest(test.TestCase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ num_workers=3, num_ps=2, has_eval=True)
+ cls._cluster_spec = {
+ "worker": [
+ _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
+ ],
+ "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
+ "evaluator": [
+ _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
+ ]
+ }
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+ self._event = threading.Event()
+ super(DistributeCoordinatorIntegrationTest, self).setUp()
+
+ def dataset_input_fn(self, x, y, batch_size, shuffle):
+
+ def input_fn():
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ if shuffle:
+ dataset = dataset.shuffle(batch_size)
+ dataset = dataset.repeat(100).batch(batch_size)
+ return dataset
+
+ return input_fn
+
+ def _get_exporter(self, name, fc):
+ feature_spec = feature_column.make_parse_example_spec(fc)
+ serving_input_receiver_fn = (
+ export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
+ return exporter_lib.LatestExporter(
+ name, serving_input_receiver_fn=serving_input_receiver_fn)
+
+ def _extract_loss_and_global_step(self, event_folder):
+ """Returns the loss and global step in last event."""
+ event_paths = glob.glob(os.path.join(event_folder, "events*"))
+
+ loss = None
+ global_step_count = None
+
+ for e in summary_iterator.summary_iterator(event_paths[-1]):
+ current_loss = None
+ for v in e.summary.value:
+ if v.tag == "loss":
+ current_loss = v.simple_value
+
+ # If loss is not found, global step is meaningless.
+ if current_loss is None:
+ continue
+
+ current_global_step = e.step
+ if global_step_count is None or current_global_step > global_step_count:
+ global_step_count = current_global_step
+ loss = current_loss
+
+ return (loss, global_step_count)
+
+ def _get_estimator(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ input_dimension = LABEL_DIMENSION
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+
+ return dnn_linear_combined.DNNLinearCombinedRegressor(
+ linear_feature_columns=linear_feature_columns,
+ dnn_hidden_units=(2, 2),
+ dnn_feature_columns=dnn_feature_columns,
+ label_dimension=LABEL_DIMENSION,
+ model_dir=self._model_dir,
+ dnn_optimizer=adagrad.AdagradOptimizer(0.001),
+ linear_optimizer=adagrad.AdagradOptimizer(0.001),
+ config=run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=train_distribute,
+ eval_distribute=eval_distribute,
+ remote_cluster=remote_cluster)))
+
+ def _complete_flow(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ estimator = self._get_estimator(train_distribute, eval_distribute,
+ remote_cluster)
+
+ input_dimension = LABEL_DIMENSION
+ train_input_fn = self.dataset_input_fn(
+ x={"x": DATA},
+ y=DATA,
+ batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
+ shuffle=True)
+ if eval_distribute:
+ eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
+ else:
+ eval_batch_size = BATCH_SIZE
+ eval_input_fn = self.dataset_input_fn(
+ x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
+
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+
+ estimator_training.train_and_evaluate(
+ estimator,
+ estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
+ estimator_training.EvalSpec(
+ name=EVAL_NAME,
+ input_fn=eval_input_fn,
+ steps=None,
+ exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
+ start_delay_secs=0,
+ throttle_secs=1))
+ return estimator
+
+ def _inspect_train_and_eval_events(self, estimator):
+ # Make sure nothing is stuck in limbo.
+ writer_cache.FileWriterCache.clear()
+
+ # Examine the training events. Use a range to check global step to avoid
+ # flakyness due to global step race condition.
+ training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
+ self.assertIsNotNone(training_loss)
+
+ # Examine the eval events. The global step should be accurate.
+ eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
+ eval_loss, eval_global_step = self._extract_loss_and_global_step(
+ event_folder=eval_dir)
+ self.assertIsNotNone(eval_loss)
+ self.assertGreaterEqual(eval_global_step, MAX_STEPS)
+
+ # Examine the export folder.
+ export_dir = os.path.join(
+ os.path.join(self._model_dir, "export"), EXPORTER_NAME)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ # Examine the ckpt for predict.
+ def predict_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ "x": DATA
+ }).batch(BATCH_SIZE)
+
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_standalone_client(self, train_distribute_cls,
+ eval_distribute_cls):
+ try:
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+ except TypeError:
+ train_distribute = train_distribute_cls(num_gpus_per_worker=2)
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ estimator = self._complete_flow(
+ train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
+ self._inspect_train_and_eval_events(estimator)
+
+ def _mock_run_distribute_coordinator(
+ self,
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=dc.CoordinatorMode.STANDALONE_CLIENT,
+ cluster_spec=None,
+ session_config=None):
+ # Calls the origial `run_distribute_coordinator` method but gets task config
+ # from environment variables and then signals the caller.
+ task_type = None
+ task_id = None
+ if not cluster_spec:
+ cluster_spec = None
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if not cluster_spec:
+ cluster_spec = tf_config.get("cluster", {})
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", task_type)
+ task_id = int(task_env.get("index", task_id))
+ self._event.set()
+ original_run_distribute_coordinator(
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=mode,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config)
+
+ def _task_thread(self, train_distribute, eval_distribute):
+ with test.mock.patch.object(dc, "run_distribute_coordinator",
+ self._mock_run_distribute_coordinator):
+ self._complete_flow(train_distribute, eval_distribute)
+
+ def _run_task_in_thread(self, cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute):
+ if task_type:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ else:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ self._event.clear()
+ t = threading.Thread(
+ target=self._task_thread, args=(train_distribute, eval_distribute))
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ t.start()
+ self._event.wait()
+ return t
+
+ def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
+ eval_distribute):
+ threads = {}
+ for task_type in cluster_spec.keys():
+ threads[task_type] = []
+ for task_id in range(len(cluster_spec[task_type])):
+ t = self._run_task_in_thread(cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute)
+ threads[task_type].append(t)
+ return threads
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ parameter_server_strategy.ParameterServerStrategy,
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_between_graph(
+ self, train_distribute_cls, eval_distribute_cls):
+ train_distribute = train_distribute_cls(
+ num_gpus_per_worker=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ for task_type, ts in threads.items():
+ if task_type == PS:
+ continue
+ for t in ts:
+ t.join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[mirrored_strategy.MirroredStrategy],
+ eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
+ eval_distribute_cls):
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ threads[WORKER][0].join()
+ threads[EVALUATOR][0].join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+
+TF_CONFIG_WITH_CHIEF = {
+ "cluster": {
+ "chief": ["fake_chief"],
+ },
+ "task": {
+ "type": "chief",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITH_MASTER = {
+ "cluster": {
+ "master": ["fake_master"],
+ },
+ "task": {
+ "type": "master",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
+
+
+class RunConfigTest(test.TestCase):
+
+ def test_previously_unexpected_cluster_spec(self):
+ with test.mock.patch.dict(
+ "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
+ run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+
+ def test_should_run_distribute_coordinator(self):
+ """Tests that should_run_distribute_coordinator return a correct value."""
+ # We don't use distribute coordinator for local training.
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config_with_train_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ config_with_eval_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertTrue(
+ dc_training.should_run_distribute_coordinator(
+ config_with_train_distribute))
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ config_with_eval_distribute))
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertFalse(dc_training.should_run_distribute_coordinator(config))
+
+ def test_init_run_config_duplicate_distribute(self):
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy()))
+
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy()))
+
+ def test_init_run_config_none_distribute_coordinator_mode(self):
+ # We don't use distribute coordinator for local training.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ dc_training.init_run_config(config, {})
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig()
+ self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
+
+ def test_init_run_config_independent_worker(self):
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator with INDEPENDENT_WORKER mode.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.INDEPENDENT_WORKER)
+
+ def test_init_run_config_standalone_client(self):
+ # When `train_distribute` is specified, TF_CONFIG is detected and
+ # `experimental.remote_cluster` is set use distribute coordinator with
+ # STANDALONE_CLIENT mode.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ remote_cluster={"chief": ["fake_worker"]}))
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.STANDALONE_CLIENT)
+
+
+if __name__ == "__main__":
+ with test.mock.patch.object(sys, "exit", os._exit):
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD
index cbfd178502..84b106545e 100644
--- a/tensorflow/contrib/distribute/python/examples/BUILD
+++ b/tensorflow/contrib/distribute/python/examples/BUILD
@@ -19,9 +19,20 @@ py_binary(
)
py_binary(
- name = "simple_tfkeras_example",
+ name = "keras_model_with_estimator",
srcs = [
- "simple_tfkeras_example.py",
+ "keras_model_with_estimator.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_binary(
+ name = "keras_mnist",
+ srcs = [
+ "keras_mnist.py",
],
deps = [
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
new file mode 100644
index 0000000000..a20069c4fe
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -0,0 +1,126 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An example training a Keras Model using MirroredStrategy and native APIs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+NUM_CLASSES = 10
+
+
+def get_input_datasets():
+ """Downloads the MNIST dataset and creates train and eval dataset objects.
+
+ Returns:
+ Train dataset, eval dataset and input shape.
+
+ """
+ # input image dimensions
+ img_rows, img_cols = 28, 28
+
+ # the data, split between train and test sets
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+
+ if tf.keras.backend.image_data_format() == 'channels_first':
+ x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
+ x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
+ input_shape = (1, img_rows, img_cols)
+ else:
+ x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
+ x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
+ input_shape = (img_rows, img_cols, 1)
+
+ x_train = x_train.astype('float32')
+ x_test = x_test.astype('float32')
+ x_train /= 255
+ x_test /= 255
+
+ # convert class vectors to binary class matrices
+ y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
+ y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
+
+ # train dataset
+ train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ train_ds = train_ds.repeat()
+ train_ds = train_ds.shuffle(100)
+ train_ds = train_ds.batch(64)
+
+ # eval dataset
+ eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+ eval_ds = eval_ds.repeat()
+ eval_ds = eval_ds.shuffle(100)
+ eval_ds = eval_ds.batch(64)
+
+ return train_ds, eval_ds, input_shape
+
+
+def get_model(input_shape):
+ """Builds a Sequential CNN model to recognize MNIST digits.
+
+ Args:
+ input_shape: Shape of the input depending on the `image_data_format`.
+
+ Returns:
+ a Keras model
+
+ """
+ # Define a CNN model to recognize MNIST digits.
+ model = tf.keras.models.Sequential()
+ model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
+ activation='relu',
+ input_shape=input_shape))
+ model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
+ model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
+ model.add(tf.keras.layers.Dropout(0.25))
+ model.add(tf.keras.layers.Flatten())
+ model.add(tf.keras.layers.Dense(128, activation='relu'))
+ model.add(tf.keras.layers.Dropout(0.5))
+ model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ return model
+
+
+def main(_):
+ # Build the train and eval datasets from the MNIST data. Also return the
+ # input shape which is constructed based on the `image_data_format`
+ # i.e channels_first or channels_last.
+ train_ds, eval_ds, input_shape = get_input_datasets()
+ model = get_model(input_shape)
+
+ # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or
+ # the `devices` argument then all the GPUs available on the machine are used.
+ strategy = tf.contrib.distribute.MirroredStrategy()
+
+ # Compile the model by passing the distribution strategy object to the
+ # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed
+ # based on the strategy instantiated.
+ model.compile(loss=tf.keras.losses.categorical_crossentropy,
+ optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001),
+ metrics=['accuracy'],
+ distribute=strategy)
+
+ # Train the model with the train dataset.
+ model.fit(x=train_ds, epochs=20, steps_per_epoch=310)
+
+ # Evaluate the model with the eval dataset.
+ score = model.evaluate(eval_ds, steps=10, verbose=0)
+ print('Test loss:', score[0])
+ print('Test accuracy:', score[1])
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
index 518ec9c423..8d117eb7e8 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
@@ -42,19 +42,19 @@ def main(args):
model_dir = args[1]
print('Using %s to store checkpoints.' % model_dir)
- # Define tf.keras Model.
+ # Define a Keras Model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
- # Compile tf.keras Model.
+ # Compile the model.
optimizer = tf.train.GradientDescentOptimizer(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
tf.keras.backend.set_learning_phase(True)
- # Define a DistributionStrategy and convert the tf.keras Model to a
- # tf.Estimator that utilizes the DistributionStrategy.
+ # Define a DistributionStrategy and convert the Keras Model to an
+ # Estimator that utilizes the DistributionStrategy.
strategy = tf.contrib.distribute.MirroredStrategy(
['/device:GPU:0', '/device:GPU:1'])
config = tf.estimator.RunConfig(
@@ -62,7 +62,7 @@ def main(args):
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, config=config, model_dir=model_dir)
- # Train and evaluate the tf.Estimator.
+ # Train and evaluate the model.
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn)
print('Eval result: {}'.format(eval_result))
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index 16179c3a49..c5acb7ced4 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -91,7 +91,7 @@ class AutoShardDatasetTest(test.TestCase):
def _verifySimpleShardingOutput(self, dataset, record_fn):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
@@ -150,7 +150,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
@@ -182,7 +182,7 @@ class AutoShardDatasetTest(test.TestCase):
# Verify output.
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual = []
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
@@ -218,7 +218,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._record(r, f), sess.run(next_element))
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 4facd72d12..d39fd57294 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -116,7 +116,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
model_dir=self._base_dir,
train_distribute=dist,
eval_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -139,7 +139,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -163,7 +163,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
config=config)
with self.assertRaisesRegexp(ValueError,
@@ -178,7 +178,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
class TestWithDistributionStrategy(test.TestCase):
def test_validating_dataset_input_tensors_with_shape_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2))
@@ -197,7 +197,7 @@ class TestWithDistributionStrategy(test.TestCase):
strategy, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
@@ -216,7 +216,7 @@ class TestWithDistributionStrategy(test.TestCase):
strategy, x, y)
def test_calling_model_on_same_dataset(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -242,7 +242,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.predict(dataset, steps=2)
def test_fit_with_tuple_and_dict_dataset_inputs(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -283,7 +283,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
def test_fit_eval_and_predict_methods_on_dataset(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -320,7 +320,7 @@ class TestWithDistributionStrategy(test.TestCase):
def __call__(self, y_true, y_pred):
return y_pred - y_true
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -336,7 +336,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
def test_unsupported_features(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -367,8 +367,8 @@ class TestWithDistributionStrategy(test.TestCase):
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
- NotImplementedError, 'sample_weight is currently not supported when '
- 'using DistributionStrategy.'):
+ NotImplementedError, '`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.'):
model.fit(
dataset,
epochs=1,
@@ -389,7 +389,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.predict(dataset, verbose=0)
def test_calling_with_unsupported_predefined_callbacks(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -428,7 +428,7 @@ class TestWithDistributionStrategy(test.TestCase):
callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
def test_dataset_input_shape_validation(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -465,7 +465,7 @@ class TestWithDistributionStrategy(test.TestCase):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
# Lambda layer uses the learning phase.
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(16,), name='input')
y = keras.layers.Dense(16)(x)
z = keras.layers.Dropout(0.9999)(y)
@@ -498,7 +498,7 @@ class TestWithDistributionStrategy(test.TestCase):
class LossMaskingWithDistributionStrategyTest(test.TestCase):
def test_masking(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
@@ -523,7 +523,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
def test_batchnorm_correctness(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
@@ -550,7 +550,7 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
class CorrectnessWithDistributionStrategyTest(test.TestCase):
def test_correctness(self):
- with self.test_session():
+ with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
x_train = np.random.rand(num_samples, 1)
@@ -565,8 +565,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
dataset_with = dataset_with.batch(32)
strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'],
- prefetch_on_device=False)
+ '/device:GPU:0'])
model.compile(loss=keras.losses.mean_squared_error,
optimizer=gradient_descent.GradientDescentOptimizer(0.5),
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 516ede7ade..bdac4fb58c 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -71,7 +71,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -108,7 +108,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -168,7 +168,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -249,7 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -343,7 +343,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -466,7 +466,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index edd5c6d17a..e87b48ba41 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import contextlib
+from functools import partial
import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@@ -274,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -287,13 +292,55 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result
class MirroredStrategy(distribute_lib.DistributionStrategy):
- """Mirrors vars to distribute across multiple devices on a single machine.
+ """Mirrors vars to distribute across multiple devices and machines.
+
+ This strategy uses one tower per device and sync replication for its multi-GPU
+ version.
+
+ When `cluster_spec` is given by the `configure` method., it turns into the
+ mulit-worker version that works on multiple workers with in-graph replication.
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
+
+ There are several important concepts for distributed TensorFlow, e.g.
+ `client`, `job`, 'task', `cluster`, `in-graph replication` and
+ 'synchronous training' and they have already been defined in the
+ [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+ The distribution strategy inherits these concepts as well and in addition to
+ that we also clarify several more concepts:
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
+ specifies tasks for devices on all workers. The `client` then creates a
+ client session which will talk to the `master` service of a `worker`. Then
+ the `master` will partition the graph and distribute the work to all
+ participating workers.
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ physical machine. We will have multiple `worker`s with different `task`
+ index. They all do similar things except for one worker checkpointing model
+ variables, writing summaries, etc. in addition to its ordinary work.
+
+ The multi-worker version of this class maps one tower to one device on a
+ worker. It mirrors all model variables on all towers. For example, if you have
+ two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the
+ model variables on these 8 GPUs. Then like in MirroredStrategy, each tower
+ performs their computation with their own copy of variables unless in
+ cross-tower model where variable or tensor reduction happens.
- This strategy uses one tower per device and sync replication.
+ Args:
+ devices: a list of device strings.
+ num_gpus: number of GPUs. For local training, either specify `devices` or
+ `num_gpus`. In distributed training, this must be specified as number of
+ GPUs on each worker.
+ cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
+ set, the `configure` method will try to find the best one.
+ prefetch_on_device: optional boolean to specify whether to prefetch input
+ data to devices.
"""
def __init__(self,
@@ -302,13 +349,73 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
+
+ self._cross_tower_ops = cross_tower_ops
+ self._prefetch_on_device = prefetch_on_device
+ # Rememeber num GPUs which might be needed by `configure` method.
+ self._num_gpus = num_gpus
+
+ self._initialize_local(num_gpus, devices)
+
+ def _initialize_local(self, num_gpus, devices):
+ """Initializes the object for local training."""
+ self._cluster_spec = None
# Convert `num_gpus` into `devices`, shouldn't specify both.
if devices is None:
if num_gpus is None:
num_gpus = context.num_gpus()
- devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ if num_gpus == 0:
+ devices = ["/device:CPU:0"]
+ else:
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
elif num_gpus is not None:
raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ self._num_gpus = num_gpus
+ # TODO(yuefengz): consider setting the default device.
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
+ self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)})
+
+ def _initialize_multi_worker(self, num_gpus, cluster_spec):
+ """Initializes the object for multi-worker training."""
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in ["chief", "worker"]:
+ for task in range(len(cluster_spec.as_dict().get(job, []))):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
+ if num_gpus is None:
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ if num_gpus > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, "/device:CPU:0")]
+ for worker in self._workers
+ }
+
+ devices = nest.flatten(self._worker_device_map)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
assert devices, "Must specify at least one device."
assert len(set(devices)) == len(devices), (
@@ -318,9 +425,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
{d: i for i, d in enumerate(devices)})
- self._cross_tower_ops = cross_tower_ops
- self._prefetch_on_device = prefetch_on_device
- # TODO(yuefengz): consider setting the default device.
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
@@ -357,9 +461,14 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**kwargs)
def distribute_dataset(self, dataset_fn):
- return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+ self._prefetch_on_device)
+ else:
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
@@ -444,10 +553,22 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
- def configure(self, session_config=None):
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del task_type, task_id
+ if cluster_spec:
+ self._initialize_multi_worker(self._num_gpus, cluster_spec)
+
if self._cross_tower_ops is None:
- self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
- self._devices, session_config=session_config)
+ if self._cluster_spec:
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
+ else:
+ self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None:
@@ -532,6 +653,22 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def parameter_devices(self):
return list(self._devices)
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def non_slot_devices(self, var_list):
del var_list
return list(self._devices)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9a4cc0a897..a12ff662db 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import sys
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
@@ -41,6 +42,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -886,8 +888,18 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
- mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+
+ # read_value == True
+ mirrored_var_result = self.evaluate(
+ mirrored_var.assign_add(6.0, read_value=True))
self.assertEquals(7.0, mirrored_var_result)
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+
+ # read_value == False
+ self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarTowerContext(self):
@@ -954,6 +966,8 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEquals(3.0, mirrored_var_result)
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarTowerContext(self):
@@ -1244,5 +1258,39 @@ class MirroredStrategyDefunTest(test.TestCase):
self._call_and_check(fn1, [factors], expected_result, [fn1])
+class MultiWorkerMirroredStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "worker": ["/job:worker/task:0", "/job:worker/task:1"]
+ })
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=cluster_spec)
+ return strategy
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy(),
+ learning_rate=0.05)
+
+
+class MultiWorkerMirroredStrategyTestWithChief(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=2, num_ps=0, has_chief=True)
+ cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
+
+ def testMinimizeLossGraph(self):
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=self._cluster_spec)
+ self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 5db2fff239..969e126956 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -22,6 +22,8 @@ from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context
@@ -60,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase):
def model_fn(device_id):
assert isinstance(device_id, int)
+
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
@@ -86,5 +89,21 @@ class VariableCreatorStackTest(test.TestCase):
self.assertEquals(expected, result)
+class MultiWorkerMirroredStrategyTest(test.TestCase):
+
+ def testDeviceScope(self):
+ """Test the device scope of multi-worker MirroredStrategy."""
+ with context.graph_mode():
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(
+ cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]})
+ with strategy.scope():
+ a = constant_op.constant(1.)
+ with ops.device("/cpu:0"):
+ b = constant_op.constant(1.)
+ self.assertEqual(a.device, "/job:worker/task:0")
+ self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 2892ce4394..16be839e1d 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -45,7 +45,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
monitor = monitor_lib.Monitor(single_loss_step, None)
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
monitor = monitor_lib.Monitor(single_loss_step, sess)
monitor.run_steps(1)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
deleted file mode 100644
index cbfe5df61d..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Classes implementing a mirrored DistributionStrategy for multiple workers."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from functools import partial
-
-from tensorflow.contrib.distribute.python import values
-from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
-from tensorflow.core.protobuf import cluster_pb2
-from tensorflow.python.training import device_util
-from tensorflow.python.training import server_lib
-from tensorflow.python.util import nest
-
-
-# TODO(yuefengz): support between-graph replication.
-# TODO(yuefengz): merge this class into its base class.
-# TODO(yuefengz): in some cases, we probably want to use configure method to
-# configure this class.
-# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the
-# class is introduced.
-class MultiWorkerMirroredStrategy(MirroredStrategy):
- """Mirrored strategy that works on multiple workers with in-graph replication.
-
- There are several important concepts for distributed TensorFlow, e.g.
- `client`, `job`, 'task', `cluster`, `in-graph replication` and
- 'synchronous training' and they have already been defined in the
- [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
- The distribution strategy inherits these concepts as well and in addition to
- that we also clarify several more concepts:
- * **In-graph replication**: the `client` creates a single `tf.Graph` that
- specifies tasks for devices on all workers. The `client` then creates a
- client session which will talk to the `master` service of a `worker`. Then
- the `master` will partition the graph and distribute the work to all
- participating workers.
- * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
- physical machine. We will have multiple `worker`s with different `task`
- index. They all do similar things except for one worker checkpointing model
- variables, writing summaries, etc. in addition to its ordinary work.
-
- This class maps one tower to one device on a worker. It mirrors all model
- variables on all towers. For example, if you have two `worker`s and each
- `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8
- GPUs. Then like in MirroredStrategy, each tower performs their computation
- with their own copy of variables unless in cross-tower model where variable or
- tensor reduction happens.
- """
-
- def __init__(self,
- num_gpus_per_worker=1,
- worker_job_name=None,
- num_workers=None,
- cluster=None,
- cross_tower_ops=None,
- prefetch_on_device=None):
- """Initialize the strategy object.
-
- Args:
- num_gpus_per_worker: number of GPUs per work. If it is zero, the local
- CPU will be used.
- worker_job_name: the job name for `worker`, typically just 'worker'.
- num_workers: the number of workers. If it is 0, it regenerates to
- single-worker MirroredStrategy.
- cluster: a `tf.train.ClusterSpec` object or a dict that can be used to
- construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef`
- proto buffer. It is an alternative way to initialize this object.
- cross_tower_ops: the cross tower ops to use. If None, a default one will
- be used. If configure method is called, a best one for the configuration
- will be chosen.
- prefetch_on_device: a boolean to specify whether to prefetech input to
- each worker's devices.
-
- Raises:
- ValueError: if got an unexpected `cluster`.
- """
- if cluster is None:
- self._workers = [
- '/job:%s/task:%d' % (worker_job_name, task_index)
- for task_index in range(num_workers)
- ]
- else:
- if isinstance(cluster, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster)
- elif isinstance(cluster, server_lib.ClusterSpec):
- cluster_spec = cluster
- else:
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- '`tf.train.ClusterDef` object')
-
- self._workers = []
- for job in sorted(cluster_spec.jobs):
- for task in range(cluster_spec.num_tasks(job)):
- self._workers.append('/job:%s/task:%d' % (job, task))
-
- self._num_gpus_per_worker = num_gpus_per_worker
- if num_gpus_per_worker > 0:
- self._worker_device_map = {
- worker: [
- device_util.canonicalize(worker + '/device:GPU:%d' % gpu)
- for gpu in range(num_gpus_per_worker)
- ] for worker in self._workers
- }
- else:
- self._worker_device_map = {
- worker: [device_util.canonicalize(worker, '/device:CPU:0')]
- for worker in self._workers
- }
- self._devices = nest.flatten(self._worker_device_map)
-
- super(MultiWorkerMirroredStrategy, self).__init__(
- devices=self._devices, prefetch_on_device=prefetch_on_device)
-
- # Setting `_default_device` will add a device scope in the
- # distribution.scope. We set the default device to the first worker. When
- # users specify device under distribution.scope by
- # with tf.device("/cpu:0"):
- # ...
- # their ops will end up on the cpu device of its first worker, e.g.
- # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
- self._default_device = self._workers[0]
-
- def distribute_dataset(self, dataset_fn):
- return values.MultiWorkerDataset(
- partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
deleted file mode 100644
index 09c859b32a..0000000000
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MultiWorkerMirroredStrategy."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import multi_worker_strategy
-from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
-from tensorflow.python.eager import context
-from tensorflow.python.eager import test
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.training import server_lib
-
-
-class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- strategy_test_lib.DistributionTestBase):
-
- def _get_distribution_strategy(self):
- return multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster=server_lib.ClusterSpec({
- 'worker': ['/job:worker/task:0', '/job:worker/task:1']
- }),
- num_gpus_per_worker=context.num_gpus())
-
- def testMinimizeLossGraph(self):
- self._test_minimize_loss_graph(self._get_distribution_strategy())
-
-
-class DeviceScopeTest(test.TestCase):
- """Test the device scope of MultiWorkerMirroredStrategy."""
-
- def testDeviceScope(self):
- with context.graph_mode():
- strategy = multi_worker_strategy.MultiWorkerMirroredStrategy(
- cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
- num_gpus_per_worker=context.num_gpus())
- with strategy.scope():
- a = constant_op.constant(1.)
- with ops.device('/cpu:0'):
- b = constant_op.constant(1.)
- self.assertEqual(a.device, '/job:worker/task:0')
- self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 249de01f08..18b4503eff 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -23,26 +23,105 @@ import copy
import threading
import numpy as np
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
-
-
-def create_in_process_cluster(num_workers, num_ps):
+from tensorflow.python.training import server_lib
+
+
+def _create_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ protocol='grpc',
+ worker_config=None,
+ ps_config=None):
+ """Creates and starts local servers and returns the cluster_spec dict."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ if num_workers > 0:
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
+ if has_eval:
+ cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ if has_chief:
+ cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ for i in range(num_workers):
+ server_lib.Server(
+ cs,
+ job_name='worker',
+ protocol=protocol,
+ task_index=i,
+ config=worker_config,
+ start=True)
+
+ for i in range(num_ps):
+ server_lib.Server(
+ cs,
+ job_name='ps',
+ protocol=protocol,
+ task_index=i,
+ config=ps_config,
+ start=True)
+
+ if has_chief:
+ server_lib.Server(
+ cs,
+ job_name='chief',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ if has_eval:
+ server_lib.Server(
+ cs,
+ job_name='evaluator',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
- gpu_mem_frac = 0.7 / num_workers
+ gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
+ if has_chief:
+ worker_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
+ else:
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps):
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
- return test_util.create_local_cluster(
+ return _create_cluster(
num_workers,
num_ps=num_ps,
+ has_chief=has_chief,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
@@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
@@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase):
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ if target is None:
+ target = self._default_target
if graph is None:
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
- graph=None, config=config, target=target or self._workers[0].target)
+ graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
- with session.Session(
- graph=graph, config=config, target=target or
- self._workers[0].target) as sess:
+ with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index a2d736e422..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built)))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 8041eb0f34..361c8be590 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -22,10 +22,12 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -55,7 +57,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
assigned to.
This class assumes between-graph replication will be used and works on a graph
- for a particular worker.
+ for a particular worker. Note that each graph and worker is independent.
+ This means that while each worker will synchronously compute a single gradient
+ update across all GPUs, updates between workers proceed asynchronously.
+ Operations that occur only on the first tower (such as incrementing the global
+ step), will occur on the first tower *of every worker*.
It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any
operations which potentially can be replicated across towers (i.e. multiple
@@ -73,7 +79,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
3) It is also not recommended to open a colocation scope (i.e. calling
`tf.colocate_with`) under the strategy's scope. For colocating variables,
use `distribution.colocate_vars_with` instead. Colocation of ops will possibly
- create conflicts of device assignement.
+ create conflicts of device assignment.
"""
def __init__(self,
@@ -81,7 +87,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster_spec=None,
task_type=None,
task_id=None):
- """Initiailizes this strategy.
+ """Initializes this strategy.
Args:
num_gpus_per_worker: number of local GPUs or GPUs per worker.
@@ -89,11 +95,18 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster configurations.
task_type: the current task type.
task_id: the current task id.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
if cluster_spec:
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
self._cluster_spec = cluster_spec
# We typically don't need to do all-reduce in this strategy.
@@ -217,14 +230,57 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through
# this creator, such as "MutableHashTable".
def _create_variable(self, next_creator, *args, **kwargs):
+ if self.num_towers > 1:
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in (
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN
+ ):
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ def var_creator(*args, **kwargs):
+ # Record what collections this variable should be added to.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Create and wrap the variable.
+ v = next_creator(*args, **kwargs)
+ wrapped = values.AggregatingVariable(v, aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the contained
+ # variable to the TRAINABLE_VARIABLES collection, so we manually
+ # remove it and replace with the wrapper. We can't set "trainable"
+ # to False for next_creator() since that causes functions like
+ # implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l.remove(v)
+ g.add_to_collections(collections, wrapped)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
+
+ return wrapped
+ else:
+ var_creator = next_creator
+
if "colocate_with" in kwargs:
with ops.device(None):
with ops.colocate_with(kwargs["colocate_with"]):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._variable_device):
- return next_creator(*args, **kwargs)
+ return var_creator(*args, **kwargs)
def _call_for_each_tower(self, fn, *args, **kwargs):
# pylint: disable=protected-access
@@ -246,7 +302,6 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
-
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
@@ -279,6 +334,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
def _update(self, var, fn, *args, **kwargs):
+ if isinstance(var, values.AggregatingVariable):
+ var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
@@ -323,6 +380,10 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster configurations.
task_type: the current task type.
task_id: the current task id.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
del session_config
@@ -331,6 +392,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if not self._cluster_spec and cluster_spec:
self._cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_spec)
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec,
task_type, task_id)
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 0df65714fb..0e2bfcec5f 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -24,6 +24,8 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
@@ -37,21 +39,15 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import training_util
+CHIEF = run_config.TaskType.CHIEF
+WORKER = run_config.TaskType.WORKER
+PS = run_config.TaskType.PS
-class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- parameterized.TestCase):
- @classmethod
- def setUpClass(cls):
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=2)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ],
- run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
- }
+class ParameterServerStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
def setUp(self):
self._result = 0
@@ -60,7 +56,7 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
- super(ParameterServerStrategyTest, self).setUp()
+ super(ParameterServerStrategyTestBase, self).setUp()
def _get_test_objects(self, task_type, task_id, num_gpus):
distribution = parameter_server_strategy.ParameterServerStrategy(
@@ -70,13 +66,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
distribution.configure(
cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
- return distribution, self._workers[task_id].target
+ return distribution, 'grpc://' + self._cluster_spec[WORKER][task_id]
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
d, _ = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -101,7 +97,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/job:worker/task:0'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
# The variable x is on the task 1 since the device_function has been
@@ -113,18 +111,26 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(y.device, '/job:ps/task:1')
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(z.device, '/job:ps/task:0')
self.assertNotEqual(z.device, x.device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, worker_device + '/' + last_part_device)
@@ -162,18 +168,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testDeviceAssignmentDistributed(self, num_gpus):
- self._test_device_assignment_distributed('worker', 1, num_gpus)
-
def _test_device_assignment_local(self,
d,
compute_device='CPU',
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target) as sess, \
d.scope():
def model_fn():
@@ -202,7 +203,9 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The device scope is ignored for variables but not for normal ops.
with ops.device('/device:GPU:2'):
- x = variable_scope.get_variable('x', initializer=10.0)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
x_add = x.assign_add(c)
e = a + c
self.assertEqual(
@@ -212,19 +215,27 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
# The colocate_vars_with can override the distribution's device.
with d.colocate_vars_with(x):
- y = variable_scope.get_variable('y', initializer=20.0)
- y_add = y.assign_add(x_add)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ y_add = y.assign_add(array_ops.identity(x_add))
self.assertEqual(
device_util.canonicalize(y.device), tower_variable_device)
self.assertEqual(y_add.device, y.device)
self.assertEqual(y.device, x.device)
- z = variable_scope.get_variable('z', initializer=10.0)
+ z = variable_scope.get_variable(
+ 'z', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
self.assertEqual(
device_util.canonicalize(z.device), tower_variable_device)
with ops.control_dependencies([y_add]):
- z_add = z.assign_add(y)
+ # We add an identity here to avoid complaints about summing
+ # non-distributed values.
+ z_add = z.assign_add(array_ops.identity(y))
with ops.control_dependencies([z_add]):
f = z + c
self.assertEqual(f.device, tower_compute_device)
@@ -256,29 +267,12 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- def testDeviceAssignmentLocalCPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=0)
- self._test_device_assignment_local(
- distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
-
- def testDeviceAssignmentLocalOneGPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=1)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
-
- def testDeviceAssignmentLocalTwoGPUs(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=2)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
-
def _test_simple_increment(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
- num_workers = len(d._cluster_spec.as_dict().get('worker',
- ['dummy_worker']))
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if 'chief' in d._cluster_spec.as_dict():
+ num_workers += 1
else:
num_workers = 1
with ops.Graph().as_default(), \
@@ -286,11 +280,18 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
d.scope():
def model_fn():
- x = variable_scope.get_variable('x', initializer=10.0)
- y = variable_scope.get_variable('y', initializer=20.0)
-
- x_add = x.assign_add(1.0, use_locking=True)
- y_add = y.assign_add(1.0, use_locking=True)
+ x = variable_scope.get_variable(
+ 'x', initializer=10.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ y = variable_scope.get_variable(
+ 'y', initializer=20.0,
+ aggregation=variable_scope.VariableAggregation.SUM)
+
+ # We explicitly make a constant tensor here to avoid complaints about
+ # summing non-distributed values.
+ one = constant_op.constant(1.0)
+ x_add = x.assign_add(one, use_locking=True)
+ y_add = y.assign_add(one, use_locking=True)
train_op = control_flow_ops.group([x_add, y_add])
return x, y, train_op
@@ -330,6 +331,11 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ assert hasattr(d, '_cluster_spec') and d._cluster_spec
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if CHIEF in d._cluster_spec.as_dict():
+ num_workers += 1
+
with ops.Graph().as_default(), \
self.test_session(target=master_target) as sess, \
d.scope():
@@ -378,13 +384,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_id == 0:
+ if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id):
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
self._init_condition.acquire()
self._init_reached += 1
- while self._init_reached != 3:
+ while self._init_reached != num_workers:
self._init_condition.wait()
self._init_condition.notify_all()
self._init_condition.release()
@@ -401,9 +407,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertLess(error_after, error_before)
return error_after < error_before
+
+class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2)
+ cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
+
+ def testDeviceAssignmentLocalCPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=0)
+ self._test_device_assignment_local(
+ distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+
+ def testDeviceAssignmentLocalOneGPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=1)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+
+ def testDeviceAssignmentLocalTwoGPUs(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testDeviceAssignmentDistributed(self, num_gpus):
+ self._test_device_assignment_distributed('worker', 1, num_gpus)
+
def testSimpleBetweenGraph(self):
self._run_between_graph_clients(self._test_simple_increment,
- self._cluster_spec, 0)
+ self._cluster_spec, context.num_gpus())
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
@@ -417,5 +456,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._cluster_spec, num_gpus)
+class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2, has_chief=True)
+ cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
+
+ def testSimpleBetweenGraph(self):
+ self._run_between_graph_clients(self._test_simple_increment,
+ self._cluster_spec, context.num_gpus())
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def testGlobalStepIsWrapped(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ with ops.Graph().as_default(), distribution.scope():
+ created_step = training_util.create_global_step()
+ get_step = training_util.get_global_step()
+ self.assertEqual(created_step, get_step,
+ msg=('created_step %s type %s vs. get_step %s type %s' %
+ (id(created_step), created_step.__class__.__name__,
+ id(get_step), get_step.__class__.__name__)))
+ self.assertIs(values.AggregatingVariable, type(created_step))
+ self.assertIs(values.AggregatingVariable, type(get_step))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
index a68dbce6c7..bb10b546a1 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -37,7 +37,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@ class PrefetchingOpsV2Test(test.TestCase):
next_element = iterator.get_next()
output = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
result = sess.run(next_element)
self.assertEqual(2, len(result))
@@ -75,7 +75,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(5):
sess.run(next_element)
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 8605ab1f7d..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -49,7 +49,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
run_step = single_loss_step
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 371b97ba96..6ee26e19ac 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -130,7 +130,8 @@ class DistributionTestBase(test.TestCase):
# Error should go down
self.assertLess(error_after, error_before)
- def _test_minimize_loss_graph(self, d, soft_placement=False):
+ def _test_minimize_loss_graph(self, d, soft_placement=False,
+ learning_rate=0.2):
config = config_pb2.ConfigProto()
config.allow_soft_placement = soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
@@ -150,7 +151,7 @@ class DistributionTestBase(test.TestCase):
grad_fn = backprop.implicit_grad(loss)
def update(v, g):
- return v.assign_sub(0.2 * g)
+ return v.assign_sub(learning_rate * g)
one = d.broadcast(constant_op.constant([[1.]]))
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 77fc56de36..6202a0750a 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -51,7 +51,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
- query_topology=True))
+ query_topology=False))
return tpu_system_metadata
@@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self, tpu_cluster_resolver, steps_per_run):
+ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None):
"""Initializes the TPUStrategy object.
Args:
@@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
+ num_cores: Number of cores to use on the TPU. If None specified, then
+ auto-detect the cores and topology of the TPU system.
"""
# TODO(isaprykin): Generalize the defaults. They are currently tailored for
# the unit test.
@@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ self._num_cores_override = num_cores
- # TODO(priyag): This should not be hardcoded here.
- self._host = '/device:CPU:0'
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
+ # TODO(frankchn): This should not be hardcoded here for pod purposes.
+ self._host = self.tpu_host_cpu_device(0)
+
def distribute_dataset(self, dataset_fn):
# TODO(priyag): Perhaps distribute across cores here.
return self._call_dataset_fn(dataset_fn)
@@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
+ # TODO(sourabhbajaj): Add support for TPU pods
with ops.device(self._host):
for _ in range(self.num_towers):
# Use control dependencies to ensure a deterministic ordering.
@@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
@property
def num_towers(self):
- return self._tpu_metadata.num_of_cores_per_host
+ return self._num_cores_override or self._tpu_metadata.num_cores
+
+ def tpu_host_cpu_device(self, host_id):
+ if self._tpu_cluster_resolver.get_master() in ('', 'local'):
+ return '/replica:0/task:0/device:CPU:0'
+ return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id)
+
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8548a86421..3ccaa2690e 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate):
return self._index[device]
return list(self._index.values())[0]
+ def _as_graph_element(self):
+ obj = self.get()
+ # pylint: disable=protected-access
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ return conv_fn()
+ return obj
+
def _assign_on_device(device, variable, tensor):
with ops.device(device):
@@ -296,6 +304,10 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
def read_value(self):
return distribution_strategy_context.get_distribution_strategy().read_var(
self)
@@ -308,26 +320,6 @@ class DistributedVariable(DistributedDelegate):
ops.register_dense_tensor_like_type(DistributedVariable)
-def _get_update_device():
- """Validate we are in update/update_non_slot() and return current device.
-
- This is used in MirroredVariable.assign* members, to make sure they
- are only called via an update method, to make sure all components of the
- variable are being updated in a consistent way.
-
- Returns:
- A string device.
-
- Raises:
- RuntimeError: If not in distribution.update()/.update_non_slot().
- """
- device = distribute_lib.get_update_device()
- if device is None:
- raise RuntimeError(
- "Use DistributionStrategy.update() to modify a MirroredVariable.")
- return device
-
-
class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
"""Class for defining how to restore a MirroredVariable."""
@@ -366,15 +358,27 @@ class MirroredVariable(DistributedVariable, Mirrored,
f = kwargs.pop("f")
if distribution_strategy_context.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
- # We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
- # We are calling an assign function on the mirrored variable in cross
- # tower context.
+ # We are calling an assign function on the mirrored variable in an
+ # update context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
- return distribution_strategy_context.get_distribution_strategy().update(
- self, f, *args, **kwargs)
+ # We are calling assign on the mirrored variable in cross tower context,
+ # use update to update the variable.
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ updates = strategy.update(self, f, *args, **kwargs)
+ grouped = strategy.group(updates)
+ if isinstance(updates, DistributedValues) and updates.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(mirrored_var.assign*(...)) may only update one tower.
+ index = {}
+ for d in updates.devices:
+ with ops.device(d), ops.control_dependencies([grouped]):
+ index[d] = array_ops.identity(updates.get(d))
+ return Mirrored(index)
+ else:
+ return grouped
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -1057,3 +1061,160 @@ def value_container(val):
if container is not None:
return container
return val
+
+
+# TODO(josh11b): Descend from Variable.
+class AggregatingVariable(checkpointable.CheckpointableBase):
+ """A wrapper around a variable that aggregates updates across towers."""
+
+ def __init__(self, v, aggregation):
+ self._v = v
+ # TODO(josh11b): Set v._distributed_container?
+ # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
+ self._aggregation = aggregation
+
+ def get(self):
+ return self._v
+
+ def __getattr__(self, name):
+ return getattr(self._v, name)
+
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ if update_device is not None:
+ # We are calling an assign function in an update context.
+ return f(self._v, *args, **kwargs)
+
+ # We are calling an assign function in cross tower context, wrap it in an
+ # update call.
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ assert distribution_strategy_context.get_tower_context()
+ # We are calling an assign function in tower context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function with the reduced value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "a variable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ def assign_sub(self, *args, **kwargs):
+ assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def name(self):
+ return self._v.name
+
+ @property
+ def dtype(self):
+ return self._v.dtype
+
+ # TODO(josh11b): Test saving & restoring.
+ def _gather_saveables_for_checkpoint(self):
+ return {checkpointable.VARIABLE_VALUE_KEY: self._v}
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self._v + o
+ def __radd__(self, o): return o + self._v
+ def __sub__(self, o): return self._v - o
+ def __rsub__(self, o): return o - self._v
+ def __mul__(self, o): return self._v * o
+ def __rmul__(self, o): return o * self._v
+ def __truediv__(self, o): return self._v / o
+ def __rtruediv__(self, o): return o / self._v
+ def __floordiv__(self, o): return self._v // o
+ def __rfloordiv__(self, o): return o // self._v
+ def __mod__(self, o): return self._v % o
+ def __rmod__(self, o): return o % self._v
+ def __lt__(self, o): return self._v < o
+ def __le__(self, o): return self._v <= o
+ def __gt__(self, o): return self._v > o
+ def __ge__(self, o): return self._v >= o
+ def __and__(self, o): return self._v & o
+ def __rand__(self, o): return o & self._v
+ def __or__(self, o): return self._v | o
+ def __ror__(self, o): return o | self._v
+ def __xor__(self, o): return self._v ^ o
+ def __rxor__(self, o): return o ^ self._v
+ def __getitem__(self, o): return self._v[o]
+ def __pow__(self, o, modulo=None): return pow(self._v, o, modulo)
+ def __rpow__(self, o): return pow(o, self._v)
+ def __invert__(self): return ~self._v
+ def __neg__(self): return -self._v
+ def __abs__(self): return abs(self._v)
+
+ def __div__(self, o):
+ try:
+ return self._v.__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self._v.__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self._v.__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self._v.__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __str__(self):
+ return str(self._v)
+
+ def __repr__(self):
+ return repr(self._v)
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(
+ AggregatingVariable, _tensor_conversion_aggregate)
+ops.register_dense_tensor_like_type(AggregatingVariable)
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 91a43d4999..3602f4d128 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -653,7 +653,7 @@ class MirroredVariableTest(test.TestCase):
def _save_mirrored(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -668,7 +668,7 @@ class MirroredVariableTest(test.TestCase):
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
@@ -684,7 +684,7 @@ class MirroredVariableTest(test.TestCase):
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
@@ -698,7 +698,7 @@ class MirroredVariableTest(test.TestCase):
def _restore_mirrored(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -864,7 +864,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
@@ -881,7 +881,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_sum(self):
"""Save variables with mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local("sum")
# Overwrite the initial values.
@@ -897,7 +897,7 @@ class TowerLocalVariableTest(test.TestCase):
def _save_normal(self):
"""Save variables without mirroring, returns save_path."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
@@ -913,7 +913,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_normal(self, save_path):
"""Restore to variables without mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
var = variable_scope.get_variable(
name="v", initializer=7., use_resource=True)
@@ -927,7 +927,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
@@ -942,7 +942,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
index d8bacdb338..5d57d144c1 100644
--- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py
+++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py
@@ -56,7 +56,7 @@ class WarmStartingUtilWithDistributionStrategyTest(
# Create variable and save checkpoint from which to warm-start.
def create_var(g):
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
var = variable_scope.get_variable(var_name, initializer=original_value)
sess.run(variables.global_variables_initializer())
saver = saver_lib.Saver()
@@ -75,7 +75,7 @@ class WarmStartingUtilWithDistributionStrategyTest(
self.assertAllEqual(original_value, prev_init_val)
def warm_start(g):
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
# Initialize with zeros.
var = variable_scope.get_variable(
var_name, initializer=[[0., 0.], [0., 0.]])