aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_cluster_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_cluster_util.cc')
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index a5628b12a2..0a025a1fc0 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
return Status::OK();
}
+gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
+ const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
+ if (attr_value == nullptr) {
+ return gtl::nullopt;
+ }
+ Status s = AttrValueHasType(*attr_value, "string");
+ if (!s.ok()) {
+ return gtl::nullopt;
+ }
+ return attr_value->s();
+}
+
+bool HasResourceInputOrOutput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end() ||
+ std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+void RemoveFromXlaCluster(NodeDef* node_def) {
+ node_def->mutable_attr()->erase(kXlaClusterAttr);
+}
} // namespace tensorflow