diff options
author | Yanan Cao <ycao@google.com> | 2018-08-27 13:48:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 13:58:07 -0700 |
commit | 59f3c57182fac4d745bb01f3976bb9832c06333d (patch) | |
tree | 82c703daccdf80cb7cfad8cef90093f091243ed5 /tensorflow/contrib/compiler | |
parent | 0f1b3bcf48eaaca4dccdf2d3208b0305b1c6056b (diff) |
[TF/XLA] Add XLACompileContext that marks ops inside for XLA compilation.
PiperOrigin-RevId: 210424333
Diffstat (limited to 'tensorflow/contrib/compiler')
-rw-r--r-- | tensorflow/contrib/compiler/BUILD | 34 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/xla.py | 208 | ||||
-rw-r--r-- | tensorflow/contrib/compiler/xla_test.py | 180 |
3 files changed, 422 insertions, 0 deletions
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index bcee0b04c8..d7583be6d8 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -8,6 +8,7 @@ package_group( packages = ["//tensorflow/..."], ) +load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -46,3 +47,36 @@ cuda_py_test( ], xla_enabled = True, ) + +py_library( + name = "xla", + srcs = ["xla.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/estimator:model_fn", + ], +) + +tf_py_test( + name = "xla_test", + srcs = ["xla_test.py"], + additional_deps = [ + ":xla", + "@six_archive//:six", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], + tags = ["no_pip"], +) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py new file mode 100644 index 0000000000..60f5af1662 --- /dev/null +++ b/tensorflow/contrib/compiler/xla.py @@ -0,0 +1,208 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""xla provides experimental xla support API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_XLA_COMPILE_ATTR = '_xla_compile_id' +_MAX_WARNING_LINES = 5 + +# Operations that indicate some error in the users graph. For example, XLA +# computation should not have any Placeholder op. +_BLACKLISTED_OPS = set([ + 'Placeholder', +]) + +# XLA doesn't currently support reading of intermediate tensors, thus some ops +# are not supported. +_UNSUPPORTED_OPS = set([ + 'AudioSummary', + 'AudioSummaryV2', + 'HistogramSummary', + 'ImageSummary', + 'MergeSummary', + 'Print', + 'ScalarSummary', + 'TensorSummary', + 'TensorSummaryV2', +]) + + +class XLACompileContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside an XLA computation cluster. + + THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. + + The primary role of `XLACompileContext` is to mark operators inside a + xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is + a unique name. + + `ControlFlowContext` is used to perform the annotation since it integrates + with Tensorflow constructs like ResourceVariables. For example, if a + `ResourceVariable` is constructed inside a xla.compile() block, the + `ResourceVariable` implementation can use + `with ops.control_dependencies(None)` to build the variable's definition + outside the compiled computation. + """ + + def __init__(self, name, pivot): + """Builds a new XLACompileContext. + + Args: + name: a unique name for the context, used to populate the + `_xla_compile_id` attribute. + pivot: a pivot node. Nodes in the XLACompileContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ + super(XLACompileContext, self).__init__() + self._name = name + self._name_as_bytes = compat.as_bytes(name) + self._unsupported_ops = [] + self._pivot = pivot + + def report_unsupported_operations(self): + if self._unsupported_ops: + op_str = '\n'.join([ + ' %s (%s)' % (op.type, op.name) + for op in self._unsupported_ops[:_MAX_WARNING_LINES] + ]) + logging.warning('%d unsupported operations found: \n%s', + len(self._unsupported_ops), op_str) + if len(self._unsupported_ops) > _MAX_WARNING_LINES: + logging.warning('... and %d more', + len(self._unsupported_ops) - _MAX_WARNING_LINES) + + def AddOp(self, op): + """Create op in XLACompileContext and notifies outer context recursively.""" + # pylint: disable=protected-access + if op.type in _BLACKLISTED_OPS: + logging.error( + 'Operation of type %s (%s) is not supported in XLA. Execution will ' + 'fail if this op is used in the graph. ', op.type, op.name) + + # TODO(ycao): Automatically disable summaries instead of reporting them. + if op.type in _UNSUPPORTED_OPS: + self._unsupported_ops.append(op) + + if any(x.dtype._is_ref_dtype for x in op.inputs): + raise NotImplementedError( + 'Non-resource Variables are not supported inside XLA computations ' + '(operator name: %s)' % op.name) + + if _XLA_COMPILE_ATTR in op.node_def.attr: + raise ValueError('XLA compiled computations cannot be nested, (operator ' + 'name: %s)' % op.name) + + op._set_attr( + _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) + + op.graph.prevent_feeding(op) + op.graph.prevent_fetching(op) + + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. An example is when one of op's inputs is + # generated in a different While control flow context. + (internal_control_inputs, + external_control_inputs) = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not internal_control_inputs: + # pylint: disable=protected-access + op._add_control_input(self._pivot) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_control_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_control_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + """Add `val` to the current context and its outer context recursively.""" + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + + result = val + self._values.add(val.name) + if self._outer_context: + result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + + return result + + def AddInnerOp(self, op): + self.AddOp(op) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + @property + def grad_state(self): + # Define the gradient loop state associated with the XLACompileContext to + # be None as the XLACompileContext does not get nested nor does the + # grad_state outside the XLACompileContext affect the graph inside so the + # grad_state should be as if this is the top-level gradient state. + return None + + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py new file mode 100644 index 0000000000..a306b56f63 --- /dev/null +++ b/tensorflow/contrib/compiler/xla_test.py @@ -0,0 +1,180 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for contrib.compiler.xla.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.compiler import xla +from tensorflow.python import summary +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import summary_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class XLACompileContextTest(test.TestCase): + + def create_test_xla_compile_context(self): + computation_name = ops.get_default_graph().unique_name('computation') + pivot = control_flow_ops.no_op(name=computation_name + '/pivot') + return xla.XLACompileContext(name=computation_name, pivot=pivot) + + def test_report_unsupported_operations(self): + """Tests that unsupported operations are detected.""" + context = self.create_test_xla_compile_context() + context.Enter() + dummy_tensor = constant_op.constant(1.1) + audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5) + histogram_summary = summary.histogram('histogram_summary', dummy_tensor) + image_summary = summary.image('image_summary', dummy_tensor) + scalar_summary = summary.scalar('scalar_summary', dummy_tensor) + tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + summary.merge( + [ + audio_summary, histogram_summary, image_summary, scalar_summary, + tensor_summary + ], + name='merge_summary') + logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op') + context.Exit() + + unsupported_ops_names = [op.name for op in context._unsupported_ops] + self.assertEqual(unsupported_ops_names, [ + u'audio_summary', u'histogram_summary', u'image_summary', + u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary', + u'print_op' + ]) + + def test_resource_variable(self): + """Tests that resource variable usage is allowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=True) + + context = self.create_test_xla_compile_context() + context.Enter() + state_ops.assign(a, a + 1) + context.Exit() + + def test_non_resource_variable_error(self): + """Tests that non-resource variable usage is disallowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=False) + + context = self.create_test_xla_compile_context() + context.Enter() + with self.assertRaisesRegexp( + NotImplementedError, 'Non-resource Variables are not supported inside ' + r'XLA computations \(operator name: Assign\)'): + state_ops.assign(a, a + 1) + context.Exit() + + def test_nested_xla_compile_error(self): + """Tests that nested XLA computation leads to fatal error.""" + context1 = self.create_test_xla_compile_context() + context1.Enter() + + context2 = self.create_test_xla_compile_context() + context2.Enter() + with self.assertRaisesRegexp(ValueError, + 'XLA compiled computations cannot be nested'): + constant_op.constant(1) + context2.Exit() + context1.Exit() + + def test_xla_compile_attr(self): + """Tests that ops are tagged with XLA compile ID attribute.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertIn('_xla_compile_id', op.op.node_def.attr) + + def test_op_without_input(self): + """Tests that ops without inputs depend on pivot correctly.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(context._pivot, op.op.control_inputs) + + def test_external_control_edges(self): + """Tests that external control edges are handled correctly.""" + i = constant_op.constant(1) + op1 = constant_op.constant(1) + + with ops.control_dependencies([op1]): + op2 = constant_op.constant(1) + self.assertIn(op1.op, op2.op.control_inputs) + + def while_body(i): + del i # unused + context = self.create_test_xla_compile_context() + context.Enter() + with ops.control_dependencies([op1]): + op3 = constant_op.constant(1) + context.Exit() + self.assertNotIn(op1.op, op3.op.control_inputs) + return op3 + + control_flow_ops.while_loop( + cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i]) + + def test_op_output_marked_as_seen(self): + """Tests that any op output is marked as seen in context.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(op.name, context._values) + + def testOpIsInContext(self): + """Tests that XLACompileContext is recognized as an XLA context.""" + op1 = constant_op.constant(1) + context = self.create_test_xla_compile_context() + context.Enter() + op2 = constant_op.constant(2) + context.Exit() + self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) + self.assertTrue(control_flow_util.IsInXLAContext(op2.op)) + + def testOpPreventFeeding(self): + """Tests that ops created inside XLACompileContext can not be fed.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_feedable(op.op)) + + def testOpPreventFetching(self): + """Tests that ops created inside XLACompileContext can not be fetched.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_fetchable(op.op)) + + +if __name__ == '__main__': + test.main() |