aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-10-04 09:42:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:46:29 -0700
commitc2552cd33c05fa84f280e766e33ba01308ffbcb2 (patch)
treead154f46d284c4cdfb0d160b10c558fd00fed629 /tensorflow/contrib
parent1fb84c2e41c454939a02a69093cb214673eab343 (diff)
Skip numeric checking in BROADCAST mode.
PiperOrigin-RevId: 215752559
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 3aa5b6efa1..8d15c857f8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -177,14 +177,29 @@ def _create_or_get_iterations_per_loop():
use_resource=True)
-def _sync_variables_ops():
- # Gets the variables back from TPU nodes. This means the variables updated
- # by TPU will now be *synced* to host memory.
- return [
- array_ops.check_numerics(v.read_value(),
- 'Gradient for %s is NaN' % v.name).op
- for v in variables.trainable_variables()
- ]
+def _sync_variables_ops(ctx):
+ """Create varriables synchronization ops.
+
+ Gets the variables back from TPU nodes. This means the variables updated
+ by TPU will now be *synced* to host memory.
+ In BROADCAST mode, we skip this sync since the variables are ususally too
+ big to transmit via RPC.
+
+ Args:
+ ctx: A `_InternalTPUContext` instance with mode.
+
+ Returns:
+ A list of sync ops.
+ """
+
+ if not ctx.is_input_broadcast_with_iterators():
+ return [
+ array_ops.check_numerics(v.read_value(),
+ 'Gradient for %s is NaN' % v.name).op
+ for v in variables.trainable_variables()
+ ]
+ else:
+ return [control_flow_ops.no_op()]
def _increase_eval_step_op(iterations_per_loop):
@@ -2567,7 +2582,7 @@ class TPUEstimator(estimator_lib.Estimator):
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
- update_ops = _sync_variables_ops()
+ update_ops = _sync_variables_ops(ctx)
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
@@ -2600,7 +2615,7 @@ class TPUEstimator(estimator_lib.Estimator):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
internal_ops_to_run.append(
_increase_eval_step_op(iterations_per_loop_var))
with ops.control_dependencies(internal_ops_to_run):
@@ -2645,7 +2660,7 @@ class TPUEstimator(estimator_lib.Estimator):
scaffold, prediction_hooks) = _predict_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
with ops.control_dependencies(internal_ops_to_run):
dummy_predict_op = control_flow_ops.no_op()