diff options
author | Derek Murray <mrry@google.com> | 2018-10-01 16:09:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 16:16:41 -0700 |
commit | 24333d8e55bdd995089e93122750340bf8d1ddba (patch) | |
tree | c65a77995eca6be1994b1af7c1f902a8798f7398 /tensorflow/compiler | |
parent | 55f561e6740d61b3665594babce4be72ad955bc6 (diff) |
[TF/XLA] Optimize `Encapsulator::GetFunctionNameAttr()`.
The previous version was hitting a very slow path in `GetNodeAttr()`, which is expensive when the named attr is not found. This change inlines the logic of finding the two relevant attrs inside `GetFunctionNameAttr()` and avoids constructing a status object with a serialized `NodeDef` when the attr can't be found.
PiperOrigin-RevId: 215298411
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 15faf31077..d165341f21 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1363,28 +1363,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { - Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no group_attribute. - attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - } - bool has_group_attr = s.ok(); - s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, - outside_compilation_attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no outside_compilation attribute. - outside_compilation_attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - if (!has_group_attr) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); + AttrSlice attrs = node->attrs(); + attr->clear(); + outside_compilation_attr->clear(); + bool found_group_attribute = false; + bool found_outside_compilation_attribute = false; + for (const auto& node_attr : attrs) { + if (node_attr.first == group_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *attr = node_attr.second.s(); + found_group_attribute = true; + } else if (node_attr.first == outside_compilation_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *outside_compilation_attr = node_attr.second.s(); + found_outside_compilation_attribute = true; } + if (found_group_attribute && found_outside_compilation_attribute) break; + } + + if (found_outside_compilation_attribute && !found_group_attribute) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } else { + return Status::OK(); } - return Status::OK(); } bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { |