diff options
17 files changed, 127 insertions, 41 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index f1c8a30669..17c64b6242 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -93,6 +93,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" "tensorflow/core/framework/op_def.proto" + "tensorflow/core/framework/remote_fused_graph_execute_info.proto" "tensorflow/core/framework/resource_handle.proto" "tensorflow/core/framework/step_stats.proto" "tensorflow/core/framework/summary.proto" diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index c17b26eccc..dfe0a70989 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -20,6 +20,7 @@ tensorflow/core/framework/tensor.pb.cc tensorflow/core/framework/summary.pb.cc tensorflow/core/framework/step_stats.pb.cc tensorflow/core/framework/resource_handle.pb.cc +tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc tensorflow/core/framework/op_def.pb.cc tensorflow/core/framework/node_def.pb.cc tensorflow/core/framework/log_memory.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index f389742186..31ae64159a 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -21,6 +21,7 @@ tensorflow/core/framework/tensor.pb.h tensorflow/core/framework/summary.pb.h tensorflow/core/framework/step_stats.pb.h tensorflow/core/framework/resource_handle.pb.h +tensorflow/core/framework/remote_fused_graph_execute_info.pb.h tensorflow/core/framework/op_def.pb.h tensorflow/core/framework/node_def.pb.h tensorflow/core/framework/log_memory.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index d4b34e7146..1a79da05d3 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -14,6 +14,7 @@ tensorflow/core/framework/tensor.pb_text.cc tensorflow/core/framework/summary.pb_text.cc tensorflow/core/framework/step_stats.pb_text.cc tensorflow/core/framework/resource_handle.pb_text.cc +tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc tensorflow/core/framework/op_def.pb_text.cc tensorflow/core/framework/node_def.pb_text.cc tensorflow/core/framework/log_memory.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index d16b7ae10f..2e1e5a95c4 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -21,6 +21,7 @@ tensorflow/core/framework/tensor.proto tensorflow/core/framework/summary.proto tensorflow/core/framework/step_stats.proto tensorflow/core/framework/resource_handle.proto +tensorflow/core/framework/remote_fused_graph_execute_info.proto tensorflow/core/framework/reader_base.proto tensorflow/core/framework/op_def.proto tensorflow/core/framework/node_def.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 67c5402785..a2d835d422 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -137,6 +137,7 @@ CORE_PROTO_SRCS = [ "framework/log_memory.proto", "framework/node_def.proto", "framework/op_def.proto", + "framework/remote_fused_graph_execute_info.proto", "framework/resource_handle.proto", "framework/step_stats.proto", "framework/summary.proto", diff --git a/tensorflow/core/framework/remote_fused_graph_execute_info.proto b/tensorflow/core/framework/remote_fused_graph_execute_info.proto new file mode 100644 index 0000000000..8b9ebbab12 --- /dev/null +++ b/tensorflow/core/framework/remote_fused_graph_execute_info.proto @@ -0,0 +1,34 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "RemoteFusedGraphExecuteInfoProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/node_def.proto"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message RemoteFusedGraphExecuteInfo { + message GraphIONodeInfo { + string name = 1; + repeated int64 shape = 2; + } + + // Nodes in remote fused graph + repeated NodeDef node = 1; + + // Remote fused graph input + repeated GraphIONodeInfo graph_input_node_info = 2; + + // Remote fused graph output + repeated GraphIONodeInfo graph_output_node_info = 3; + + // Executor's name + string executor_name = 4; + + // Optional parameters given to the executor + bytes serialized_executor_parameters = 5; +}; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f733f7b55f..147bccd2d2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4234,6 +4234,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels/hexagon:graph_transferer", ], ) diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index c44ef2cc02..ce6f3f7e32 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -49,6 +49,36 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, } } +/* static */ RemoteFusedGraphExecuteInfo +GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( + const GraphTransferInfo& graph_transfer_info) { + RemoteFusedGraphExecuteInfo execute_info; + execute_info.set_executor_name("hexagon"); + for (const GraphTransferInfo::GraphInputNodeInfo& input : + graph_transfer_info.graph_input_node_info()) { + RemoteFusedGraphExecuteInfo::GraphIONodeInfo& graph_input_node_info = + *execute_info.add_graph_input_node_info(); + graph_input_node_info.set_name(input.name()); + for (const int64 dim : input.shape()) { + graph_input_node_info.add_shape(dim); + } + } + + for (const GraphTransferInfo::GraphOutputNodeInfo& output : + graph_transfer_info.graph_output_node_info()) { + RemoteFusedGraphExecuteInfo::GraphIONodeInfo& graph_output_node_info = + *execute_info.add_graph_output_node_info(); + graph_output_node_info.set_name(output.name()); + for (const int64 dim : output.shape()) { + graph_output_node_info.add_shape(dim); + } + } + + execute_info.set_serialized_executor_parameters( + graph_transfer_info.SerializeAsString()); + return execute_info; +} + /* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef( const IGraphTransferOpsDefinitions& ops_definitions, const string& remote_graph_execute_name, @@ -77,7 +107,9 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, CHECK(scope.ok()); output_list.emplace_back(Output(ret, 0)); } - string serialized_graph = gt->GetGraphTransferInfo().SerializeAsString(); + + const RemoteFusedGraphExecuteInfo execute_info = + BuildRemoteFusedGraphExecuteInfo(gt->GetGraphTransferInfo()); const Scope& scope = root.WithOpName(remote_graph_execute_name); CHECK(scope.ok()); @@ -88,7 +120,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data, .Input(node_out_list) .Attr("N", static_cast<int64>(outputs.size())) .Attr("serialized_graph_transfer_info", - StringPiece(serialized_graph)); + StringPiece(execute_info.SerializeAsString())); CHECK(scope.ok()); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &node)); diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h index a9de914538..8ceca95ccb 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include <queue> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h" #include "tensorflow/core/platform/macros.h" @@ -30,10 +31,14 @@ class GraphTransferUtils { static std::priority_queue<std::tuple<float, int, string>> GetTopNFloatResults(const float* const data, const string* const labels, const int element_count); + static void DumpTopNFloatResults(const float* const data, const string* const labels, const int element_count, const int top_n); + static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo( + const GraphTransferInfo& graph_transfer_info); + static GraphDef BuildFusedGraphDef( const IGraphTransferOpsDefinitions& ops_definitions, const string& remote_graph_execute_name, diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 4b75b95839..b1f9de3adf 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -396,6 +396,10 @@ const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const { return graph_transfer_info_; } +GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() { + return graph_transfer_info_; +} + int GraphTransferer::CacheNode(const Node& node) { if (node_name_to_id_cache_map_.count(node.name()) > 0) { VLOG(1) << "Emplace node to cache failed"; diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index 5c09ba5db8..b039ddc300 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -106,6 +106,9 @@ class GraphTransferer { // Return parameters for graph transfer const GraphTransferInfo& GetGraphTransferInfo() const; + // Return mutable GraphTransferInfo for graph transfer + GraphTransferInfo& GetMutableGraphTransferInfo(); + // Dump verification string of parameters to verify with offline tools void DumpVerificationStringOfNodeTransferParams() const; diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index a1a7fe5bdd..d5a0091669 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -50,20 +50,21 @@ int HexagonControlWrapper::GetVersion() { return soc_interface_GetSocControllerVersion(); } -bool HexagonControlWrapper::Init() { +bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) { soc_interface_SetLogLevel(SHOW_DBG_IN_SOC ? -1 /* debug */ : 0 /* info */); if (DBG_USE_SAMPLE_INPUT) { soc_interface_SetDebugFlag(FLAG_ENABLE_PANDA_BINARY_INPUT); } + graph_transferer_.SetSerializedGraphTransferInfo( + info.serialized_executor_parameters()); return soc_interface_Init(); } bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); } -bool HexagonControlWrapper::SetupGraph( - const GraphTransferer& graph_transferer) { +bool HexagonControlWrapper::SetupGraph() { // Copy graph transfer info to modify to adapt hexnn library - GraphTransferInfo graph_transfer_info = - graph_transferer.GetGraphTransferInfo(); + GraphTransferInfo& graph_transfer_info = + graph_transferer_.GetMutableGraphTransferInfo(); // Overwrite op type of input nodes for hexagon for (const GraphTransferInfo::GraphInputNodeInfo& graph_input : @@ -309,11 +310,11 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name, #else int HexagonControlWrapper::GetVersion() { return -1; } -bool HexagonControlWrapper::Init() { return false; } -bool HexagonControlWrapper::Finalize() { return false; } -bool HexagonControlWrapper::SetupGraph(const GraphTransferer &) { +bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) { return false; } +bool HexagonControlWrapper::Finalize() { return false; } +bool HexagonControlWrapper::SetupGraph() { return false; } bool HexagonControlWrapper::ExecuteGraph() { return false; } bool HexagonControlWrapper::TeardownGraph() { return false; } bool HexagonControlWrapper::FillInputNode(const string&, const ConstByteArray) { diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h index 86540d35f9..e91b094c83 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h @@ -34,9 +34,9 @@ class HexagonControlWrapper final : public ISocControlWrapper { public: HexagonControlWrapper() = default; int GetVersion() final; - bool Init() final; + bool Init(const RemoteFusedGraphExecuteInfo& info) final; bool Finalize() final; - bool SetupGraph(const GraphTransferer& graph_transferer) final; + bool SetupGraph() final; bool ExecuteGraph() final; bool TeardownGraph() final; bool FillInputNode(const string& node_name, const ConstByteArray bytes) final; @@ -50,6 +50,7 @@ class HexagonControlWrapper final : public ISocControlWrapper { static GraphTransferInfo::NodeInfo* FindNodeInfo( const string& node_name, GraphTransferInfo* graph_transfer_info); + GraphTransferer graph_transferer_; // Dummy float array for input node. // TODO(satok): Use actual data passed by FillInputNode and remove std::vector<float> dummy_input_float_; diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 9da338fa71..bba798dc4c 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -171,12 +171,16 @@ static void RunInferenceByHexagonControlWrapper( std::make_tuple(reinterpret_cast<const uint8*>(img_floats.data()), img_floats.size() * sizeof(float), DT_FLOAT); + const RemoteFusedGraphExecuteInfo execute_info = + GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( + gt.GetGraphTransferInfo()); + HexagonControlWrapper hexagon_control_wrapper; // 1. Initialize hexagon - hexagon_control_wrapper.Init(); + hexagon_control_wrapper.Init(execute_info); // 2. Setup graph in hexagon - hexagon_control_wrapper.SetupGraph(gt); + hexagon_control_wrapper.SetupGraph(); // 3. Fill input node's output hexagon_control_wrapper.FillInputNode("Mul", ba); diff --git a/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h b/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h index 86d01b3dcb..15d0ea09a5 100644 --- a/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h +++ b/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_SOC_CONTROL_WRAPPER_H_ #define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_SOC_CONTROL_WRAPPER_H_ +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/hexagon/graph_transferer.h" @@ -39,14 +40,14 @@ class ISocControlWrapper { // Initialize SOC. This function should be called before // starting graph transfer. - virtual bool Init() = 0; + virtual bool Init(const RemoteFusedGraphExecuteInfo& info) = 0; // Finalize SOC. This function should be called when all graph executions // are finished. virtual bool Finalize() = 0; // Setup graph on SOC - virtual bool SetupGraph(const GraphTransferer &graph_transferer) = 0; + virtual bool SetupGraph() = 0; // Execute graph on SOC virtual bool ExecuteGraph() = 0; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc index 5fcffa2042..0b77c0c0ad 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc @@ -16,7 +16,7 @@ limitations under the License. // See docs in ../ops/remote_fused_graph_ops.cc. #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/kernels/hexagon/graph_transferer.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -25,30 +25,25 @@ namespace tensorflow { class RemoteFusedGraphExecuteOp : public OpKernel { public: explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx) - : OpKernel(ctx), graph_transferer_() { + : OpKernel(ctx), execute_info_() { string serialized_proto; OP_REQUIRES_OK( ctx, ctx->GetAttr("serialized_graph_transfer_info", &serialized_proto)); - graph_transferer_.SetSerializedGraphTransferInfo(serialized_proto); - const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo(); - switch (gt_info.destination()) { - case GraphTransferInfo::NOP: - break; - case GraphTransferInfo::HEXAGON: - soc_control_wrapper_.reset(new HexagonControlWrapper()); - break; - default: - // Other destination is not supported yet. - CHECK(false); - break; + execute_info_.ParseFromString(serialized_proto); + // TODO(satok): Add a way to register executor. + if (execute_info_.executor_name() == "hexagon") { + soc_control_wrapper_.reset(new HexagonControlWrapper()); } if (soc_control_wrapper_) { // 1. Initialize remote processor - soc_control_wrapper_->Init(); + soc_control_wrapper_->Init(execute_info_); + // Explicitly clear serialized executor parameter after initialization + // to release unnecessary memory. + execute_info_.clear_serialized_executor_parameters(); // 2. Setup graph in remote processor - soc_control_wrapper_->SetupGraph(graph_transferer_); + soc_control_wrapper_->SetupGraph(); } } @@ -65,16 +60,15 @@ class RemoteFusedGraphExecuteOp : public OpKernel { void Compute(OpKernelContext* const ctx) final { CHECK(ctx != nullptr); const int input_count = ctx->num_inputs(); - const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo(); - CHECK(input_count == gt_info.graph_input_node_info_size()) + CHECK(input_count == execute_info_.graph_input_node_info_size()) << "input_count = " << input_count - << ", gt input count = " << gt_info.graph_input_node_info_size(); + << ", gt input count = " << execute_info_.graph_input_node_info_size(); // 3. Send inputs into remote processor for (int i = 0; i < input_count; ++i) { const Tensor& input_tensor = ctx->input(i); - const GraphTransferInfo::GraphInputNodeInfo& input_node_info = - gt_info.graph_input_node_info(i); + const RemoteFusedGraphExecuteInfo::GraphIONodeInfo& input_node_info = + execute_info_.graph_input_node_info(i); const string& input_node_name = input_node_info.name(); if (soc_control_wrapper_) { soc_control_wrapper_->FillInputNode(input_node_name, input_tensor); @@ -88,12 +82,12 @@ class RemoteFusedGraphExecuteOp : public OpKernel { // 5. Load outputs from remote processor const int output_count = ctx->num_outputs(); - CHECK(output_count == gt_info.graph_output_node_info_size()); + CHECK(output_count == execute_info_.graph_output_node_info_size()); for (int i = 0; i < output_count; ++i) { Tensor* output = nullptr; TensorShape output_shape; - const GraphTransferInfo::GraphOutputNodeInfo& output_node_info = - gt_info.graph_output_node_info(i); + const RemoteFusedGraphExecuteInfo::GraphIONodeInfo& output_node_info = + execute_info_.graph_output_node_info(i); for (const int64 dim : output_node_info.shape()) { output_shape.AddDim(dim); } @@ -117,7 +111,7 @@ class RemoteFusedGraphExecuteOp : public OpKernel { bool IsExpensive() final { return true; } private: - GraphTransferer graph_transferer_; + RemoteFusedGraphExecuteInfo execute_info_; std::unique_ptr<ISocControlWrapper> soc_control_wrapper_; TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp); |