aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-09-27 17:04:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 17:11:56 -0700
commit0a9ee95ed9c26bef58e9daadcb6935807d90fcd3 (patch)
treed196e84c1508f7c70b07e57f9c355308b197ffbf /tensorflow/contrib/compiler
parenta0de15424803bb2688aafd496c30b78c4eb6e1c3 (diff)
Disable summary ops from lower-level xla.compile API rather than xla.estimator_model_fn
PiperOrigin-RevId: 214860981
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/xla.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 1e30525159..873b03580d 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -293,7 +293,8 @@ def _compile_internal(computation, inputs=None):
saved_use_resource = vscope.use_resource
vscope.set_use_resource(True)
- outputs = computation(*computation_inputs)
+ with _disable_summary_context():
+ outputs = computation(*computation_inputs)
# Restore variable scope after computation.
vscope.set_use_resource(saved_use_resource)
@@ -371,13 +372,13 @@ def _disable_summary_context():
Yields:
None.
"""
- origional_skip_summary_func = summary_op_util.skip_summary
+ original_skip_summary_func = summary_op_util.skip_summary
summary_op_util.skip_summary = lambda: True
try:
yield
finally:
- summary_op_util.skip_summary = origional_skip_summary_func
+ summary_op_util.skip_summary = original_skip_summary_func
class _CapturedObject(object):
@@ -436,8 +437,7 @@ class _ModelFnWrapper(object):
if mode == model_fn_lib.ModeKeys.TRAIN:
train_step, captured_scaffold_fn = self._make_train_step(
features, labels, params)
- with _disable_summary_context():
- (loss,) = compile(train_step)
+ (loss,) = compile(train_step)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=loss,
@@ -446,8 +446,7 @@ class _ModelFnWrapper(object):
elif mode == model_fn_lib.ModeKeys.EVAL:
eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
self._make_eval_step(features, labels, params))
- with _disable_summary_context():
- outputs = compile(eval_step)
+ outputs = compile(eval_step)
loss = outputs[0]
# Calculate eval_metric_ops if eval_metric_fn is set and captured.