aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_context.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_context.h')
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h44
1 files changed, 14 insertions, 30 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 657ead5391..3978baaf63 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -31,6 +30,8 @@ limitations under the License.
namespace tensorflow {
+class XlaOpKernelContext;
+
// The XlaContext is the data structure that holds the state of an XLA
// compilation, that is accessible from OpKernelContexts when compiling a
// subgraph of Ops using XLA.
@@ -55,16 +56,16 @@ class XlaContext : public ResourceBase {
string name;
// Is this a variable?
- bool is_variable;
+ bool is_variable = false;
HandleOrConstant value;
+
+ int64 tensor_array_size = -1;
};
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
- static XlaContext& Get(const XlaOpKernelContext* ctx) {
- return Get(ctx->op_kernel_context());
- }
+ static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
@@ -105,33 +106,16 @@ class XlaContext : public ResourceBase {
bool has_side_effects() const { return has_side_effects_; }
- struct Variable {
- // A descriptive name for the variable, used in error messages.
- string name;
-
- // Current type and value of the variable. Uninitialized variables are
- // represented by a default (zero) handle and type DT_INVALID.
- // While the type of a variable is notionally fixed during execution, when
- // a variable is first initialized we do not yet know its type, so we keep
- // track of its type dynamically.
- DataType type = DT_INVALID;
- xla::ComputationDataHandle value;
-
- // Value of the variable at computation entry. Used to detect which
- // variables have new values that need to be written back.
- xla::ComputationDataHandle initial_value;
- };
-
// Creates a variable with variable `variable_id` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Fails if the variable already exists.
- Status CreateVariable(int variable_id, string name, DataType type,
- const xla::ComputationDataHandle& handle);
+ Status CreateVariable(int arg_num, string name, DataType type,
+ const xla::ComputationDataHandle& handle,
+ XlaVariable** variable);
- // Retrieves variable `variable_id`. Fails if the variable does not exist.
- Status GetVariable(int variable_id, Variable** variable);
-
- const std::unordered_map<int, Variable>& variables() { return variables_; }
+ const std::vector<std::unique_ptr<XlaVariable>>& variables() {
+ return variables_;
+ }
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -182,8 +166,8 @@ class XlaContext : public ResourceBase {
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;
- // Map from variable ID to the current value of each variable.
- std::unordered_map<int, Variable> variables_;
+ // Holds ownership of variables. The variables are not ordered.
+ std::vector<std::unique_ptr<XlaVariable>> variables_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;