aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/tf2xla_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/tf2xla_util.h')
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h63
1 files changed, 62 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index dcddef8418..6065d0bb9a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <unordered_map>
-#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -60,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
+// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
+// a function call, or the node has function attrs).
+class AssociatedFunctionInfo {
+ public:
+ enum AssociatedFunctionType {
+ kFunctionCallNode = 0,
+ kFunctionAttr = 1,
+ };
+
+ // The node is a function call.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
+ : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
+
+ // The function is an attr of the node.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
+ const string& attr_name)
+ : type_(kFunctionAttr),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
+ AssociatedFunctionType type() const { return type_; }
+
+ const string& func_name() const { return func_name_; }
+
+ const string& attr_name() const { return attr_name_; }
+
+ const AttrValueMap& attrs() const { return attrs_; }
+
+ private:
+ // Available for all instances.
+ AssociatedFunctionType type_;
+ string func_name_;
+ AttrValueMap attrs_;
+
+ // Only available if the function is defined in an attr.
+ string attr_name_;
+};
+
+// Returns if the NodeDef has associated function.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr);
+
+// Gets functions associated with the node. Current cases:
+// 1. For function call node, its function name;
+// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr);
+
+// Changes associated functions for the node. Current cases:
+// 1. For function call node, creates a new node with the new function name and
+// remove the old node;
+// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name);
+
+// Attribute to mark nodes to be executed on host.
+extern const char kXlaOutsideCompilationAttrName[];
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_