aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-09-18 22:33:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 22:37:39 -0700
commit50125bf0d8ee9f47b868211f62cb545c5701a032 (patch)
tree6c65f7bf4af58772bfb1a8cf1941904436514020 /tensorflow/contrib/compiler
parentd7cc73c300b12e7c02507bcfaff146d6c4955f19 (diff)
Add xla.compile(), a low-level API that compiles graph with XLA.
PiperOrigin-RevId: 213574904
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/BUILD3
-rw-r--r--tensorflow/contrib/compiler/xla.py149
2 files changed, 151 insertions, 1 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index d7583be6d8..3b0e8f6cda 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -53,11 +53,14 @@ py_library(
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/compiler/jit:xla_ops_py",
+ "//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
],
)
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 60f5af1662..0aae695f92 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""xla provides experimental xla support API."""
+"""xla is an experimental library that provides XLA support APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@@ -51,6 +55,30 @@ _UNSUPPORTED_OPS = set([
])
+def compile(computation, inputs=None): # pylint: disable=redefined-builtin
+ """Builds an operator that compiles and runs `computation` with XLA.
+
+ Args:
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
+
+ `computation` may return a list of operations and tensors. Tensors must
+ come before operations in the returned list. The return value of
+ `compile` is a list of tensors corresponding to the tensors from the
+ output of `computation`.
+
+ All `Operation`s returned from `computation` will be executed when
+ evaluating any of the returned output tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+
+ Returns:
+ A list of output tensors.
+ """
+ # pylint: disable=protected-access
+ return _compile_internal(computation, inputs)
+
+
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
@@ -206,3 +234,122 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext):
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
+
+
+def _compile_internal(computation, inputs=None):
+ """Builds graph operators that compiles and symbolically executes computation.
+
+ Args:
+ computation: A Python function that builds the computation to compile and
+ execute.
+ inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
+ should match ordering of computation arguments.
+ Returns:
+ A list of output tensors from computation.
+ Raises:
+ ValueError: If any element in computation outputs is neither an operations
+ or a value that can be converted to tensor.
+ TypeError: If `inputs` is not a list or tuple.
+ """
+ if inputs is None:
+ inputs = []
+
+ if not isinstance(inputs, collections.Sequence):
+ raise TypeError('inputs must be a list')
+
+ # Converts inputs to Tensors.
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
+ input_arity = len(inputs)
+
+ arg_error = tpu_function.check_function_argument_count(
+ computation, input_arity, infeed_queue=None)
+ if arg_error is not None:
+ raise TypeError(
+ 'Supplied computation cannot be called with the specified inputs. You '
+ 'specified %d inputs: %s, but the computation needs %s' %
+ (input_arity, str([i.name for i in inputs[0]]), arg_error))
+
+ cluster_name = ops.get_default_graph().unique_name('cluster')
+ pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
+ context = XLACompileContext(name=cluster_name, pivot=pivot)
+ try:
+ context.Enter()
+
+ # Add identity ops so even unused inputs are 'consumed' by the
+ # computation.
+ computation_inputs = [
+ array_ops.identity(x, name='input_{}'.format(i))
+ for i, x in enumerate(inputs)
+ ]
+
+ # Only resource variables work inside an XLA computation, so turn on
+ # resource variables for the computation.
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ vscope.set_use_resource(True)
+
+ outputs = computation(*computation_inputs)
+
+ # Restore variable scope after computation.
+ vscope.set_use_resource(saved_use_resource)
+
+ # If the computation returns `None`, make it an empty tuple.
+ if outputs is None:
+ outputs = tuple()
+ # If the computation only returned one value, make it a tuple.
+ if not isinstance(outputs, collections.Sequence):
+ outputs = (outputs,)
+
+ # Append `no_op` here so that return value of this function always contains
+ # at least one op that can trigger XlaLaunch node.
+ outputs += (control_flow_ops.no_op(),)
+ try:
+ outputs = [
+ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
+ for o in outputs
+ ]
+ except Exception as e:
+ raise ValueError(
+ 'XLA computation function return values must all either be Operations'
+ ' or convertible to Tensors. Got error: "%s"' % str(e))
+
+ # Separates the returned Operations and Tensors.
+ output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
+
+ if outputs != output_tensors + output_operations:
+ raise ValueError(
+ 'XLA computation function must return zero or more Tensor values '
+ 'followed by zero or more Operations.')
+ output_arity = len(output_tensors)
+
+ new_output_tensors = []
+ for t in output_tensors:
+ with ops.device(t.device if t.device else ''):
+ new_output_tensors.append(array_ops.identity(t))
+
+ output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
+ finally:
+ context.report_unsupported_operations()
+ context.Exit()
+
+ outputs = [
+ xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
+ for i in xrange(output_arity)
+ ]
+
+ with ops.control_dependencies(output_operations):
+ if output_arity == 0:
+ # When XLA computation returns only operations and no tensors, a NoOp
+ # dependent on the operations in outputs is returned. Otherwise final
+ # outputs would be empty and there is no way to trigger returned
+ # operations.
+ return control_flow_ops.no_op(name='output_0')
+ else:
+ # Wraps the outputs in identity operators that carries control
+ # dependencies.
+ return [
+ array_ops.identity(outputs[i], name='output_%d' % i)
+ for i in xrange(output_arity)
+ ]