aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc')
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ae7a22f451..e0632ff7e4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+void SortControlInputs(GraphDef* gdef) {
+ int64 num_nodes = gdef->node_size();
+ for (int64 i = 0; i < num_nodes; ++i) {
+ NodeDef* node = gdef->mutable_node(i);
+ // Stable sort control inputs and leave the order of data inputs unchanged.
+ std::stable_sort(node->mutable_input()->begin(),
+ node->mutable_input()->end(),
+ [](const string& a, const string& b) {
+ bool a_is_control = absl::StartsWith(a, "^");
+ bool b_is_control = absl::StartsWith(b, "^");
+ return (!a_is_control && b_is_control) ||
+ (a_is_control && b_is_control && a < b);
+ });
+ }
+}
+
namespace {
bool AreAllParentsGuaranteedConst(