aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-08-27 13:48:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 13:58:07 -0700
commit59f3c57182fac4d745bb01f3976bb9832c06333d (patch)
tree82c703daccdf80cb7cfad8cef90093f091243ed5 /tensorflow/contrib/compiler
parent0f1b3bcf48eaaca4dccdf2d3208b0305b1c6056b (diff)
[TF/XLA] Add XLACompileContext that marks ops inside for XLA compilation.
PiperOrigin-RevId: 210424333
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r--tensorflow/contrib/compiler/BUILD34
-rw-r--r--tensorflow/contrib/compiler/xla.py208
-rw-r--r--tensorflow/contrib/compiler/xla_test.py180
3 files changed, 422 insertions, 0 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index bcee0b04c8..d7583be6d8 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -8,6 +8,7 @@ package_group(
packages = ["//tensorflow/..."],
)
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
@@ -46,3 +47,36 @@ cuda_py_test(
],
xla_enabled = True,
)
+
+py_library(
+ name = "xla",
+ srcs = ["xla.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+tf_py_test(
+ name = "xla_test",
+ srcs = ["xla_test.py"],
+ additional_deps = [
+ ":xla",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ ],
+ tags = ["no_pip"],
+)
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
new file mode 100644
index 0000000000..60f5af1662
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla.py
@@ -0,0 +1,208 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""xla provides experimental xla support API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+_XLA_COMPILE_ATTR = '_xla_compile_id'
+_MAX_WARNING_LINES = 5
+
+# Operations that indicate some error in the users graph. For example, XLA
+# computation should not have any Placeholder op.
+_BLACKLISTED_OPS = set([
+ 'Placeholder',
+])
+
+# XLA doesn't currently support reading of intermediate tensors, thus some ops
+# are not supported.
+_UNSUPPORTED_OPS = set([
+ 'AudioSummary',
+ 'AudioSummaryV2',
+ 'HistogramSummary',
+ 'ImageSummary',
+ 'MergeSummary',
+ 'Print',
+ 'ScalarSummary',
+ 'TensorSummary',
+ 'TensorSummaryV2',
+])
+
+
+class XLACompileContext(control_flow_ops.XLAControlFlowContext):
+ """A `ControlFlowContext` for nodes inside an XLA computation cluster.
+
+ THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
+
+ The primary role of `XLACompileContext` is to mark operators inside a
+ xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
+ a unique name.
+
+ `ControlFlowContext` is used to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a xla.compile() block, the
+ `ResourceVariable` implementation can use
+ `with ops.control_dependencies(None)` to build the variable's definition
+ outside the compiled computation.
+ """
+
+ def __init__(self, name, pivot):
+ """Builds a new XLACompileContext.
+
+ Args:
+ name: a unique name for the context, used to populate the
+ `_xla_compile_id` attribute.
+ pivot: a pivot node. Nodes in the XLACompileContext that do not have any
+ inputs will have a control dependency on the pivot node. This ensures
+ that nodes are correctly included in any enclosing control flow
+ contexts.
+ """
+ super(XLACompileContext, self).__init__()
+ self._name = name
+ self._name_as_bytes = compat.as_bytes(name)
+ self._unsupported_ops = []
+ self._pivot = pivot
+
+ def report_unsupported_operations(self):
+ if self._unsupported_ops:
+ op_str = '\n'.join([
+ ' %s (%s)' % (op.type, op.name)
+ for op in self._unsupported_ops[:_MAX_WARNING_LINES]
+ ])
+ logging.warning('%d unsupported operations found: \n%s',
+ len(self._unsupported_ops), op_str)
+ if len(self._unsupported_ops) > _MAX_WARNING_LINES:
+ logging.warning('... and %d more',
+ len(self._unsupported_ops) - _MAX_WARNING_LINES)
+
+ def AddOp(self, op):
+ """Create op in XLACompileContext and notifies outer context recursively."""
+ # pylint: disable=protected-access
+ if op.type in _BLACKLISTED_OPS:
+ logging.error(
+ 'Operation of type %s (%s) is not supported in XLA. Execution will '
+ 'fail if this op is used in the graph. ', op.type, op.name)
+
+ # TODO(ycao): Automatically disable summaries instead of reporting them.
+ if op.type in _UNSUPPORTED_OPS:
+ self._unsupported_ops.append(op)
+
+ if any(x.dtype._is_ref_dtype for x in op.inputs):
+ raise NotImplementedError(
+ 'Non-resource Variables are not supported inside XLA computations '
+ '(operator name: %s)' % op.name)
+
+ if _XLA_COMPILE_ATTR in op.node_def.attr:
+ raise ValueError('XLA compiled computations cannot be nested, (operator '
+ 'name: %s)' % op.name)
+
+ op._set_attr(
+ _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
+
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
+
+ # Remove any control edges from outer control flow contexts. These may cause
+ # mismatched frame errors. An example is when one of op's inputs is
+ # generated in a different While control flow context.
+ (internal_control_inputs,
+ external_control_inputs) = self._RemoveExternalControlEdges(op)
+
+ if not op.inputs:
+ # Add a control edge from the control pivot to this op.
+ if not internal_control_inputs:
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot)
+ # pylint: enable=protected-access
+ else:
+ for index in xrange(len(op.inputs)):
+ x = op.inputs[index]
+ real_x = self.AddValue(x)
+ if real_x != x:
+ op._update_input(index, real_x) # pylint: disable=protected-access
+
+ if external_control_inputs:
+ # Use an identity to pull control inputs as data inputs. Note that we
+ # ignore ops which don't have outputs. TODO(phawkins): fix that.
+ with ops.control_dependencies(None):
+ self.Enter()
+ external_control_inputs = [
+ array_ops.identity(x.outputs[0]).op
+ for x in external_control_inputs
+ if x.outputs
+ ]
+ self.Exit()
+ # pylint: disable=protected-access
+ op._add_control_inputs(external_control_inputs)
+ # pylint: enable=protected-access
+
+ # Mark op's outputs as seen by this context and any outer contexts.
+ output_names = [x.name for x in op.outputs]
+ context = self
+ while context is not None:
+ # pylint: disable=protected-access
+ context._values.update(output_names)
+ context = context._outer_context
+ # pylint: enable=protected-access
+
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ """Add `val` to the current context and its outer context recursively."""
+ if val.name in self._values:
+ # Use the real value if it comes from outer context.
+ result = self._external_values.get(val.name)
+ return val if result is None else result
+
+ result = val
+ self._values.add(val.name)
+ if self._outer_context:
+ result = self._outer_context.AddValue(val)
+ self._values.add(result.name)
+
+ self._external_values[val.name] = result
+
+ return result
+
+ def AddInnerOp(self, op):
+ self.AddOp(op)
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ @property
+ def grad_state(self):
+ # Define the gradient loop state associated with the XLACompileContext to
+ # be None as the XLACompileContext does not get nested nor does the
+ # grad_state outside the XLACompileContext affect the graph inside so the
+ # grad_state should be as if this is the top-level gradient state.
+ return None
+
+ @property
+ def back_prop(self):
+ """Forwards to the enclosing while context, if any."""
+ if self.GetWhileContext():
+ return self.GetWhileContext().back_prop
+ return False
diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py
new file mode 100644
index 0000000000..a306b56f63
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla_test.py
@@ -0,0 +1,180 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for contrib.compiler.xla."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.compiler import xla
+from tensorflow.python import summary
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class XLACompileContextTest(test.TestCase):
+
+ def create_test_xla_compile_context(self):
+ computation_name = ops.get_default_graph().unique_name('computation')
+ pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
+ return xla.XLACompileContext(name=computation_name, pivot=pivot)
+
+ def test_report_unsupported_operations(self):
+ """Tests that unsupported operations are detected."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ dummy_tensor = constant_op.constant(1.1)
+ audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
+ histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
+ image_summary = summary.image('image_summary', dummy_tensor)
+ scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
+ tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor)
+ summary.merge(
+ [
+ audio_summary, histogram_summary, image_summary, scalar_summary,
+ tensor_summary
+ ],
+ name='merge_summary')
+ logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
+ context.Exit()
+
+ unsupported_ops_names = [op.name for op in context._unsupported_ops]
+ self.assertEqual(unsupported_ops_names, [
+ u'audio_summary', u'histogram_summary', u'image_summary',
+ u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
+ u'print_op'
+ ])
+
+ def test_resource_variable(self):
+ """Tests that resource variable usage is allowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=True)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_non_resource_variable_error(self):
+ """Tests that non-resource variable usage is disallowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=False)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with self.assertRaisesRegexp(
+ NotImplementedError, 'Non-resource Variables are not supported inside '
+ r'XLA computations \(operator name: Assign\)'):
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_nested_xla_compile_error(self):
+ """Tests that nested XLA computation leads to fatal error."""
+ context1 = self.create_test_xla_compile_context()
+ context1.Enter()
+
+ context2 = self.create_test_xla_compile_context()
+ context2.Enter()
+ with self.assertRaisesRegexp(ValueError,
+ 'XLA compiled computations cannot be nested'):
+ constant_op.constant(1)
+ context2.Exit()
+ context1.Exit()
+
+ def test_xla_compile_attr(self):
+ """Tests that ops are tagged with XLA compile ID attribute."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertIn('_xla_compile_id', op.op.node_def.attr)
+
+ def test_op_without_input(self):
+ """Tests that ops without inputs depend on pivot correctly."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(context._pivot, op.op.control_inputs)
+
+ def test_external_control_edges(self):
+ """Tests that external control edges are handled correctly."""
+ i = constant_op.constant(1)
+ op1 = constant_op.constant(1)
+
+ with ops.control_dependencies([op1]):
+ op2 = constant_op.constant(1)
+ self.assertIn(op1.op, op2.op.control_inputs)
+
+ def while_body(i):
+ del i # unused
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with ops.control_dependencies([op1]):
+ op3 = constant_op.constant(1)
+ context.Exit()
+ self.assertNotIn(op1.op, op3.op.control_inputs)
+ return op3
+
+ control_flow_ops.while_loop(
+ cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
+
+ def test_op_output_marked_as_seen(self):
+ """Tests that any op output is marked as seen in context."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(op.name, context._values)
+
+ def testOpIsInContext(self):
+ """Tests that XLACompileContext is recognized as an XLA context."""
+ op1 = constant_op.constant(1)
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op2 = constant_op.constant(2)
+ context.Exit()
+ self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
+ self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
+
+ def testOpPreventFeeding(self):
+ """Tests that ops created inside XLACompileContext can not be fed."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_feedable(op.op))
+
+ def testOpPreventFetching(self):
+ """Tests that ops created inside XLACompileContext can not be fetched."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_fetchable(op.op))
+
+
+if __name__ == '__main__':
+ test.main()