diff options
author | 2018-08-24 09:12:59 -0700 | |
---|---|---|
committer | 2018-08-24 09:18:02 -0700 | |
commit | 6f00ab7b8f16f450c00375df271c45da4dc72be5 (patch) | |
tree | 25163f5e9d441de60df20f1e990758271cbcf316 /tensorflow/contrib/distribute/python | |
parent | 247b81a7c47fe52a383c86a9a32efa536ead6fa6 (diff) |
For ParameterServerStrategy, make sure to include the AggregatingVariable
wrapper for variables in collections instead of what it wraps.
PiperOrigin-RevId: 210107528
Diffstat (limited to 'tensorflow/contrib/distribute/python')
5 files changed, 86 insertions, 31 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 753fdddbb2..8173b5d4ba 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -106,6 +106,38 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + additional_deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + ":values", + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -240,36 +272,6 @@ py_test( ) cuda_py_test( - name = "parameter_server_strategy_test", - srcs = ["parameter_server_strategy_test.py"], - additional_deps = [ - ":combinations", - ":multi_worker_test_base", - ":parameter_server_strategy", - "@absl_py//absl/testing:parameterized", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:layers", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:estimator_py", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - -cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index ecaf60f350..e87b48ba41 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -276,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): else: result = values.MirroredVariable(index, index[devices[0]], aggregation) + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables @@ -289,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): for v in index.values(): l.remove(v) g.add_to_collections(collections, result) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) + return result diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 3b1cc0217b..361c8be590 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -240,8 +241,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): " for variable: " + kwargs["name"]) def var_creator(*args, **kwargs): + # Record what collections this variable should be added to. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Create and wrap the variable. v = next_creator(*args, **kwargs) - return values.AggregatingVariable(v, aggregation) + wrapped = values.AggregatingVariable(v, aggregation) + + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the contained + # variable to the TRAINABLE_VARIABLES collection, so we manually + # remove it and replace with the wrapper. We can't set "trainable" + # to False for next_creator() since that causes functions like + # implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + l.remove(v) + g.add_to_collections(collections, wrapped) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) + + return wrapped else: var_creator = next_creator diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 3ec9e6f27b..0e2bfcec5f 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.contrib.distribute.python import values from tensorflow.python.distribute import multi_worker_util from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -38,6 +39,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import training_util CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -473,6 +475,19 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, self._run_between_graph_clients(self._test_minimize_loss_graph, self._cluster_spec, num_gpus) + def testGlobalStepIsWrapped(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + with ops.Graph().as_default(), distribution.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(values.AggregatingVariable, type(created_step)) + self.assertIs(values.AggregatingVariable, type(get_step)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a58bb3a849..1b9fdef5b0 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -1180,6 +1180,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase): def __repr__(self): return repr(self._v) + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. |