aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-09-21 19:24:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 19:27:39 -0700
commit0695e9ad8fe6f50942c8c18d648aea982541eeae (patch)
treed46244608d3efa3f795eacca72664cb6ba8267d5 /tensorflow/contrib/compiler
parent174e782ded74187fa81f034bb3cfedf2b100286d (diff)
xla.estimator_model_fn can be used to decorate a model_fn written for estimator API in order to compile entire model with XLA.
PiperOrigin-RevId: 214078470
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/BUILD20
-rw-r--r--tensorflow/contrib/compiler/xla.py293
2 files changed, 294 insertions, 19 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 3b0e8f6cda..9c7fbee838 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -59,27 +59,9 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
],
)
-
-tf_py_test(
- name = "xla_test",
- srcs = ["xla_test.py"],
- additional_deps = [
- ":xla",
- "@six_archive//:six",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:control_flow_util",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- ],
- tags = ["no_pip"],
-)
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 0aae695f92..1e30525159 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -19,17 +19,22 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_op_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
+from tensorflow.python.util import tf_decorator
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
@@ -353,3 +358,291 @@ def _compile_internal(computation, inputs=None):
array_ops.identity(outputs[i], name='output_%d' % i)
for i in xrange(output_arity)
]
+
+
+@contextlib.contextmanager
+def _disable_summary_context():
+ """Enters a context where all summary ops are skipped.
+
+ Summaries are not yet supported in xla.compile(). So we provide this context
+ manager that can skip creating summary ops. This is a temporary workaround due
+ to XLA not supporting summary ops.
+
+ Yields:
+ None.
+ """
+ origional_skip_summary_func = summary_op_util.skip_summary
+ summary_op_util.skip_summary = lambda: True
+
+ try:
+ yield
+ finally:
+ summary_op_util.skip_summary = origional_skip_summary_func
+
+
+class _CapturedObject(object):
+ """A placeholder to capture an object."""
+
+ def __init__(self):
+ self._object = None
+
+ def capture(self, o):
+ if self._object:
+ raise RuntimeError(
+ 'InternalError: _CapturedObject can capture only once. Please file '
+ 'bug.')
+
+ self._object = o
+
+ def get(self):
+ return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+ """Retrieves the Scaffold from `captured_scaffold_fn`."""
+ scaffold_fn = captured_scaffold_fn.get()
+
+ if not scaffold_fn:
+ return None
+
+ scaffold = scaffold_fn()
+ if scaffold is None:
+ raise ValueError(
+ 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+
+ return scaffold
+
+
+class _ModelFnWrapper(object):
+ """_ModelFnWrapper supports executing model_fn with XLA."""
+
+ def __init__(self, function):
+ self._model_fn = function
+
+ def __call__(self, features, labels, mode, params):
+
+ # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
+ # compilation, we use this params['use_tpu'] as a hint. When it is set to
+ # True, model_fn is called without compilation.
+ # Note that this condition isn't accurate for the case of exporting a model.
+ # In that case we should ideally not compile so that user can see detailed
+ # graph. However, we don't have enough information to tell whether model_fn
+ # is being called for export mode or not.
+ # TODO(ycao): Make this condition more accurate when implementing PREDICT
+ # mode.
+ if params.get('use_tpu'):
+ return self._call_model_fn(features, labels, mode, params)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_step, captured_scaffold_fn = self._make_train_step(
+ features, labels, params)
+ with _disable_summary_context():
+ (loss,) = compile(train_step)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ train_op=array_ops.identity(loss),
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
+ self._make_eval_step(features, labels, params))
+ with _disable_summary_context():
+ outputs = compile(eval_step)
+ loss = outputs[0]
+
+ # Calculate eval_metric_ops if eval_metric_fn is set and captured.
+ eval_metric_fn = captured_eval_metric_fn.get()
+ if eval_metric_fn:
+ eval_metric_fn_tensors = outputs[1:]
+ eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
+ else:
+ eval_metric_ops = None
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=eval_metric_ops,
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ else:
+ raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
+ ' supported' % mode)
+
+ def _make_train_step(self, features, labels, params):
+ """Creates a single step of training for xla.compile()."""
+ captured_scaffold_fn = _CapturedObject()
+
+ def train_step():
+ """A single step of training."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ # train_step will be run by xla.compile(). xla.compile() only supports
+ # tensor output while train_op can be either an operation or a tensor.
+ # Even though xla.compile() automatically adds operation-typed train_op as
+ # control dependency of other tensor outputs, it doesn't do so for
+ # tensor-typed train_op. Thus, we need to set it explicitly here.
+ with ops.control_dependencies([estimator_spec.train_op]):
+ return array_ops.identity(estimator_spec.loss)
+
+ return train_step, captured_scaffold_fn
+
+ def _make_eval_step(self, features, labels, params):
+ """Creates a single step of evaluation for xla.compile()."""
+ captured_eval_metric_fn = _CapturedObject()
+ captured_scaffold_fn = _CapturedObject()
+
+ def eval_step():
+ """A single step of evaluation."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ eval_metric_fn = None
+ eval_metric_fn_tensors = []
+ try:
+ if estimator_spec.eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
+ except AttributeError:
+ pass
+
+ # If a dictionary is provided, we need to convert it into a list sorted
+ # according to order of eval_metric_fn positional arguments.
+ if isinstance(eval_metric_fn_tensors, dict):
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+ eval_metric_fn_tensors = [
+ eval_metric_fn_tensors[i] for i in eval_metric_fn_args
+ ]
+
+ captured_eval_metric_fn.capture(eval_metric_fn)
+
+ return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
+
+ return eval_step, captured_eval_metric_fn, captured_scaffold_fn
+
+ def _call_model_fn(self, features, labels, mode, params):
+ """Calls the model_fn with required parameters."""
+ model_fn_args = function_utils.fn_args(self._model_fn)
+ kwargs = {}
+
+ if 'labels' in model_fn_args:
+ kwargs['labels'] = labels
+ elif labels is not None:
+ raise ValueError(
+ 'model_fn does not take labels, but input_fn returns labels.')
+ if 'mode' in model_fn_args:
+ kwargs['mode'] = mode
+
+ if 'params' in model_fn_args:
+ kwargs['params'] = params
+
+ return self._verify_estimator_spec(
+ self._model_fn(features=features, **kwargs))
+
+ def _verify_estimator_spec(self, estimator_spec):
+ """Verifies estimator spec contains correct data."""
+ # TODO(ycao): Implement estimator spec verification for other modes.
+
+ try:
+ if estimator_spec.scaffold:
+ logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
+ '. Please use TPUEstimatorSpec.scaffold_fn instead.')
+ except AttributeError:
+ pass
+
+ try:
+ if estimator_spec.eval_metric_ops:
+ raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
+ 'XLA compilation. Please use '
+ 'TPUEstimatorSpec.eval_metrics instead.')
+ except AttributeError:
+ pass
+
+ if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
+ # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
+ # check that eval_metrics contains eval_metric_fn and
+ # eval_metric_fn_tensors with matching arguments.
+ try:
+ eval_metrics = estimator_spec.eval_metrics
+ except AttributeError:
+ eval_metrics = None
+
+ if eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+
+ if isinstance(eval_metric_fn_tensors, dict):
+ missing_tensors = [
+ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
+ ]
+ additional_tensors = [
+ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
+ ]
+
+ if missing_tensors:
+ raise ValueError('Arguments %s are needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics) but '
+ 'they are not provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ '.' % missing_tensors)
+
+ if additional_tensors:
+ raise ValueError('Arguments %s are provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ ' but they are not needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics).' %
+ additional_tensors)
+
+ return estimator_spec
+
+
+def estimator_model_fn(target_model_fn=None):
+ """estimator_model_fn decorates a model_fn to be compiled for execution.
+
+ Currently only it only works with `TPUEstimator`. If you need to use it with
+ base `Estimator`, please add `tf.enable_resource_variables()` at beginning of
+ your program.
+
+ Example 1, decorating model_fn:
+ ```
+ @xla.estimator_model_fn()
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+
+ est = Estimator(model_fn=model_fn, ...)
+ est.train(...)
+
+ ```
+
+ Example 2, decorator as function:
+ ```
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+ est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
+ est.train(...)
+ ```
+
+ Args:
+ target_model_fn: model_fn to be decorated. This is only needed when
+ decorator is used in function call form (example 2).
+
+ Returns:
+ Decorated target_model_fn.
+ """
+
+ def decorated(function):
+ return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
+
+ return decorated(target_model_fn) if target_model_fn else decorated