diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f4a9858ed..2cc603a580 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include <stack> + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,6 +185,9 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { @@ -384,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -448,6 +462,15 @@ class XlaCompiler { std::unordered_map<string, xla::XlaOp> host_compute_control_output_; + // This is used to store <node name, token output> mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; |