aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.h')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f4a9858ed..2cc603a580 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include <stack>
+
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
+
+ // Argument is an XLA token.
+ kToken,
};
Kind kind = kInvalid;
@@ -179,6 +185,9 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
+
+ // True when we should add XLA input & output to the graph/function.
+ bool add_token_input_output = false;
};
struct OutputDescription {
@@ -384,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
+ void PushNodeTokenMapping();
+ Status PopNodeTokenMapping();
+ Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+ xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@@ -448,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
+ // This is used to store <node name, token output> mapping. Side-effecting
+ // ops call SetNodeToken() to record its token output, so later side-effecting
+ // ops can use GetNodeToken() to get it and use it as token input.
+ //
+ // It's a stack because we need a mapping like this for each level of nested
+ // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+ // stack, and pop the mapping before returning.
+ std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};