aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-18 10:38:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-18 10:43:12 -0700
commit9ad851e54d014532dd3b3c8308396769f9a7aeee (patch)
treeff777c97939a6f4d9de6ebd7226b19e9ffb2ff2f /tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
parent7916e22e954fc893e673f74b4088b9e9c3a9be97 (diff)
Add graph transform rewriter for remote fused graph
PiperOrigin-RevId: 156448934
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.cc')
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 2174098bde..d0ffcb1064 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -104,6 +104,22 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
+/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
+/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
const string& name, ExecutorBuildFunc executor_build_func) {
@@ -797,6 +813,22 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
shape, subgraph_def));
}
+
+ // sort subgraph_def to align order in graph_def
+ std::unordered_map<string, int> name_to_id_map;
+ for (int i = 0; i < graph_def.node_size(); ++i) {
+ name_to_id_map.emplace(graph_def.node(i).name(), i);
+ }
+ std::sort(subgraph_def->mutable_node()->begin(),
+ subgraph_def->mutable_node()->end(),
+ [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
+ CHECK(name_to_id_map.count(node0.name()) > 0);
+ CHECK(name_to_id_map.count(node1.name()) > 0);
+ const int id0 = name_to_id_map.at(node0.name());
+ const int id1 = name_to_id_map.at(node1.name());
+ return id0 < id1;
+ });
+
VLOG(1) << DumpGraphDef(*subgraph_def);
return Status::OK();
}