diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 49 |
1 files changed, 35 insertions, 14 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca..71990b57d9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -66,16 +68,26 @@ class XlaOpKernelContext { // Returns the number of inputs to the operator. int num_inputs() const { return context_->num_inputs(); } - // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + // Returns the type of input `index`. + DataType input_type(int index) const; - // Returns the shape of input 'index'. + // Returns the type of input `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); + + // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns input 'index' as a XlaOp. Unlike + // Returns the shape of input `name`. + TensorShape InputShape(StringPiece name); + + // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); + // Returns input `name` as a XlaOp. + const xla::XlaOp& Input(StringPiece name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -90,13 +102,13 @@ class XlaOpKernelContext { // Helper methods for constant inputs. - // Evaluates input 'index' and stores it in '*constant_literal'. If the + // Evaluates input `index` and stores it in `*constant_literal`. If the // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - // Evaluates input 'index', reshapes it to 'new_shape' if new_shape != - // InputShape(index), and stores it in '*constant_literal'. If the input + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. @@ -131,17 +143,17 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the XlaOp 'handle'. + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. void SetOutput(int index, const xla::XlaOp& handle); - // Sets output 'index' to compile-time constant 'host_tensor', where - // 'host_tensor' is a tensor in host memory. It is preferable to use + // Sets output `index` to compile-time constant `host_tensor`, where + // `host_tensor` is a tensor in host memory. It is preferable to use // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output 'index' to an invalid value. + // Sets output `index` to an invalid value. // Any subsequent attempt to consume this output will cause an error. void SetInvalidOutput(int index); @@ -151,10 +163,10 @@ class XlaOpKernelContext { // Variables - // Sets '*resource' to the resource associated with input `index`. + // Sets `*resource` to the resource associated with input `index`. Status GetResourceInput(int index, XlaResource** resource); - // Sets output 'index' to be a reference to resource 'resource'. + // Sets output `index` to be a reference to resource `resource`. void SetResourceOutput(int index, XlaResource* resource); // Sets `*type` and `*shape` to the current type and shape of a variable's @@ -163,17 +175,23 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value); + // Reads the current value of the resouce variable referred to by input + // `name`. + Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); + // Assigns the value `handle` to the variable referenced by input `name`. + Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -221,6 +239,9 @@ class XlaOpKernelContext { const xla::XlaComputation* GetOrCreateMul(const DataType type); private: + // Returns the tensor of input `name`. + const Tensor& GetInputTensorByName(StringPiece name); + OpKernelContext* const context_; }; |