diff options
author | Yanan Cao <ycao@google.com> | 2018-09-27 17:04:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 17:11:56 -0700 |
commit | 0a9ee95ed9c26bef58e9daadcb6935807d90fcd3 (patch) | |
tree | d196e84c1508f7c70b07e57f9c355308b197ffbf /tensorflow/contrib/compiler | |
parent | a0de15424803bb2688aafd496c30b78c4eb6e1c3 (diff) |
Disable summary ops from lower-level xla.compile API rather than xla.estimator_model_fn
PiperOrigin-RevId: 214860981
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/xla.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 1e30525159..873b03580d 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -293,7 +293,8 @@ def _compile_internal(computation, inputs=None): saved_use_resource = vscope.use_resource vscope.set_use_resource(True) - outputs = computation(*computation_inputs) + with _disable_summary_context(): + outputs = computation(*computation_inputs) # Restore variable scope after computation. vscope.set_use_resource(saved_use_resource) @@ -371,13 +372,13 @@ def _disable_summary_context(): Yields: None. """ - origional_skip_summary_func = summary_op_util.skip_summary + original_skip_summary_func = summary_op_util.skip_summary summary_op_util.skip_summary = lambda: True try: yield finally: - summary_op_util.skip_summary = origional_skip_summary_func + summary_op_util.skip_summary = original_skip_summary_func class _CapturedObject(object): @@ -436,8 +437,7 @@ class _ModelFnWrapper(object): if mode == model_fn_lib.ModeKeys.TRAIN: train_step, captured_scaffold_fn = self._make_train_step( features, labels, params) - with _disable_summary_context(): - (loss,) = compile(train_step) + (loss,) = compile(train_step) return model_fn_lib.EstimatorSpec( mode=mode, loss=loss, @@ -446,8 +446,7 @@ class _ModelFnWrapper(object): elif mode == model_fn_lib.ModeKeys.EVAL: eval_step, captured_eval_metric_fn, captured_scaffold_fn = ( self._make_eval_step(features, labels, params)) - with _disable_summary_context(): - outputs = compile(eval_step) + outputs = compile(eval_step) loss = outputs[0] # Calculate eval_metric_ops if eval_metric_fn is set and captured. |