aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD1
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc55
-rw-r--r--tensorflow/core/kernels/BUILD37
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc9
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op.cc7
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc8
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc53
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h28
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc32
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.h19
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc131
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc163
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc206
-rw-r--r--tensorflow/tools/graph_transforms/BUILD1
16 files changed, 641 insertions, 111 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index e93680fb4d..c95cd068cd 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -101,6 +101,7 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*"
"${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute*.cc"
+ "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform*.cc"
)
list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs})
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 2568ca3bad..6392a5e044 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -282,6 +282,7 @@ if (tensorflow_BUILD_CC_TESTS)
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/call_options_test.cc"
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc"
+ "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc"
"${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc"
)
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
index 922996a686..fa75943d78 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD
@@ -29,6 +29,7 @@ cc_binary(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/core/kernels/hexagon:graph_transferer",
"//tensorflow/tools/graph_transforms:transform_utils",
],
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
index 3a219bb3e6..6ae7c4a742 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
@@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
+#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
@@ -46,6 +48,47 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) {
return 0;
}
+static void SummarizeNode(const NodeDef& node_def) {
+ LOG(INFO) << "Node(" << node_def.name() << ")";
+ LOG(INFO) << " op: " << node_def.op();
+ for (const string& input : node_def.input()) {
+ LOG(INFO) << " Input: " << input;
+ }
+}
+
+static void DumpRemoteFusedGraph(const NodeDef& node_def) {
+ LOG(INFO) << "Remote fused graph found.";
+ RemoteFusedGraphExecuteInfo info;
+ string serialized_proto;
+ GetNodeAttr(node_def,
+ RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
+ &serialized_proto)
+ .IgnoreError();
+ info.ParseFromString(serialized_proto);
+ LOG(INFO) << "Node name: " << node_def.name();
+ LOG(INFO) << "Executor name: " << info.executor_name();
+ for (const string& input : info.graph_input_node_name()) {
+ LOG(INFO) << "Input: " << input;
+ }
+ for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type :
+ info.default_graph_input_tensor_shape()) {
+ LOG(INFO) << "Input shape type: " << shape_type.DebugString();
+ }
+ for (const string& output : info.graph_output_node_name()) {
+ LOG(INFO) << "Output: " << output;
+ }
+ for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type :
+ info.default_graph_output_tensor_shape()) {
+ LOG(INFO) << "Output shape type: " << shape_type.DebugString();
+ }
+ const int subgraph_node_size = info.remote_graph().node_size();
+ LOG(INFO) << "Nodes in the graph: " << subgraph_node_size;
+ for (int i = 0; i < subgraph_node_size; ++i) {
+ LOG(INFO) << "node(" << i << "): " << info.remote_graph().node(i).name();
+ }
+}
+
static void CheckOpsSupport(const GraphDef& graph_def) {
const IGraphTransferOpsDefinitions& ops_definition =
HexagonOpsDefinitions::getInstance();
@@ -53,7 +96,13 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
std::unordered_set<string> unsupported_ops;
bool all_supported = true;
+ bool contains_remote_graph = false;
for (const NodeDef& node : graph_def.node()) {
+ if (node.op() == "RemoteFusedGraphExecute") {
+ contains_remote_graph = true;
+ DumpRemoteFusedGraph(node);
+ continue;
+ }
// TODO(satok): Set correct data type if it's given.
const int op_id = ops_definition.GetOpIdFor(node.op(), {});
if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) {
@@ -75,6 +124,12 @@ static void CheckOpsSupport(const GraphDef& graph_def) {
} else {
LOG(INFO) << count << " ops are not supported.";
}
+
+ if (contains_remote_graph) {
+ for (const NodeDef& node : graph_def.node()) {
+ SummarizeNode(node);
+ }
+ }
}
} // namespace
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index b2eaaa3492..590d334edc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4297,6 +4297,7 @@ filegroup(
"encode_jpeg_op.*",
"identity_reader_op.*",
"remote_fused_graph_execute_op.*",
+ "remote_fused_graph_rewriter_transform.*",
"fixed_length_record_reader_op.*",
"whole_file_read_ops.*",
"sample_distorted_bounding_box_op.*",
@@ -4777,6 +4778,7 @@ cc_library(
cc_library(
name = "remote_fused_graph_execute_op_test_utils",
+ testonly = 1,
srcs = ["remote_fused_graph_execute_op_test_utils.cc"],
hdrs = ["remote_fused_graph_execute_op_test_utils.h"],
deps = [
@@ -4785,6 +4787,7 @@ cc_library(
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
],
)
@@ -4842,6 +4845,40 @@ tf_cc_test(
],
)
+cc_library(
+ name = "remote_fused_graph_rewriter_transform",
+ srcs = [
+ "remote_fused_graph_rewriter_transform.cc",
+ ],
+ deps = [
+ ":remote_fused_graph_execute_utils",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:remote_fused_graph_ops",
+ "//tensorflow/core",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "remote_fused_graph_rewriter_transform_test",
+ size = "small",
+ srcs = ["remote_fused_graph_rewriter_transform_test.cc"],
+ deps = [
+ ":remote_fused_graph_execute_op_test_utils",
+ ":remote_fused_graph_execute_utils",
+ ":remote_fused_graph_rewriter_transform",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
tf_mkl_kernel_library(
name = "mkl_conv_op",
prefix = "mkl_conv",
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index a383cc8199..54ba101501 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -35,6 +35,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
@@ -268,8 +269,7 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) {
session_options.env = Env::Default();
std::unique_ptr<Session> session =
std::unique_ptr<Session>(NewSession(session_options));
- Status status = session->Create(fused_graph_def);
- ASSERT_TRUE(status.ok());
+ TF_ASSERT_OK(session->Create(fused_graph_def));
// Setup session arguments
RunOptions run_options;
@@ -283,9 +283,8 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) {
LOG(INFO) << "Run graph";
// Run inference with all node as output
- status = session->Run(run_options, input_tensors, output_node_names, {},
- &output_tensors, &run_metadata);
- ASSERT_TRUE(status.ok());
+ TF_ASSERT_OK(session->Run(run_options, input_tensors, output_node_names, {},
+ &output_tensors, &run_metadata));
ASSERT_EQ(1, output_tensors.size());
const Tensor& output_tensor = output_tensors.at(0);
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
index 101d0e694b..aa3835ecc5 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
@@ -29,9 +29,10 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
: OpKernel(ctx), execute_info_() {
string serialized_proto;
- OP_REQUIRES_OK(ctx,
- ctx->GetAttr("serialized_remote_fused_graph_execute_info",
- &serialized_proto));
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
+ &serialized_proto));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
execute_info_.ParseFromString(serialized_proto);
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
index 925af1f79e..112168b195 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
@@ -269,13 +269,13 @@ static Status RewriteGraphToFusedGraph(const GraphDef& original_graph,
// 5. Fuse the original graph and run the inference the new fused graph
TEST(RemoteFusedExecuteGraphOp, EndToEndTest) {
// 5.1 Load original graph
- const GraphDef original_graph =
- RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef original_graph;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &original_graph));
// 5.2 Fuse graph
GraphDef fused_graph;
- TF_CHECK_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph));
+ TF_ASSERT_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph));
// 5.3 Setup session
std::vector<Tensor> output_tensors;
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc
index 6e7d4b73d2..31c48082dd 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc
@@ -15,7 +15,10 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
+#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -36,17 +39,57 @@ namespace tensorflow {
return Output(ret, 0);
}
-/* static */ GraphDef RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+/* static */ Status RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
const string& name0, const float val0, const string& name1,
- const float val1, const string& name_out) {
+ const float val1, const string& name_out, GraphDef* graph_def) {
Scope root = Scope::NewRootScope();
Output node0 = ops::Const(root.WithOpName(name0), val0);
Output node1 = ops::Const(root.WithOpName(name1), val1);
RemoteFusedGraphExecuteOpTestUtils::BuildAddOp(root.WithOpName(name_out),
node0, node1);
- GraphDef def;
- TF_CHECK_OK(root.ToGraphDef(&def));
- return def;
+ TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def));
+ return Status::OK();
+}
+
+/* static */ Status RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
+ GraphDef* graph_def) {
+ Scope root = tensorflow::Scope::NewRootScope();
+
+ Tensor a_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&a_data, 1.0f);
+ Output a_const = ops::Const(root.WithOpName("A"), Input::Initializer(a_data));
+
+ Tensor b_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&b_data, 1.0f);
+ Output b_const = ops::Const(root.WithOpName("B"), Input::Initializer(b_data));
+
+ Tensor c_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&c_data, 1.0f);
+ Output c_const = ops::Const(root.WithOpName("C"), Input::Initializer(c_data));
+
+ Tensor d_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&d_data, 1.0f);
+ Output d_const = ops::Const(root.WithOpName("D"), Input::Initializer(d_data));
+
+ Tensor e_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillIota<float>(&e_data, 1.0f);
+ Output e_const = ops::Const(root.WithOpName("E"), Input::Initializer(e_data));
+
+ Output f_add = ops::Add(root.WithOpName("F"), a_const, b_const);
+
+ Output g_add = ops::Add(root.WithOpName("G"), d_const, e_const);
+
+ Output h_add = ops::Add(root.WithOpName("H"), f_add, c_const);
+
+ Output i_add = ops::Add(root.WithOpName("I"), c_const, g_add);
+
+ Output j_add = ops::Add(root.WithOpName("J"), h_add, i_add);
+
+ Output k_add = ops::Add(root.WithOpName("K"), j_add, g_add);
+
+ TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def));
+
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
index 70d758ea6a..a0df50162b 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
@@ -28,9 +28,31 @@ namespace tensorflow {
class RemoteFusedGraphExecuteOpTestUtils {
public:
static Output BuildAddOp(const Scope& scope, const Input& x, const Input& y);
- static GraphDef BuildAddGraph(const string& name0, const float val0,
- const string& name1, const float val1,
- const string& name_out);
+ static Status BuildAddGraph(const string& name0, const float val0,
+ const string& name1, const float val1,
+ const string& name_out, GraphDef* graph_def);
+
+ // BuildMultipleAddGraph builds the following graph
+ //
+ // A B C D E
+ // | | | | |
+ // +----+----+ | +----+----+
+ // | | |
+ // F / \ G
+ // | | | / \
+ // +-----+------+ +-----+----+ +
+ // | | |
+ // H I |
+ // | | |
+ // +-------+--------+ |
+ // | |
+ // J |
+ // | |
+ // +--------+--------+
+ // |
+ // K
+ //
+ static Status BuildMultipleAddGraph(GraphDef* graph_def);
private:
RemoteFusedGraphExecuteOpTestUtils() = delete;
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 2174098bde..d0ffcb1064 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -104,6 +104,22 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
/* static */ constexpr const char* const
RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
+/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
+/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
+/* static */ constexpr const char* const
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
const string& name, ExecutorBuildFunc executor_build_func) {
@@ -797,6 +813,22 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
shape, subgraph_def));
}
+
+ // sort subgraph_def to align order in graph_def
+ std::unordered_map<string, int> name_to_id_map;
+ for (int i = 0; i < graph_def.node_size(); ++i) {
+ name_to_id_map.emplace(graph_def.node(i).name(), i);
+ }
+ std::sort(subgraph_def->mutable_node()->begin(),
+ subgraph_def->mutable_node()->end(),
+ [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
+ CHECK(name_to_id_map.count(node0.name()) > 0);
+ CHECK(name_to_id_map.count(node1.name()) > 0);
+ const int id0 = name_to_id_map.at(node0.name());
+ const int id1 = name_to_id_map.at(node1.name());
+ return id0 < id1;
+ });
+
VLOG(1) << DumpGraphDef(*subgraph_def);
return Status::OK();
}
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
index 97b0c2008a..3a792824c5 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
@@ -39,6 +39,25 @@ class RemoteFusedGraphExecuteUtils {
// TODO(satok): Use "_output_shapes" to share a spec with other ops
static constexpr const char* const ATTR_OUTPUT_SHAPES =
"_default_remote_output_shapes";
+ static constexpr const char* const
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO =
+ "serialized_remote_fused_graph_execute_info";
+
+ // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp.
+ static constexpr const char* const
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
+ "remote_fused_graph_executor_name";
+ static constexpr const char* const
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME =
+ "remote_fused_graph_node_name";
+ static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes";
+ static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS =
+ "border_inputs";
+ static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS =
+ "border_outputs";
+ static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types";
+ static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES =
+ "input_shapes";
using ExecutorBuildFunc = std::function<Status(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>;
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
index 8bd63d996a..581b61a625 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
@@ -15,11 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/cc/framework/scope.h"
-#include "tensorflow/cc/ops/array_ops.h"
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -47,69 +43,12 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
return nullptr;
}
-// This function builds the following graph
-//
-// A B C D E
-// | | | | |
-// +----+----+ | +----+----+
-// | | |
-// F / \ G
-// | | | / \
-// +-----+------+ +-----+----+ +
-// | | |
-// H I |
-// | | |
-// +-------+--------+ |
-// | |
-// J |
-// | |
-// +--------+--------+
-// |
-// K
-//
-Status BuildMultipleAddGraph(GraphDef* graph_def) {
- Scope root = tensorflow::Scope::NewRootScope();
-
- Tensor a_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
- test::FillIota<float>(&a_data, 1.0f);
- Output a_const = ops::Const(root.WithOpName("A"), Input::Initializer(a_data));
-
- Tensor b_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
- test::FillIota<float>(&b_data, 1.0f);
- Output b_const = ops::Const(root.WithOpName("B"), Input::Initializer(b_data));
-
- Tensor c_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
- test::FillIota<float>(&c_data, 1.0f);
- Output c_const = ops::Const(root.WithOpName("C"), Input::Initializer(c_data));
-
- Tensor d_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
- test::FillIota<float>(&d_data, 1.0f);
- Output d_const = ops::Const(root.WithOpName("D"), Input::Initializer(d_data));
-
- Tensor e_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
- test::FillIota<float>(&e_data, 1.0f);
- Output e_const = ops::Const(root.WithOpName("E"), Input::Initializer(e_data));
-
- Output f_add = ops::Add(root.WithOpName("F"), a_const, b_const);
-
- Output g_add = ops::Add(root.WithOpName("G"), d_const, e_const);
-
- Output h_add = ops::Add(root.WithOpName("H"), f_add, c_const);
-
- Output i_add = ops::Add(root.WithOpName("I"), c_const, g_add);
-
- Output j_add = ops::Add(root.WithOpName("J"), h_add, i_add);
-
- Output k_add = ops::Add(root.WithOpName("K"), j_add, g_add);
-
- TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def));
-
- return Status::OK();
-}
-
class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
protected:
- void SetUp() final { TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def_)); }
+ void SetUp() final {
+ TF_ASSERT_OK(
+ RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
+ }
void TearDown() final {}
@@ -196,8 +135,9 @@ static void ClearCluster(ClusterInfo* cluster) {
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) {
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
std::pair<string, Tensor> input_node_info;
input_node_info.first = NAME_A;
input_node_info.second = Tensor(DT_FLOAT, {});
@@ -216,8 +156,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) {
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) {
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
std::pair<string, Tensor> input_node_info;
input_node_info.first = NAME_A;
input_node_info.second = Tensor(DT_FLOAT, {});
@@ -235,8 +176,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) {
}
TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAB) {
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
std::pair<string, Tensor> input_node_info_a;
input_node_info_a.first = NAME_A;
input_node_info_a.second = Tensor(DT_FLOAT, {});
@@ -268,8 +210,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) {
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
// dryrun
const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
@@ -310,8 +253,9 @@ TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) {
input_node_info_b};
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
ImportGraphDefOptions opts;
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
@@ -398,8 +342,9 @@ TEST(RemoteFusedGraphExecuteUtils,
// Build outputs
const std::vector<string> outputs = {NAME_A_PLUS_B};
- GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
input_tensors, /*dry_run_inference*/ true, &def));
@@ -427,8 +372,9 @@ TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) {
// Build outputs
const std::vector<string> outputs = {NAME_A_PLUS_B};
- const GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
- NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B);
+ GraphDef def;
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
+ NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
@@ -442,7 +388,8 @@ TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) {
TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) {
GraphDef graph_def;
- TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def));
+ TF_ASSERT_OK(
+ RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
ClusterInfo cluster;
const std::unordered_set<string>& node_names = std::get<0>(cluster);
TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
@@ -472,7 +419,8 @@ TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) {
TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) {
GraphDef graph_def;
- TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def));
+ TF_ASSERT_OK(
+ RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
std::vector<ClusterInfo> ci_vec;
TF_ASSERT_OK(
@@ -508,7 +456,8 @@ TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) {
TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) {
GraphDef graph_def;
- TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def));
+ TF_ASSERT_OK(
+ RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
ClusterInfo cluster;
GraphDef subgraph_def;
@@ -556,7 +505,7 @@ TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) {
EXPECT_EQ(3, subgraph_def.node_size());
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_hi_j) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_HI_J) {
SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"},
this);
@@ -569,7 +518,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_hi_j) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_fcg_j) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_FCG_J) {
SetSubgraphArguments(std::vector<string>{"F", "C", "G"},
std::vector<string>{"J"}, this);
@@ -582,7 +531,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_fcg_j) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_j) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_J) {
SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"},
std::vector<string>{"J"}, this);
@@ -595,7 +544,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_j) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_k) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_K) {
SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"},
std::vector<string>{"K"}, this);
@@ -608,7 +557,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_k) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_h) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_H) {
subgraph_node_names_ = {"H"};
TF_ASSERT_OK(FuseByNodes());
@@ -620,7 +569,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_h) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_hij) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_HIJ) {
subgraph_node_names_ = {"H", "I", "J"};
TF_ASSERT_OK(FuseByNodes());
@@ -632,7 +581,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_hij) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_cfghij) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_CFGHIJ) {
subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"};
TF_ASSERT_OK(FuseByNodes());
@@ -644,7 +593,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_cfghij) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghij) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJ) {
subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"};
TF_ASSERT_OK(FuseByNodes());
@@ -656,7 +605,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghij) {
<< SummarizeGraphDef(result_graph_def_);
}
-TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghijk) {
+TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJK) {
subgraph_node_names_ = {"A", "B", "C", "D", "E", "F",
"G", "H", "I", "J", "K"};
diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc
new file mode 100644
index 0000000000..8742214e17
--- /dev/null
+++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc
@@ -0,0 +1,163 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Wraps the hexagon rewriter in a transform so it can be used as part of the
+// graph transform tool.
+// A usage example, based on inception v3 model:
+/*
+bazel build tensorflow/tools/graph_transforms:transform_graph
+
+
+// Specify remote graph by node names
+bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
+--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
+--out_graph=\
+/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
+--inputs='Mul' \
+--outputs='softmax' \
+--transforms='\
+fuse_remote_graph(
+input_types="float" \
+input_shapes="1,299,299,3" \
+fused_nodes="NodeA,NodeB,NodeC",
+remote_fused_graph_executor_name="executor" \
+remote_fused_graph_node_name="node_name" \
+)'
+
+// Specify remote graph by border inputs and outputs
+bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
+--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
+--out_graph=\
+/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
+--inputs='Mul' \
+--outputs='softmax' \
+--transforms='\
+fuse_remote_graph(
+input_types="float" \
+input_shapes="1,299,299,3" \
+border_inputs="NodeA:0,NodeB:0" \
+border_outputs="NodeC" \
+remote_fused_graph_executor_name="executor" \
+remote_fused_graph_node_name="node_name" \
+)'
+*/
+
+#include <unordered_set>
+
+#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+Status FuseRemoteGraph(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ GraphDef mutable_input_graph_def = input_graph_def;
+
+ const std::vector<string>& inputs = context.input_names;
+ const std::vector<string>& outputs = context.output_names;
+
+ string input_types_str;
+ string input_shapes_str;
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, "",
+ &input_types_str));
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, "",
+ &input_shapes_str));
+
+ if (!input_types_str.empty()) {
+ const std::vector<string> input_types_strs =
+ str_util::Split(input_types_str, ",");
+ const std::vector<string> input_shapes_strs =
+ str_util::Split(input_shapes_str, ":");
+ CHECK_EQ(inputs.size(), input_types_strs.size());
+ CHECK_EQ(inputs.size(), input_shapes_strs.size());
+ std::vector<std::pair<string, Tensor>> input_tensors;
+ for (int i = 0; i < inputs.size(); ++i) {
+ const string& name = inputs.at(i);
+ std::vector<int64> dims;
+ CHECK(str_util::SplitAndParseAsInts(input_shapes_strs.at(i), ',', &dims));
+ DataType data_type;
+ CHECK(DataTypeFromString(input_types_strs.at(i), &data_type))
+ << "\"" << input_types_strs.at(i) << "\" was an invalid type";
+ input_tensors.emplace_back(
+ std::make_pair(name, Tensor(data_type, TensorShape(dims))));
+ }
+ TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
+ input_tensors, /*dry_run_inference=*/true, &mutable_input_graph_def));
+ }
+
+ string fused_nodes_str;
+ string border_inputs_str;
+ string border_outputs_str;
+ string remote_graph_executor_name;
+ string remote_fused_graph_node_name;
+
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, "",
+ &fused_nodes_str));
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, "",
+ &border_inputs_str));
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, "",
+ &border_outputs_str));
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
+ "", &remote_graph_executor_name));
+ TF_RETURN_IF_ERROR(context.GetOneStringParameter(
+ RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
+ "", &remote_fused_graph_node_name));
+
+ CHECK(!remote_graph_executor_name.empty());
+
+ const bool require_shape_type = !input_types_str.empty();
+ if (!fused_nodes_str.empty()) {
+ const std::vector<string>& fused_node_name_vector =
+ str_util::Split(fused_nodes_str, ",");
+ const std::unordered_set<string> fused_node_names(
+ fused_node_name_vector.begin(), fused_node_name_vector.end());
+ TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
+ mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
+ fused_node_names, remote_graph_executor_name, require_shape_type,
+ output_graph_def));
+ } else if (!border_inputs_str.empty() && !border_outputs_str.empty()) {
+ const std::vector<string> border_inputs =
+ str_util::Split(border_inputs_str, ",");
+ const std::vector<string> border_outputs =
+ str_util::Split(border_outputs_str, ",");
+ for (int i = 0; i < border_inputs.size(); ++i) {
+ VLOG(2) << "Border Input(" << i << "): " << border_inputs.at(i);
+ }
+ for (int i = 0; i < border_outputs.size(); ++i) {
+ VLOG(2) << "Border Output(" << i << "): " << border_outputs.at(i);
+ }
+ TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
+ mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name,
+ border_inputs, border_outputs, remote_graph_executor_name,
+ require_shape_type, output_graph_def));
+ } else {
+ CHECK(false) << "Fuse targets are not specified.";
+ }
+
+ return Status::OK();
+}
+
+REGISTER_GRAPH_TRANSFORM("fuse_remote_graph", FuseRemoteGraph);
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
new file mode 100644
index 0000000000..9e061437a9
--- /dev/null
+++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
@@ -0,0 +1,206 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
+#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declared here so we don't have to put it in a public header.
+Status FuseRemoteGraph(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+
+namespace {
+
+constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
+ "remote_fused_graph_executor_name";
+constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
+ "remote_fused_graph_node_name";
+
+class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
+ protected:
+ void SetUp() final {
+ TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
+ &input_graph_def_));
+ }
+
+ void TearDown() final {}
+
+ Status Fuse() {
+ TransformFuncContext context;
+ context.input_names = inputs_;
+ context.output_names = outputs_;
+
+ if (!input_types_.empty()) {
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES,
+ {input_types_}}));
+ }
+ if (!input_shapes_.empty()) {
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES,
+ {input_shapes_}}));
+ }
+ if (!fused_node_names_str_.empty()) {
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES,
+ {fused_node_names_str_}}));
+ }
+
+ if (!border_inputs_str_.empty()) {
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS,
+ {border_inputs_str_}}));
+ }
+ if (!border_outputs_str_.empty()) {
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS,
+ {border_outputs_str_}}));
+ }
+
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
+ {REMOTE_FUSED_GRAPH_EXECUTOR_NAME}}));
+ context.params.insert(std::pair<string, std::vector<string>>(
+ {RemoteFusedGraphExecuteUtils::
+ TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
+ {REMOTE_FUSED_GRAPH_NODE_NAME}}));
+
+ return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_);
+ }
+
+ void SetInputShapeType() {
+ input_types_ = "float";
+ input_shapes_ = "1,1,1,1";
+ }
+
+ void CheckGraph(int expected_node_count, int expected_cluster_count) {
+ EXPECT_EQ(expected_node_count, output_graph_def_.node_size());
+
+ int cluster_count = 0;
+ for (const NodeDef& node_def : output_graph_def_.node()) {
+ const string& name = node_def.name();
+ if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) {
+ ++cluster_count;
+ RemoteFusedGraphExecuteInfo info;
+ string serialized_proto;
+ TF_ASSERT_OK(
+ GetNodeAttr(node_def,
+ RemoteFusedGraphExecuteUtils::
+ ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
+ &serialized_proto));
+ info.ParseFromString(serialized_proto);
+ CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name());
+ }
+ }
+ EXPECT_EQ(expected_cluster_count, cluster_count);
+ }
+
+ public:
+ const std::vector<string> inputs_{"A"};
+ const std::vector<string> outputs_{"K"};
+ GraphDef input_graph_def_;
+ string input_types_;
+ string input_shapes_;
+ GraphDef output_graph_def_;
+ string fused_node_names_str_;
+ string border_inputs_str_;
+ string border_outputs_str_;
+};
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByNodesWithShapeType_HIJ) {
+ SetInputShapeType();
+ fused_node_names_str_ = "H,I,J";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(9, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByNodesWithoutShapeType_HIJ) {
+ fused_node_names_str_ = "H,I,J";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(9, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) {
+ SetInputShapeType();
+ fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(3, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) {
+ fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(3, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByBorderWithShapeType_FCG_J) {
+ SetInputShapeType();
+ border_inputs_str_ = "F:0,C:0,G";
+ border_outputs_str_ = "J:0";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(9, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByBorderWithoutShapeType_FCG_J) {
+ border_inputs_str_ = "F:0,C:0,G";
+ border_outputs_str_ = "J:0";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(9, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByBorderWithShapeType_ABCDE_K) {
+ SetInputShapeType();
+ border_inputs_str_ = "A,B,C,D,E";
+ border_outputs_str_ = "K";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(7, 1);
+}
+
+TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
+ FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) {
+ border_inputs_str_ = "A,B,C,D,E";
+ border_outputs_str_ = "K";
+ TF_ASSERT_OK(Fuse());
+ CheckGraph(7, 1);
+}
+
+} // namespace
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 89068c6d01..20f958f640 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -97,6 +97,7 @@ cc_library(
"//tensorflow/core:tensorflow",
] + if_not_windows([
"//tensorflow/core/kernels:quantized_ops",
+ "//tensorflow/core/kernels:remote_fused_graph_rewriter_transform",
"//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform",
]),
alwayslink = 1,