aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-24 12:28:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-24 12:50:52 -0800
commitcc8bfc47e31d32b0c27fb30414187754e95622e7 (patch)
treebec6ee77227d57d717f1a3ae82b06078af3d7940
parent1f42290dcd92058cac1314b2d75b3b06bdf53a27 (diff)
Add remote_fused_graph_execute_info to abstract hexagon specific parameters
Change: 148489487
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/remote_fused_graph_execute_info.proto34
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.cc36
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.h5
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc4
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h3
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc17
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h5
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc8
-rw-r--r--tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h5
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op.cc44
17 files changed, 127 insertions, 41 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index f1c8a30669..17c64b6242 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -93,6 +93,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/log_memory.proto"
"tensorflow/core/framework/node_def.proto"
"tensorflow/core/framework/op_def.proto"
+ "tensorflow/core/framework/remote_fused_graph_execute_info.proto"
"tensorflow/core/framework/resource_handle.proto"
"tensorflow/core/framework/step_stats.proto"
"tensorflow/core/framework/summary.proto"
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index c17b26eccc..dfe0a70989 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -20,6 +20,7 @@ tensorflow/core/framework/tensor.pb.cc
tensorflow/core/framework/summary.pb.cc
tensorflow/core/framework/step_stats.pb.cc
tensorflow/core/framework/resource_handle.pb.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
tensorflow/core/framework/op_def.pb.cc
tensorflow/core/framework/node_def.pb.cc
tensorflow/core/framework/log_memory.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index f389742186..31ae64159a 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -21,6 +21,7 @@ tensorflow/core/framework/tensor.pb.h
tensorflow/core/framework/summary.pb.h
tensorflow/core/framework/step_stats.pb.h
tensorflow/core/framework/resource_handle.pb.h
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
tensorflow/core/framework/op_def.pb.h
tensorflow/core/framework/node_def.pb.h
tensorflow/core/framework/log_memory.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index d4b34e7146..1a79da05d3 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -14,6 +14,7 @@ tensorflow/core/framework/tensor.pb_text.cc
tensorflow/core/framework/summary.pb_text.cc
tensorflow/core/framework/step_stats.pb_text.cc
tensorflow/core/framework/resource_handle.pb_text.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
tensorflow/core/framework/op_def.pb_text.cc
tensorflow/core/framework/node_def.pb_text.cc
tensorflow/core/framework/log_memory.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index d16b7ae10f..2e1e5a95c4 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -21,6 +21,7 @@ tensorflow/core/framework/tensor.proto
tensorflow/core/framework/summary.proto
tensorflow/core/framework/step_stats.proto
tensorflow/core/framework/resource_handle.proto
+tensorflow/core/framework/remote_fused_graph_execute_info.proto
tensorflow/core/framework/reader_base.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/node_def.proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 67c5402785..a2d835d422 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -137,6 +137,7 @@ CORE_PROTO_SRCS = [
"framework/log_memory.proto",
"framework/node_def.proto",
"framework/op_def.proto",
+ "framework/remote_fused_graph_execute_info.proto",
"framework/resource_handle.proto",
"framework/step_stats.proto",
"framework/summary.proto",
diff --git a/tensorflow/core/framework/remote_fused_graph_execute_info.proto b/tensorflow/core/framework/remote_fused_graph_execute_info.proto
new file mode 100644
index 0000000000..8b9ebbab12
--- /dev/null
+++ b/tensorflow/core/framework/remote_fused_graph_execute_info.proto
@@ -0,0 +1,34 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "RemoteFusedGraphExecuteInfoProto";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "tensorflow/core/framework/node_def.proto";
+
+// Protocol buffer representing a handle to a tensorflow resource. Handles are
+// not valid across executions, but can be serialized back and forth from within
+// a single run.
+message RemoteFusedGraphExecuteInfo {
+ message GraphIONodeInfo {
+ string name = 1;
+ repeated int64 shape = 2;
+ }
+
+ // Nodes in remote fused graph
+ repeated NodeDef node = 1;
+
+ // Remote fused graph input
+ repeated GraphIONodeInfo graph_input_node_info = 2;
+
+ // Remote fused graph output
+ repeated GraphIONodeInfo graph_output_node_info = 3;
+
+ // Executor's name
+ string executor_name = 4;
+
+ // Optional parameters given to the executor
+ bytes serialized_executor_parameters = 5;
+};
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f733f7b55f..147bccd2d2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4234,6 +4234,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels/hexagon:graph_transferer",
],
)
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
index c44ef2cc02..ce6f3f7e32 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
@@ -49,6 +49,36 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
}
}
+/* static */ RemoteFusedGraphExecuteInfo
+GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
+ const GraphTransferInfo& graph_transfer_info) {
+ RemoteFusedGraphExecuteInfo execute_info;
+ execute_info.set_executor_name("hexagon");
+ for (const GraphTransferInfo::GraphInputNodeInfo& input :
+ graph_transfer_info.graph_input_node_info()) {
+ RemoteFusedGraphExecuteInfo::GraphIONodeInfo& graph_input_node_info =
+ *execute_info.add_graph_input_node_info();
+ graph_input_node_info.set_name(input.name());
+ for (const int64 dim : input.shape()) {
+ graph_input_node_info.add_shape(dim);
+ }
+ }
+
+ for (const GraphTransferInfo::GraphOutputNodeInfo& output :
+ graph_transfer_info.graph_output_node_info()) {
+ RemoteFusedGraphExecuteInfo::GraphIONodeInfo& graph_output_node_info =
+ *execute_info.add_graph_output_node_info();
+ graph_output_node_info.set_name(output.name());
+ for (const int64 dim : output.shape()) {
+ graph_output_node_info.add_shape(dim);
+ }
+ }
+
+ execute_info.set_serialized_executor_parameters(
+ graph_transfer_info.SerializeAsString());
+ return execute_info;
+}
+
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
@@ -77,7 +107,9 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
CHECK(scope.ok());
output_list.emplace_back(Output(ret, 0));
}
- string serialized_graph = gt->GetGraphTransferInfo().SerializeAsString();
+
+ const RemoteFusedGraphExecuteInfo execute_info =
+ BuildRemoteFusedGraphExecuteInfo(gt->GetGraphTransferInfo());
const Scope& scope = root.WithOpName(remote_graph_execute_name);
CHECK(scope.ok());
@@ -88,7 +120,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
.Input(node_out_list)
.Attr("N", static_cast<int64>(outputs.size()))
.Attr("serialized_graph_transfer_info",
- StringPiece(serialized_graph));
+ StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
index a9de914538..8ceca95ccb 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/platform/macros.h"
@@ -30,10 +31,14 @@ class GraphTransferUtils {
static std::priority_queue<std::tuple<float, int, string>>
GetTopNFloatResults(const float* const data, const string* const labels,
const int element_count);
+
static void DumpTopNFloatResults(const float* const data,
const string* const labels,
const int element_count, const int top_n);
+ static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
+ const GraphTransferInfo& graph_transfer_info);
+
static GraphDef BuildFusedGraphDef(
const IGraphTransferOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index 4b75b95839..b1f9de3adf 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -396,6 +396,10 @@ const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
return graph_transfer_info_;
}
+GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
+ return graph_transfer_info_;
+}
+
int GraphTransferer::CacheNode(const Node& node) {
if (node_name_to_id_cache_map_.count(node.name()) > 0) {
VLOG(1) << "Emplace node to cache failed";
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 5c09ba5db8..b039ddc300 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -106,6 +106,9 @@ class GraphTransferer {
// Return parameters for graph transfer
const GraphTransferInfo& GetGraphTransferInfo() const;
+ // Return mutable GraphTransferInfo for graph transfer
+ GraphTransferInfo& GetMutableGraphTransferInfo();
+
// Dump verification string of parameters to verify with offline tools
void DumpVerificationStringOfNodeTransferParams() const;
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index a1a7fe5bdd..d5a0091669 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -50,20 +50,21 @@ int HexagonControlWrapper::GetVersion() {
return soc_interface_GetSocControllerVersion();
}
-bool HexagonControlWrapper::Init() {
+bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
soc_interface_SetLogLevel(SHOW_DBG_IN_SOC ? -1 /* debug */ : 0 /* info */);
if (DBG_USE_SAMPLE_INPUT) {
soc_interface_SetDebugFlag(FLAG_ENABLE_PANDA_BINARY_INPUT);
}
+ graph_transferer_.SetSerializedGraphTransferInfo(
+ info.serialized_executor_parameters());
return soc_interface_Init();
}
bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); }
-bool HexagonControlWrapper::SetupGraph(
- const GraphTransferer& graph_transferer) {
+bool HexagonControlWrapper::SetupGraph() {
// Copy graph transfer info to modify to adapt hexnn library
- GraphTransferInfo graph_transfer_info =
- graph_transferer.GetGraphTransferInfo();
+ GraphTransferInfo& graph_transfer_info =
+ graph_transferer_.GetMutableGraphTransferInfo();
// Overwrite op type of input nodes for hexagon
for (const GraphTransferInfo::GraphInputNodeInfo& graph_input :
@@ -309,11 +310,11 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name,
#else
int HexagonControlWrapper::GetVersion() { return -1; }
-bool HexagonControlWrapper::Init() { return false; }
-bool HexagonControlWrapper::Finalize() { return false; }
-bool HexagonControlWrapper::SetupGraph(const GraphTransferer &) {
+bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo&) {
return false;
}
+bool HexagonControlWrapper::Finalize() { return false; }
+bool HexagonControlWrapper::SetupGraph() { return false; }
bool HexagonControlWrapper::ExecuteGraph() { return false; }
bool HexagonControlWrapper::TeardownGraph() { return false; }
bool HexagonControlWrapper::FillInputNode(const string&, const ConstByteArray) {
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
index 86540d35f9..e91b094c83 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
@@ -34,9 +34,9 @@ class HexagonControlWrapper final : public ISocControlWrapper {
public:
HexagonControlWrapper() = default;
int GetVersion() final;
- bool Init() final;
+ bool Init(const RemoteFusedGraphExecuteInfo& info) final;
bool Finalize() final;
- bool SetupGraph(const GraphTransferer& graph_transferer) final;
+ bool SetupGraph() final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(const string& node_name, const ConstByteArray bytes) final;
@@ -50,6 +50,7 @@ class HexagonControlWrapper final : public ISocControlWrapper {
static GraphTransferInfo::NodeInfo* FindNodeInfo(
const string& node_name, GraphTransferInfo* graph_transfer_info);
+ GraphTransferer graph_transferer_;
// Dummy float array for input node.
// TODO(satok): Use actual data passed by FillInputNode and remove
std::vector<float> dummy_input_float_;
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index 9da338fa71..bba798dc4c 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -171,12 +171,16 @@ static void RunInferenceByHexagonControlWrapper(
std::make_tuple(reinterpret_cast<const uint8*>(img_floats.data()),
img_floats.size() * sizeof(float), DT_FLOAT);
+ const RemoteFusedGraphExecuteInfo execute_info =
+ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
+ gt.GetGraphTransferInfo());
+
HexagonControlWrapper hexagon_control_wrapper;
// 1. Initialize hexagon
- hexagon_control_wrapper.Init();
+ hexagon_control_wrapper.Init(execute_info);
// 2. Setup graph in hexagon
- hexagon_control_wrapper.SetupGraph(gt);
+ hexagon_control_wrapper.SetupGraph();
// 3. Fill input node's output
hexagon_control_wrapper.FillInputNode("Mul", ba);
diff --git a/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h b/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h
index 86d01b3dcb..15d0ea09a5 100644
--- a/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/i_soc_control_wrapper.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_SOC_CONTROL_WRAPPER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_I_SOC_CONTROL_WRAPPER_H_
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
@@ -39,14 +40,14 @@ class ISocControlWrapper {
// Initialize SOC. This function should be called before
// starting graph transfer.
- virtual bool Init() = 0;
+ virtual bool Init(const RemoteFusedGraphExecuteInfo& info) = 0;
// Finalize SOC. This function should be called when all graph executions
// are finished.
virtual bool Finalize() = 0;
// Setup graph on SOC
- virtual bool SetupGraph(const GraphTransferer &graph_transferer) = 0;
+ virtual bool SetupGraph() = 0;
// Execute graph on SOC
virtual bool ExecuteGraph() = 0;
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
index 5fcffa2042..0b77c0c0ad 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
// See docs in ../ops/remote_fused_graph_ops.cc.
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -25,30 +25,25 @@ namespace tensorflow {
class RemoteFusedGraphExecuteOp : public OpKernel {
public:
explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
- : OpKernel(ctx), graph_transferer_() {
+ : OpKernel(ctx), execute_info_() {
string serialized_proto;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("serialized_graph_transfer_info", &serialized_proto));
- graph_transferer_.SetSerializedGraphTransferInfo(serialized_proto);
- const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
- switch (gt_info.destination()) {
- case GraphTransferInfo::NOP:
- break;
- case GraphTransferInfo::HEXAGON:
- soc_control_wrapper_.reset(new HexagonControlWrapper());
- break;
- default:
- // Other destination is not supported yet.
- CHECK(false);
- break;
+ execute_info_.ParseFromString(serialized_proto);
+ // TODO(satok): Add a way to register executor.
+ if (execute_info_.executor_name() == "hexagon") {
+ soc_control_wrapper_.reset(new HexagonControlWrapper());
}
if (soc_control_wrapper_) {
// 1. Initialize remote processor
- soc_control_wrapper_->Init();
+ soc_control_wrapper_->Init(execute_info_);
+ // Explicitly clear serialized executor parameter after initialization
+ // to release unnecessary memory.
+ execute_info_.clear_serialized_executor_parameters();
// 2. Setup graph in remote processor
- soc_control_wrapper_->SetupGraph(graph_transferer_);
+ soc_control_wrapper_->SetupGraph();
}
}
@@ -65,16 +60,15 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
void Compute(OpKernelContext* const ctx) final {
CHECK(ctx != nullptr);
const int input_count = ctx->num_inputs();
- const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo();
- CHECK(input_count == gt_info.graph_input_node_info_size())
+ CHECK(input_count == execute_info_.graph_input_node_info_size())
<< "input_count = " << input_count
- << ", gt input count = " << gt_info.graph_input_node_info_size();
+ << ", gt input count = " << execute_info_.graph_input_node_info_size();
// 3. Send inputs into remote processor
for (int i = 0; i < input_count; ++i) {
const Tensor& input_tensor = ctx->input(i);
- const GraphTransferInfo::GraphInputNodeInfo& input_node_info =
- gt_info.graph_input_node_info(i);
+ const RemoteFusedGraphExecuteInfo::GraphIONodeInfo& input_node_info =
+ execute_info_.graph_input_node_info(i);
const string& input_node_name = input_node_info.name();
if (soc_control_wrapper_) {
soc_control_wrapper_->FillInputNode(input_node_name, input_tensor);
@@ -88,12 +82,12 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
// 5. Load outputs from remote processor
const int output_count = ctx->num_outputs();
- CHECK(output_count == gt_info.graph_output_node_info_size());
+ CHECK(output_count == execute_info_.graph_output_node_info_size());
for (int i = 0; i < output_count; ++i) {
Tensor* output = nullptr;
TensorShape output_shape;
- const GraphTransferInfo::GraphOutputNodeInfo& output_node_info =
- gt_info.graph_output_node_info(i);
+ const RemoteFusedGraphExecuteInfo::GraphIONodeInfo& output_node_info =
+ execute_info_.graph_output_node_info(i);
for (const int64 dim : output_node_info.shape()) {
output_shape.AddDim(dim);
}
@@ -117,7 +111,7 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
bool IsExpensive() final { return true; }
private:
- GraphTransferer graph_transferer_;
+ RemoteFusedGraphExecuteInfo execute_info_;
std::unique_ptr<ISocControlWrapper> soc_control_wrapper_;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);