diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-01 16:31:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 16:35:32 -0700 |
commit | bb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch) | |
tree | 66809f6708a49a5b8d408ef0b349fe837d905cda /tensorflow/contrib/optimizer_v2 | |
parent | 49bbfec04b729960999ef054e3acab719631b101 (diff) |
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change,
update() would return a Mirrored value that where each component was
an update to a single mirror. This caused a problem since for reading
purposes other DistributionStrategy methods would consider it okay
to read any single component, and so if you for example did something
like session.run(strategy.update(...)) it would only perform the
update on one replica. The fix is to have the output be a Mirrored
value that is actually the identity operation returning the output on
that device, but that has a control dependency making sure that the
update actually happens on all the replicas. This fix was already
present in MirroredVariable._assign_func, this CL moves the fix into
update() and generalizes it to multiple return values.
To disable this new grouping behavior, you may now pass
"grouped=False" to update(). For example, some callers (like Optimizer)
are performing a lot of updates and they prefer to group all of them
together at once for performance reasons. In this case, we still want
to make sure the caller executes the update on all replicas, so we
return an unwrapped value instead of a Mirrored value. This has the
happy side effect of removing a bunch of unwrap calls in client code,
since unwrapping was the only safe way to use the Mirrored value we
used to return.
PiperOrigin-RevId: 215301909
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 6af59dcfbf..53e27c08c4 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Use the processors to update the variables. update_ops = [] for grad, var in grads_and_vars: - update_ops.extend(distribution.unwrap(distribution.update( - var, update, grad))) + update_ops.extend(distribution.update(var, update, grad, grouped=False)) # Give the child class a chance to do something after applying # gradients @@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer): update_ops = control_flow_ops.group(update_ops) with ops.control_dependencies([update_ops]): - finish_updates = distribution.update_non_slot(non_slot_devices, finish) - if finish_updates is None: - finish_updates = update_ops + finish_updates = distribution.update_non_slot( + non_slot_devices, finish, grouped=False) + # We said grouped=False, which means finish_updates is always a list. + # It will be [None] when finish() returns None. + if finish_updates == [None]: + finish_updates = [update_ops] # Update `global_step` (if any). if global_step is None: apply_updates = distribution.group(finish_updates, name=name) else: - with ops.control_dependencies(distribution.unwrap(finish_updates)): - - def update_global_step(global_step): - if isinstance(global_step, resource_variable_ops.ResourceVariable): - return global_step.assign_add( - ops.convert_to_tensor(1, dtype=global_step.dtype), - read_value=False) - else: - return state_ops.assign_add(global_step, 1) - - apply_updates = distribution.group( - distribution.update(global_step, update_global_step), name=name) + with ops.control_dependencies(finish_updates): + + def update_global_step(global_step, name): + return global_step.assign_add(1, read_value=False, name=name) + + apply_updates = distribution.update( + global_step, update_global_step, name) # Add the training op to the TRAIN_OP graph collection in graph mode. if not eager_execution: |