diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_context.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 44 |
1 files changed, 14 insertions, 30 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 657ead5391..3978baaf63 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -21,7 +21,6 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,6 +30,8 @@ limitations under the License. namespace tensorflow { +class XlaOpKernelContext; + // The XlaContext is the data structure that holds the state of an XLA // compilation, that is accessible from OpKernelContexts when compiling a // subgraph of Ops using XLA. @@ -55,16 +56,16 @@ class XlaContext : public ResourceBase { string name; // Is this a variable? - bool is_variable; + bool is_variable = false; HandleOrConstant value; + + int64 tensor_array_size = -1; }; // Retrieves the XlaContext of the current compilation. static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); - } + static XlaContext& Get(const XlaOpKernelContext* ctx); // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, @@ -105,33 +106,16 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - struct Variable { - // A descriptive name for the variable, used in error messages. - string name; - - // Current type and value of the variable. Uninitialized variables are - // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep - // track of its type dynamically. - DataType type = DT_INVALID; - xla::ComputationDataHandle value; - - // Value of the variable at computation entry. Used to detect which - // variables have new values that need to be written back. - xla::ComputationDataHandle initial_value; - }; - // Creates a variable with variable `variable_id` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. // Fails if the variable already exists. - Status CreateVariable(int variable_id, string name, DataType type, - const xla::ComputationDataHandle& handle); + Status CreateVariable(int arg_num, string name, DataType type, + const xla::ComputationDataHandle& handle, + XlaVariable** variable); - // Retrieves variable `variable_id`. Fails if the variable does not exist. - Status GetVariable(int variable_id, Variable** variable); - - const std::unordered_map<int, Variable>& variables() { return variables_; } + const std::vector<std::unique_ptr<XlaVariable>>& variables() { + return variables_; + } // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -182,8 +166,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Map from variable ID to the current value of each variable. - std::unordered_map<int, Variable> variables_; + // Holds ownership of variables. The variables are not ordered. + std::vector<std::unique_ptr<XlaVariable>> variables_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map<DataType, xla::Computation>; |