diff options
author | 2018-04-22 09:26:15 -0700 | |
---|---|---|
committer | 2018-04-22 09:28:37 -0700 | |
commit | d481f07549470b4a03b41f9bb588d7f7ddc85082 (patch) | |
tree | 48270727c506275de83e39e562fb7e70fe6be17c | |
parent | 522e20ef9cff8a7a49322c6442d940aa556222c0 (diff) |
Remove proto header include in core/kernels.
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so import
PiperOrigin-RevId: 193843351
11 files changed, 52 insertions, 45 deletions
diff --git a/tensorflow/core/framework/remote_fused_graph_execute_info.proto b/tensorflow/core/framework/remote_fused_graph_execute_info.proto index 389a08ac2f..946da40d0e 100644 --- a/tensorflow/core/framework/remote_fused_graph_execute_info.proto +++ b/tensorflow/core/framework/remote_fused_graph_execute_info.proto @@ -14,14 +14,6 @@ import "tensorflow/core/framework/types.proto"; // not valid across executions, but can be serialized back and forth from within // a single run. message RemoteFusedGraphExecuteInfo { - enum NodeType { - UNUSED = 0; - GRAPH_INPUT = 1; - GRAPH_OUTPUT = 2; - FUSED_NODE = 3; - BORDER_INPUT = 4; - BORDER_OUTPUT = 5; - } message TensorShapeTypeProto { DataType dtype = 1; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7ef15da143..f7f6a9b505 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5925,6 +5925,7 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index 66d24d171d..3810cbe5b5 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h" #include "tensorflow/core/framework/graph_transfer_info.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" #include "tensorflow/core/kernels/hexagon/soc_interface.h" diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 5fb6b9247f..d53977703e 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -30,6 +30,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp #include <memory> #include "tensorflow/core/framework/graph_transfer_info.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h" diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h index eb6b64da58..6072412689 100644 --- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h +++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h @@ -16,13 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ #define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_ -#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { +class GraphDef; +class RemoteFusedGraphExecuteInfo; + class IRemoteFusedGraphExecutor { public: using TensorAllocatorFunc = std::function<Tensor*(const TensorShape& shape)>; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index e2709c117d..cc4d9a49a0 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -20,7 +20,9 @@ limitations under the License. #include <utility> #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -1125,46 +1127,43 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (size_t i = 0; i < inputs.size(); ++i) { if (IsSameNodeName(node_def, inputs.at(i), &tid)) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT, - tid.second, i, remote_graph_executor_name, + attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i, + remote_graph_executor_name, remote_fused_graph_node_name); } } for (size_t i = 0; i < outputs.size(); ++i) { if (IsSameNodeName(node_def, outputs.at(i), &tid)) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT, - tid.second, i); + attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i); } } for (const string& fused_node_name : fused_node_names) { if (fused_node_name == node_def.name()) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); + attr_str += BuildNodeTypeAttr(FUSED_NODE); } } for (const string& fused_node_name : fused_nodes_filtered_by_op_types) { if (fused_node_name == node_def.name()) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); + attr_str += BuildNodeTypeAttr(FUSED_NODE); } } for (size_t i = 0; i < border_inputs.size(); ++i) { if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT, - tid.second, i); + attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i); } } for (size_t i = 0; i < border_outputs.size(); ++i) { if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) { AppendDeliminator(&attr_str); - attr_str += BuildNodeTypeAttr( - RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i); + attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i); } } if (attr_str.empty()) { - attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED); + attr_str += BuildNodeTypeAttr(UNUSED); } AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def); } @@ -1200,14 +1199,14 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( } int node_type_int; CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0); - const RemoteFusedGraphExecuteInfo::NodeType node_type = - static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int); + const RemoteFusedGraphNodeType node_type = + static_cast<RemoteFusedGraphNodeType>(node_type_int); const string& name = node_def.name(); int port; int index; switch (node_type) { - case RemoteFusedGraphExecuteInfo::GRAPH_INPUT: + case GRAPH_INPUT: VLOG(2) << "Graph input: " << name; CHECK_EQ(5, attr.size()); CHECK(strings::safe_strto32(attr.at(1), &port)); @@ -1224,33 +1223,33 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( return Status::OK(); } break; - case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT: + case GRAPH_OUTPUT: VLOG(2) << "Graph output: " << name; CHECK_EQ(3, attr.size()); CHECK(strings::safe_strto32(attr.at(1), &port)); CHECK(strings::safe_strto32(attr.at(2), &index)); output_map.emplace(index, strings::StrCat(name, ":", port)); break; - case RemoteFusedGraphExecuteInfo::FUSED_NODE: + case FUSED_NODE: VLOG(2) << "Fused node: " << name; CHECK_EQ(1, attr.size()); fused_node_names.emplace(name); break; - case RemoteFusedGraphExecuteInfo::BORDER_INPUT: + case BORDER_INPUT: VLOG(2) << "Border input: " << name; CHECK_EQ(3, attr.size()); CHECK(strings::safe_strto32(attr.at(1), &port)); CHECK(strings::safe_strto32(attr.at(2), &index)); border_input_map.emplace(index, strings::StrCat(name, ":", port)); break; - case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT: + case BORDER_OUTPUT: VLOG(2) << "Border output: " << name; CHECK_EQ(3, attr.size()); CHECK(strings::safe_strto32(attr.at(1), &port)); CHECK(strings::safe_strto32(attr.at(2), &index)); border_output_map.emplace(index, strings::StrCat(name, ":", port)); break; - case RemoteFusedGraphExecuteInfo::UNUSED: + case UNUSED: // do nothing break; default: @@ -1461,20 +1460,19 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( } /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, - const int index, const string& executor_name, const string& node_name) { + const RemoteFusedGraphNodeType node_type, const int port, const int index, + const string& executor_name, const string& node_name) { return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index, ",", executor_name, ",", node_name); } /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, - const int index) { + const RemoteFusedGraphNodeType node_type, const int port, const int index) { return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index); } /* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type) { + const RemoteFusedGraphNodeType node_type) { return strings::StrCat(static_cast<int>(node_type)); } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index f047144278..ea6b6a1015 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -19,8 +19,6 @@ limitations under the License. #include <unordered_map> #include <unordered_set> -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" @@ -30,6 +28,17 @@ limitations under the License. namespace tensorflow { +enum RemoteFusedGraphNodeType { + UNUSED = 0, + GRAPH_INPUT = 1, + GRAPH_OUTPUT = 2, + FUSED_NODE = 3, + BORDER_INPUT = 4, + BORDER_OUTPUT = 5, +}; + +class RemoteFusedGraphExecuteInfo; + // RemoteFusedGraphExecuteUtils provides APIs to register and get builder // functions for IRemoteFusedGraphExecutor. class RemoteFusedGraphExecuteUtils { @@ -297,16 +306,15 @@ class RemoteFusedGraphExecuteUtils { static ExecutorBuildRegistry* GetExecutorBuildRegistry(); - static string BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, - const int index, const string& executor_name, const string& node_name); + static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type, + const int port, const int index, + const string& executor_name, + const string& node_name); - static string BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, - const int index); + static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type, + const int port, const int index); - static string BuildNodeTypeAttr( - const RemoteFusedGraphExecuteInfo::NodeType node_type); + static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type); TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils); }; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc index aca8ddfae9..44251e6ff8 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc index 9217c25978..1e0731e540 100644 --- a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/default_device.h" diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index 02391e967a..1854fe5526 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -17,14 +17,15 @@ limitations under the License. #include <memory> -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/event.pb.h" namespace tensorflow { +class Event; +class GraphDef; + // Main interface for the summary writer resource. class SummaryWriterInterface : public ResourceBase { public: diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index d317a8d33d..b287f0cc2f 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { |