diff options
author | 2018-09-18 22:33:38 -0700 | |
---|---|---|
committer | 2018-09-18 22:37:39 -0700 | |
commit | 50125bf0d8ee9f47b868211f62cb545c5701a032 (patch) | |
tree | 6c65f7bf4af58772bfb1a8cf1941904436514020 /tensorflow/contrib/compiler | |
parent | d7cc73c300b12e7c02507bcfaff146d6c4955f19 (diff) |
Add xla.compile(), a low-level API that compiles graph with XLA.
PiperOrigin-RevId: 213574904
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/xla.py | 149 |
2 files changed, 151 insertions, 1 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index d7583be6d8..3b0e8f6cda 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -53,11 +53,14 @@ py_library( srcs = ["xla.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/compiler/jit:xla_ops_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 60f5af1662..0aae695f92 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""xla provides experimental xla support API.""" +"""xla is an experimental library that provides XLA support APIs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.compiler.jit.ops import xla_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat @@ -51,6 +55,30 @@ _UNSUPPORTED_OPS = set([ ]) +def compile(computation, inputs=None): # pylint: disable=redefined-builtin + """Builds an operator that compiles and runs `computation` with XLA. + + Args: + computation: A Python function that builds a computation to apply to the + input. If the function takes n inputs, 'inputs' should be a list of n + tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `compile` is a list of tensors corresponding to the tensors from the + output of `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + + Returns: + A list of output tensors. + """ + # pylint: disable=protected-access + return _compile_internal(computation, inputs) + + class XLACompileContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside an XLA computation cluster. @@ -206,3 +234,122 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if self.GetWhileContext(): return self.GetWhileContext().back_prop return False + + +def _compile_internal(computation, inputs=None): + """Builds graph operators that compiles and symbolically executes computation. + + Args: + computation: A Python function that builds the computation to compile and + execute. + inputs: A list of input tensors or `None` (equivalent to `[]`). Its order + should match ordering of computation arguments. + Returns: + A list of output tensors from computation. + Raises: + ValueError: If any element in computation outputs is neither an operations + or a value that can be converted to tensor. + TypeError: If `inputs` is not a list or tuple. + """ + if inputs is None: + inputs = [] + + if not isinstance(inputs, collections.Sequence): + raise TypeError('inputs must be a list') + + # Converts inputs to Tensors. + inputs = [ops.convert_to_tensor(x) for x in inputs] + input_arity = len(inputs) + + arg_error = tpu_function.check_function_argument_count( + computation, input_arity, infeed_queue=None) + if arg_error is not None: + raise TypeError( + 'Supplied computation cannot be called with the specified inputs. You ' + 'specified %d inputs: %s, but the computation needs %s' % + (input_arity, str([i.name for i in inputs[0]]), arg_error)) + + cluster_name = ops.get_default_graph().unique_name('cluster') + pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') + context = XLACompileContext(name=cluster_name, pivot=pivot) + try: + context.Enter() + + # Add identity ops so even unused inputs are 'consumed' by the + # computation. + computation_inputs = [ + array_ops.identity(x, name='input_{}'.format(i)) + for i, x in enumerate(inputs) + ] + + # Only resource variables work inside an XLA computation, so turn on + # resource variables for the computation. + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + vscope.set_use_resource(True) + + outputs = computation(*computation_inputs) + + # Restore variable scope after computation. + vscope.set_use_resource(saved_use_resource) + + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + output_arity = len(output_tensors) + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + output_tensors = new_output_tensors + context.ExitResult(output_tensors) + finally: + context.report_unsupported_operations() + context.Exit() + + outputs = [ + xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i)) + for i in xrange(output_arity) + ] + + with ops.control_dependencies(output_operations): + if output_arity == 0: + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + return control_flow_ops.no_op(name='output_0') + else: + # Wraps the outputs in identity operators that carries control + # dependencies. + return [ + array_ops.identity(outputs[i], name='output_%d' % i) + for i in xrange(output_arity) + ] |