aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:31:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:35:32 -0700
commitbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch)
tree66809f6708a49a5b8d408ef0b349fe837d905cda /tensorflow/contrib/optimizer_v2
parent49bbfec04b729960999ef054e3acab719631b101 (diff)
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change, update() would return a Mirrored value that where each component was an update to a single mirror. This caused a problem since for reading purposes other DistributionStrategy methods would consider it okay to read any single component, and so if you for example did something like session.run(strategy.update(...)) it would only perform the update on one replica. The fix is to have the output be a Mirrored value that is actually the identity operation returning the output on that device, but that has a control dependency making sure that the update actually happens on all the replicas. This fix was already present in MirroredVariable._assign_func, this CL moves the fix into update() and generalizes it to multiple return values. To disable this new grouping behavior, you may now pass "grouped=False" to update(). For example, some callers (like Optimizer) are performing a lot of updates and they prefer to group all of them together at once for performance reasons. In this case, we still want to make sure the caller executes the update on all replicas, so we return an unwrapped value instead of a Mirrored value. This has the happy side effect of removing a bunch of unwrap calls in client code, since unwrapping was the only safe way to use the Mirrored value we used to return. PiperOrigin-RevId: 215301909
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py32
1 files changed, 14 insertions, 18 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 6af59dcfbf..53e27c08c4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
@@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Use the processors to update the variables.
update_ops = []
for grad, var in grads_and_vars:
- update_ops.extend(distribution.unwrap(distribution.update(
- var, update, grad)))
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
# Give the child class a chance to do something after applying
# gradients
@@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer):
update_ops = control_flow_ops.group(update_ops)
with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(non_slot_devices, finish)
- if finish_updates is None:
- finish_updates = update_ops
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
# Update `global_step` (if any).
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
-
- def update_global_step(global_step):
- if isinstance(global_step, resource_variable_ops.ResourceVariable):
- return global_step.assign_add(
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- read_value=False)
- else:
- return state_ops.assign_add(global_step, 1)
-
- apply_updates = distribution.group(
- distribution.update(global_step, update_global_step), name=name)
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(
+ global_step, update_global_step, name)
# Add the training op to the TRAIN_OP graph collection in graph mode.
if not eager_execution: