diff options
author | Yanan Cao <ycao@google.com> | 2018-09-21 19:24:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 19:27:39 -0700 |
commit | 0695e9ad8fe6f50942c8c18d648aea982541eeae (patch) | |
tree | d46244608d3efa3f795eacca72664cb6ba8267d5 /tensorflow/contrib/compiler | |
parent | 174e782ded74187fa81f034bb3cfedf2b100286d (diff) |
xla.estimator_model_fn can be used to decorate a model_fn written for estimator API in order to compile entire model with XLA.
PiperOrigin-RevId: 214078470
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/xla.py | 293 |
2 files changed, 294 insertions, 19 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 3b0e8f6cda..9c7fbee838 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -59,27 +59,9 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:summary_op_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:model_fn", ], ) - -tf_py_test( - name = "xla_test", - srcs = ["xla_test.py"], - additional_deps = [ - ":xla", - "@six_archive//:six", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - ], - tags = ["no_pip"], -) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 0aae695f92..1e30525159 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -19,17 +19,22 @@ from __future__ import division from __future__ import print_function import collections +import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_op_util from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import function_utils +from tensorflow.python.util import tf_decorator _XLA_COMPILE_ATTR = '_xla_compile_id' _MAX_WARNING_LINES = 5 @@ -353,3 +358,291 @@ def _compile_internal(computation, inputs=None): array_ops.identity(outputs[i], name='output_%d' % i) for i in xrange(output_arity) ] + + +@contextlib.contextmanager +def _disable_summary_context(): + """Enters a context where all summary ops are skipped. + + Summaries are not yet supported in xla.compile(). So we provide this context + manager that can skip creating summary ops. This is a temporary workaround due + to XLA not supporting summary ops. + + Yields: + None. + """ + origional_skip_summary_func = summary_op_util.skip_summary + summary_op_util.skip_summary = lambda: True + + try: + yield + finally: + summary_op_util.skip_summary = origional_skip_summary_func + + +class _CapturedObject(object): + """A placeholder to capture an object.""" + + def __init__(self): + self._object = None + + def capture(self, o): + if self._object: + raise RuntimeError( + 'InternalError: _CapturedObject can capture only once. Please file ' + 'bug.') + + self._object = o + + def get(self): + return self._object + + +def _get_scaffold(captured_scaffold_fn): + """Retrieves the Scaffold from `captured_scaffold_fn`.""" + scaffold_fn = captured_scaffold_fn.get() + + if not scaffold_fn: + return None + + scaffold = scaffold_fn() + if scaffold is None: + raise ValueError( + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') + + return scaffold + + +class _ModelFnWrapper(object): + """_ModelFnWrapper supports executing model_fn with XLA.""" + + def __init__(self, function): + self._model_fn = function + + def __call__(self, features, labels, mode, params): + + # TPUEstimator compiles model_fn when use_tpu=True. To avoid double + # compilation, we use this params['use_tpu'] as a hint. When it is set to + # True, model_fn is called without compilation. + # Note that this condition isn't accurate for the case of exporting a model. + # In that case we should ideally not compile so that user can see detailed + # graph. However, we don't have enough information to tell whether model_fn + # is being called for export mode or not. + # TODO(ycao): Make this condition more accurate when implementing PREDICT + # mode. + if params.get('use_tpu'): + return self._call_model_fn(features, labels, mode, params) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_step, captured_scaffold_fn = self._make_train_step( + features, labels, params) + with _disable_summary_context(): + (loss,) = compile(train_step) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + train_op=array_ops.identity(loss), + scaffold=_get_scaffold(captured_scaffold_fn)) + elif mode == model_fn_lib.ModeKeys.EVAL: + eval_step, captured_eval_metric_fn, captured_scaffold_fn = ( + self._make_eval_step(features, labels, params)) + with _disable_summary_context(): + outputs = compile(eval_step) + loss = outputs[0] + + # Calculate eval_metric_ops if eval_metric_fn is set and captured. + eval_metric_fn = captured_eval_metric_fn.get() + if eval_metric_fn: + eval_metric_fn_tensors = outputs[1:] + eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors) + else: + eval_metric_ops = None + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=eval_metric_ops, + scaffold=_get_scaffold(captured_scaffold_fn)) + else: + raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are' + ' supported' % mode) + + def _make_train_step(self, features, labels, params): + """Creates a single step of training for xla.compile().""" + captured_scaffold_fn = _CapturedObject() + + def train_step(): + """A single step of training.""" + estimator_spec = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, params) + + try: + captured_scaffold_fn.capture(estimator_spec.scaffold_fn) + except AttributeError: + captured_scaffold_fn.capture(None) + + # train_step will be run by xla.compile(). xla.compile() only supports + # tensor output while train_op can be either an operation or a tensor. + # Even though xla.compile() automatically adds operation-typed train_op as + # control dependency of other tensor outputs, it doesn't do so for + # tensor-typed train_op. Thus, we need to set it explicitly here. + with ops.control_dependencies([estimator_spec.train_op]): + return array_ops.identity(estimator_spec.loss) + + return train_step, captured_scaffold_fn + + def _make_eval_step(self, features, labels, params): + """Creates a single step of evaluation for xla.compile().""" + captured_eval_metric_fn = _CapturedObject() + captured_scaffold_fn = _CapturedObject() + + def eval_step(): + """A single step of evaluation.""" + estimator_spec = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.EVAL, params) + + try: + captured_scaffold_fn.capture(estimator_spec.scaffold_fn) + except AttributeError: + captured_scaffold_fn.capture(None) + + eval_metric_fn = None + eval_metric_fn_tensors = [] + try: + if estimator_spec.eval_metrics: + (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics + except AttributeError: + pass + + # If a dictionary is provided, we need to convert it into a list sorted + # according to order of eval_metric_fn positional arguments. + if isinstance(eval_metric_fn_tensors, dict): + eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) + eval_metric_fn_tensors = [ + eval_metric_fn_tensors[i] for i in eval_metric_fn_args + ] + + captured_eval_metric_fn.capture(eval_metric_fn) + + return tuple([estimator_spec.loss] + eval_metric_fn_tensors) + + return eval_step, captured_eval_metric_fn, captured_scaffold_fn + + def _call_model_fn(self, features, labels, mode, params): + """Calls the model_fn with required parameters.""" + model_fn_args = function_utils.fn_args(self._model_fn) + kwargs = {} + + if 'labels' in model_fn_args: + kwargs['labels'] = labels + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') + if 'mode' in model_fn_args: + kwargs['mode'] = mode + + if 'params' in model_fn_args: + kwargs['params'] = params + + return self._verify_estimator_spec( + self._model_fn(features=features, **kwargs)) + + def _verify_estimator_spec(self, estimator_spec): + """Verifies estimator spec contains correct data.""" + # TODO(ycao): Implement estimator spec verification for other modes. + + try: + if estimator_spec.scaffold: + logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation' + '. Please use TPUEstimatorSpec.scaffold_fn instead.') + except AttributeError: + pass + + try: + if estimator_spec.eval_metric_ops: + raise ValueError('EstimatorSpec.eval_metric_ops is not supported with ' + 'XLA compilation. Please use ' + 'TPUEstimatorSpec.eval_metrics instead.') + except AttributeError: + pass + + if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL: + # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics, + # check that eval_metrics contains eval_metric_fn and + # eval_metric_fn_tensors with matching arguments. + try: + eval_metrics = estimator_spec.eval_metrics + except AttributeError: + eval_metrics = None + + if eval_metrics: + (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics + eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) + + if isinstance(eval_metric_fn_tensors, dict): + missing_tensors = [ + i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors + ] + additional_tensors = [ + i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args + ] + + if missing_tensors: + raise ValueError('Arguments %s are needed by metric_fn (first ' + 'element of TPUEstimatorSpec.eval_metrics) but ' + 'they are not provided by evaluation tensors ' + '(second element of TPUEstimatorSpec.eval_metrics)' + '.' % missing_tensors) + + if additional_tensors: + raise ValueError('Arguments %s are provided by evaluation tensors ' + '(second element of TPUEstimatorSpec.eval_metrics)' + ' but they are not needed by metric_fn (first ' + 'element of TPUEstimatorSpec.eval_metrics).' % + additional_tensors) + + return estimator_spec + + +def estimator_model_fn(target_model_fn=None): + """estimator_model_fn decorates a model_fn to be compiled for execution. + + Currently only it only works with `TPUEstimator`. If you need to use it with + base `Estimator`, please add `tf.enable_resource_variables()` at beginning of + your program. + + Example 1, decorating model_fn: + ``` + @xla.estimator_model_fn() + def model_fn(features, labels, mode, params): + ... + return EstimatorSpec(...) + + + est = Estimator(model_fn=model_fn, ...) + est.train(...) + + ``` + + Example 2, decorator as function: + ``` + def model_fn(features, labels, mode, params): + ... + return EstimatorSpec(...) + + est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...) + est.train(...) + ``` + + Args: + target_model_fn: model_fn to be decorated. This is only needed when + decorator is used in function call form (example 2). + + Returns: + Decorated target_model_fn. + """ + + def decorated(function): + return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) + + return decorated(target_model_fn) if target_model_fn else decorated |