diff options
author | 2017-05-18 10:38:29 -0700 | |
---|---|---|
committer | 2017-05-18 10:43:12 -0700 | |
commit | 9ad851e54d014532dd3b3c8308396769f9a7aeee (patch) | |
tree | ff777c97939a6f4d9de6ebd7226b19e9ffb2ff2f /tensorflow/core/kernels/remote_fused_graph_execute_utils.cc | |
parent | 7916e22e954fc893e673f74b4088b9e9c3a9be97 (diff) |
Add graph transform rewriter for remote fused graph
PiperOrigin-RevId: 156448934
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.cc')
-rw-r--r-- | tensorflow/core/kernels/remote_fused_graph_execute_utils.cc | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 2174098bde..d0ffcb1064 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -104,6 +104,22 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) { RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES; /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES; +/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO; +/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES; RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar( const string& name, ExecutorBuildFunc executor_build_func) { @@ -797,6 +813,22 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt, shape, subgraph_def)); } + + // sort subgraph_def to align order in graph_def + std::unordered_map<string, int> name_to_id_map; + for (int i = 0; i < graph_def.node_size(); ++i) { + name_to_id_map.emplace(graph_def.node(i).name(), i); + } + std::sort(subgraph_def->mutable_node()->begin(), + subgraph_def->mutable_node()->end(), + [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) { + CHECK(name_to_id_map.count(node0.name()) > 0); + CHECK(name_to_id_map.count(node1.name()) > 0); + const int id0 = name_to_id_map.at(node0.name()); + const int id1 = name_to_id_map.at(node1.name()); + return id0 < id1; + }); + VLOG(1) << DumpGraphDef(*subgraph_def); return Status::OK(); } |