aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_op_kernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.h')
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h174
1 files changed, 174 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
new file mode 100644
index 0000000000..0c614005be
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+class XlaOpKernelContext;
+
+// Implementations of operators that generate XLA code should usually subclass
+// XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
+// an XlaOpKernel produces and consumes symbolic values during compilation.
+//
+// See the comments in xla_context.h for more details.
+class XlaOpKernel : public OpKernel {
+ public:
+ explicit XlaOpKernel(OpKernelConstruction* construction);
+
+ // Subclasses should implement Compile(), much as standard OpKernels implement
+ // Compute().
+ virtual void Compile(XlaOpKernelContext* context) = 0;
+
+ private:
+ void Compute(OpKernelContext* context) final;
+};
+
+// The context passed to the Compile() method of XlaOpKernel. An
+// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
+// implementing operators that perform symbolic execution as part of the XLA
+// compiler. The key difference is that XlaOpKernelContext produces and consumes
+// data as XLA computations, rather than as standard Tensors. (Under the hood,
+// symbolic execution communicates using special Tensors, but that is an
+// implementation detail that this class hides.)
+class XlaOpKernelContext {
+ public:
+ explicit XlaOpKernelContext(OpKernelContext* context);
+
+ // Returns the XLA ComputationBuilder containing the output of compilation.
+ xla::ComputationBuilder* builder() const;
+
+ // Inputs
+
+ // Returns the number of inputs to the operator.
+ int num_inputs() const { return context_->num_inputs(); }
+
+ // Returns the type of input 'index'.
+ DataType input_type(int index) { return context_->input(index).dtype(); }
+
+ // Returns the shape of input 'index'.
+ TensorShape InputShape(int index);
+
+ // Returns input 'index' as a ComputationDataHandle. Unlike
+ // OpKernelContext::Input returns a symbolic value rather than a concrete
+ // Tensor.
+ const xla::ComputationDataHandle& Input(int index);
+
+ // Returns true if all inputs are the same shape, otherwise sets the
+ // status to a non-OK value and returns false.
+ // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
+ bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list.
+ Status InputList(StringPiece name,
+ std::vector<xla::ComputationDataHandle>* handles,
+ std::vector<TensorShape>* shapes);
+
+ // Helper methods for constant inputs.
+
+ // Evaluates input 'index' and stores it in '*constant_literal'. If the
+ // expression cannot be evaluated, e.g., because it depends on unbound
+ // parameters, returns a non-OK status.
+ Status ConstantInput(int index, xla::Literal* constant_literal);
+
+ // Evaluates input 'index', reshapes it to 'new_shape' if new_shape !=
+ // InputShape(index), and stores it in '*constant_literal'. If the input
+ // cannot be evaluated, e.g., because it depends on unbound parameters,
+ // returns a non-Ok status. If InputShape(index).num_elements() !=
+ // new_shape.num_elements(), returns an error status.
+ Status ConstantInputReshaped(int index, gtl::ArraySlice<int64> new_shape,
+ xla::Literal* constant_literal);
+
+ // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
+ Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+
+ // Converts a constant 1D int32 or int64 tensor into a TensorShape.
+ Status ConstantInputAsShape(int index, TensorShape* shape);
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list.
+ Status ConstantInputList(StringPiece name,
+ std::vector<xla::Literal>* literals);
+
+ // Outputs
+
+ int num_outputs() const { return context_->num_outputs(); }
+ DataType expected_output_dtype(int index) const {
+ return context_->expected_output_dtype(index);
+ }
+
+ // Sets output 'index' to the ComputationDataHandle 'handle'.
+ // All outputs should be set using SetOutput and SetConstantOutput, not
+ // via the underlying OpKernelContext.
+ void SetOutput(int index, const xla::ComputationDataHandle& handle);
+
+ // Sets output 'index' to compile-time constant 'host_tensor', where
+ // 'host_tensor' is a tensor in host memory. It is preferable to use
+ // SetConstantOutput where possible.
+ void SetConstantOutput(int index, const Tensor& host_tensor);
+
+ // Status handling.
+ void SetStatus(const Status& status) { context_->SetStatus(status); }
+ Status status() { return context_->status(); }
+
+ // Helper routines for the OP_REQUIRES macros
+ void CtxFailure(Status s);
+ void CtxFailureWithWarning(Status s);
+
+ // If this kernel invocation is within a function execution,
+ // call_frame() returns the call frame for the function call.
+ FunctionCallFrame* call_frame() const { return context_->call_frame(); }
+
+ FunctionLibraryRuntime* function_library() const {
+ return context_->function_library();
+ }
+
+ const OpKernel& op_kernel() const { return context_->op_kernel(); }
+
+ // Returns the underlying OpKernelContext. Use rarely.
+ OpKernelContext* op_kernel_context() const { return context_; }
+
+ // TODO(phawkins): find a better home for these helpers.
+
+ // Get an XLA lambda to compute Max. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateMax(const DataType type);
+
+ // Get an XLA lambda to compute Add. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateAdd(const DataType type);
+
+ // Get an XLA lambda to compute Sigmoid. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateSigmoid(const DataType type);
+
+ private:
+ OpKernelContext* const context_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_