diff options
author | 2018-10-04 09:42:13 -0700 | |
---|---|---|
committer | 2018-10-04 09:46:29 -0700 | |
commit | c2552cd33c05fa84f280e766e33ba01308ffbcb2 (patch) | |
tree | ad154f46d284c4cdfb0d160b10c558fd00fed629 /tensorflow/contrib | |
parent | 1fb84c2e41c454939a02a69093cb214673eab343 (diff) |
Skip numeric checking in BROADCAST mode.
PiperOrigin-RevId: 215752559
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 3aa5b6efa1..8d15c857f8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -177,14 +177,29 @@ def _create_or_get_iterations_per_loop(): use_resource=True) -def _sync_variables_ops(): - # Gets the variables back from TPU nodes. This means the variables updated - # by TPU will now be *synced* to host memory. - return [ - array_ops.check_numerics(v.read_value(), - 'Gradient for %s is NaN' % v.name).op - for v in variables.trainable_variables() - ] +def _sync_variables_ops(ctx): + """Create varriables synchronization ops. + + Gets the variables back from TPU nodes. This means the variables updated + by TPU will now be *synced* to host memory. + In BROADCAST mode, we skip this sync since the variables are ususally too + big to transmit via RPC. + + Args: + ctx: A `_InternalTPUContext` instance with mode. + + Returns: + A list of sync ops. + """ + + if not ctx.is_input_broadcast_with_iterators(): + return [ + array_ops.check_numerics(v.read_value(), + 'Gradient for %s is NaN' % v.name).op + for v in variables.trainable_variables() + ] + else: + return [control_flow_ops.no_op()] def _increase_eval_step_op(iterations_per_loop): @@ -2567,7 +2582,7 @@ class TPUEstimator(estimator_lib.Estimator): summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() + update_ops = _sync_variables_ops(ctx) # Validate the TPU training graph to catch basic errors _validate_tpu_training_graph() @@ -2600,7 +2615,7 @@ class TPUEstimator(estimator_lib.Estimator): # After TPU evaluation computation is done (the mean_loss tensor), # reads all variables back from TPU and updates the eval step # counter properly - internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run = _sync_variables_ops(ctx) internal_ops_to_run.append( _increase_eval_step_op(iterations_per_loop_var)) with ops.control_dependencies(internal_ops_to_run): @@ -2645,7 +2660,7 @@ class TPUEstimator(estimator_lib.Estimator): scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): - internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run = _sync_variables_ops(ctx) with ops.control_dependencies(internal_ops_to_run): dummy_predict_op = control_flow_ops.no_op() |