aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-24 09:12:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 09:18:02 -0700
commit6f00ab7b8f16f450c00375df271c45da4dc72be5 (patch)
tree25163f5e9d441de60df20f1e990758271cbcf316 /tensorflow/contrib/distribute/python
parent247b81a7c47fe52a383c86a9a32efa536ead6fa6 (diff)
For ParameterServerStrategy, make sure to include the AggregatingVariable
wrapper for variables in collections instead of what it wraps. PiperOrigin-RevId: 210107528
Diffstat (limited to 'tensorflow/contrib/distribute/python')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD62
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py30
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py15
-rw-r--r--tensorflow/contrib/distribute/python/values.py4
5 files changed, 86 insertions, 31 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 753fdddbb2..8173b5d4ba 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -106,6 +106,38 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ ":values",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -240,36 +272,6 @@ py_test(
)
cuda_py_test(
- name = "parameter_server_strategy_test",
- srcs = ["parameter_server_strategy_test.py"],
- additional_deps = [
- ":combinations",
- ":multi_worker_test_base",
- ":parameter_server_strategy",
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:layers",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/distribute:multi_worker_util",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
- ],
- tags = [
- "multi_and_single_gpu",
- "no_pip",
- ],
-)
-
-cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index ecaf60f350..e87b48ba41 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -276,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -289,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 3b1cc0217b..361c8be590 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -240,8 +241,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
" for variable: " + kwargs["name"])
def var_creator(*args, **kwargs):
+ # Record what collections this variable should be added to.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Create and wrap the variable.
v = next_creator(*args, **kwargs)
- return values.AggregatingVariable(v, aggregation)
+ wrapped = values.AggregatingVariable(v, aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the contained
+ # variable to the TRAINABLE_VARIABLES collection, so we manually
+ # remove it and replace with the wrapper. We can't set "trainable"
+ # to False for next_creator() since that causes functions like
+ # implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l.remove(v)
+ g.add_to_collections(collections, wrapped)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
+
+ return wrapped
else:
var_creator = next_creator
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 3ec9e6f27b..0e2bfcec5f 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
@@ -38,6 +39,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import training_util
CHIEF = run_config.TaskType.CHIEF
WORKER = run_config.TaskType.WORKER
@@ -473,6 +475,19 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
self._run_between_graph_clients(self._test_minimize_loss_graph,
self._cluster_spec, num_gpus)
+ def testGlobalStepIsWrapped(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ with ops.Graph().as_default(), distribution.scope():
+ created_step = training_util.create_global_step()
+ get_step = training_util.get_global_step()
+ self.assertEqual(created_step, get_step,
+ msg=('created_step %s type %s vs. get_step %s type %s' %
+ (id(created_step), created_step.__class__.__name__,
+ id(get_step), get_step.__class__.__name__)))
+ self.assertIs(values.AggregatingVariable, type(created_step))
+ self.assertIs(values.AggregatingVariable, type(get_step))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a58bb3a849..1b9fdef5b0 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -1180,6 +1180,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase):
def __repr__(self):
return repr(self._v)
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.