diff options
author | 2017-05-18 10:38:29 -0700 | |
---|---|---|
committer | 2017-05-18 10:43:12 -0700 | |
commit | 9ad851e54d014532dd3b3c8308396769f9a7aeee (patch) | |
tree | ff777c97939a6f4d9de6ebd7226b19e9ffb2ff2f /tensorflow | |
parent | 7916e22e954fc893e673f74b4088b9e9c3a9be97 (diff) |
Add graph transform rewriter for remote fused graph
PiperOrigin-RevId: 156448934
Diffstat (limited to 'tensorflow')
16 files changed, 641 insertions, 111 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index e93680fb4d..c95cd068cd 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -101,6 +101,7 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute*.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform*.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 2568ca3bad..6392a5e044 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -282,6 +282,7 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/call_options_test.cc" "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" ) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 922996a686..fa75943d78 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -29,6 +29,7 @@ cc_binary( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:transform_utils", ], diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 3a219bb3e6..6ae7c4a742 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h" +#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" @@ -46,6 +48,47 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) { return 0; } +static void SummarizeNode(const NodeDef& node_def) { + LOG(INFO) << "Node(" << node_def.name() << ")"; + LOG(INFO) << " op: " << node_def.op(); + for (const string& input : node_def.input()) { + LOG(INFO) << " Input: " << input; + } +} + +static void DumpRemoteFusedGraph(const NodeDef& node_def) { + LOG(INFO) << "Remote fused graph found."; + RemoteFusedGraphExecuteInfo info; + string serialized_proto; + GetNodeAttr(node_def, + RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, + &serialized_proto) + .IgnoreError(); + info.ParseFromString(serialized_proto); + LOG(INFO) << "Node name: " << node_def.name(); + LOG(INFO) << "Executor name: " << info.executor_name(); + for (const string& input : info.graph_input_node_name()) { + LOG(INFO) << "Input: " << input; + } + for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type : + info.default_graph_input_tensor_shape()) { + LOG(INFO) << "Input shape type: " << shape_type.DebugString(); + } + for (const string& output : info.graph_output_node_name()) { + LOG(INFO) << "Output: " << output; + } + for (const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type : + info.default_graph_output_tensor_shape()) { + LOG(INFO) << "Output shape type: " << shape_type.DebugString(); + } + const int subgraph_node_size = info.remote_graph().node_size(); + LOG(INFO) << "Nodes in the graph: " << subgraph_node_size; + for (int i = 0; i < subgraph_node_size; ++i) { + LOG(INFO) << "node(" << i << "): " << info.remote_graph().node(i).name(); + } +} + static void CheckOpsSupport(const GraphDef& graph_def) { const IGraphTransferOpsDefinitions& ops_definition = HexagonOpsDefinitions::getInstance(); @@ -53,7 +96,13 @@ static void CheckOpsSupport(const GraphDef& graph_def) { std::unordered_set<string> unsupported_ops; bool all_supported = true; + bool contains_remote_graph = false; for (const NodeDef& node : graph_def.node()) { + if (node.op() == "RemoteFusedGraphExecute") { + contains_remote_graph = true; + DumpRemoteFusedGraph(node); + continue; + } // TODO(satok): Set correct data type if it's given. const int op_id = ops_definition.GetOpIdFor(node.op(), {}); if (op_id == IGraphTransferOpsDefinitions::INVALID_OP_ID) { @@ -75,6 +124,12 @@ static void CheckOpsSupport(const GraphDef& graph_def) { } else { LOG(INFO) << count << " ops are not supported."; } + + if (contains_remote_graph) { + for (const NodeDef& node : graph_def.node()) { + SummarizeNode(node); + } + } } } // namespace diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b2eaaa3492..590d334edc 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4297,6 +4297,7 @@ filegroup( "encode_jpeg_op.*", "identity_reader_op.*", "remote_fused_graph_execute_op.*", + "remote_fused_graph_rewriter_transform.*", "fixed_length_record_reader_op.*", "whole_file_read_ops.*", "sample_distorted_bounding_box_op.*", @@ -4777,6 +4778,7 @@ cc_library( cc_library( name = "remote_fused_graph_execute_op_test_utils", + testonly = 1, srcs = ["remote_fused_graph_execute_op_test_utils.cc"], hdrs = ["remote_fused_graph_execute_op_test_utils.h"], deps = [ @@ -4785,6 +4787,7 @@ cc_library( "//tensorflow/cc:scope", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:testlib", "//tensorflow/core/kernels:cwise_op", ], ) @@ -4842,6 +4845,40 @@ tf_cc_test( ], ) +cc_library( + name = "remote_fused_graph_rewriter_transform", + srcs = [ + "remote_fused_graph_rewriter_transform.cc", + ], + deps = [ + ":remote_fused_graph_execute_utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:remote_fused_graph_ops", + "//tensorflow/core", + "//tensorflow/tools/graph_transforms:transform_utils", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "remote_fused_graph_rewriter_transform_test", + size = "small", + srcs = ["remote_fused_graph_rewriter_transform_test.cc"], + deps = [ + ":remote_fused_graph_execute_op_test_utils", + ":remote_fused_graph_execute_utils", + ":remote_fused_graph_rewriter_transform", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/tools/graph_transforms:transform_utils", + ], +) + tf_mkl_kernel_library( name = "mkl_conv_op", prefix = "mkl_conv", diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index a383cc8199..54ba101501 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -35,6 +35,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" @@ -268,8 +269,7 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) { session_options.env = Env::Default(); std::unique_ptr<Session> session = std::unique_ptr<Session>(NewSession(session_options)); - Status status = session->Create(fused_graph_def); - ASSERT_TRUE(status.ok()); + TF_ASSERT_OK(session->Create(fused_graph_def)); // Setup session arguments RunOptions run_options; @@ -283,9 +283,8 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) { LOG(INFO) << "Run graph"; // Run inference with all node as output - status = session->Run(run_options, input_tensors, output_node_names, {}, - &output_tensors, &run_metadata); - ASSERT_TRUE(status.ok()); + TF_ASSERT_OK(session->Run(run_options, input_tensors, output_node_names, {}, + &output_tensors, &run_metadata)); ASSERT_EQ(1, output_tensors.size()); const Tensor& output_tensor = output_tensors.at(0); LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes(); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc index 101d0e694b..aa3835ecc5 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc @@ -29,9 +29,10 @@ class RemoteFusedGraphExecuteOp : public OpKernel { explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx) : OpKernel(ctx), execute_info_() { string serialized_proto; - OP_REQUIRES_OK(ctx, - ctx->GetAttr("serialized_remote_fused_graph_execute_info", - &serialized_proto)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, + &serialized_proto)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_)); execute_info_.ParseFromString(serialized_proto); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc index 925af1f79e..112168b195 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc @@ -269,13 +269,13 @@ static Status RewriteGraphToFusedGraph(const GraphDef& original_graph, // 5. Fuse the original graph and run the inference the new fused graph TEST(RemoteFusedExecuteGraphOp, EndToEndTest) { // 5.1 Load original graph - const GraphDef original_graph = - RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef original_graph; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &original_graph)); // 5.2 Fuse graph GraphDef fused_graph; - TF_CHECK_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph)); + TF_ASSERT_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph)); // 5.3 Setup session std::vector<Tensor> output_tensors; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc index 6e7d4b73d2..31c48082dd 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" +#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -36,17 +39,57 @@ namespace tensorflow { return Output(ret, 0); } -/* static */ GraphDef RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( +/* static */ Status RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( const string& name0, const float val0, const string& name1, - const float val1, const string& name_out) { + const float val1, const string& name_out, GraphDef* graph_def) { Scope root = Scope::NewRootScope(); Output node0 = ops::Const(root.WithOpName(name0), val0); Output node1 = ops::Const(root.WithOpName(name1), val1); RemoteFusedGraphExecuteOpTestUtils::BuildAddOp(root.WithOpName(name_out), node0, node1); - GraphDef def; - TF_CHECK_OK(root.ToGraphDef(&def)); - return def; + TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def)); + return Status::OK(); +} + +/* static */ Status RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph( + GraphDef* graph_def) { + Scope root = tensorflow::Scope::NewRootScope(); + + Tensor a_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&a_data, 1.0f); + Output a_const = ops::Const(root.WithOpName("A"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&b_data, 1.0f); + Output b_const = ops::Const(root.WithOpName("B"), Input::Initializer(b_data)); + + Tensor c_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&c_data, 1.0f); + Output c_const = ops::Const(root.WithOpName("C"), Input::Initializer(c_data)); + + Tensor d_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&d_data, 1.0f); + Output d_const = ops::Const(root.WithOpName("D"), Input::Initializer(d_data)); + + Tensor e_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); + test::FillIota<float>(&e_data, 1.0f); + Output e_const = ops::Const(root.WithOpName("E"), Input::Initializer(e_data)); + + Output f_add = ops::Add(root.WithOpName("F"), a_const, b_const); + + Output g_add = ops::Add(root.WithOpName("G"), d_const, e_const); + + Output h_add = ops::Add(root.WithOpName("H"), f_add, c_const); + + Output i_add = ops::Add(root.WithOpName("I"), c_const, g_add); + + Output j_add = ops::Add(root.WithOpName("J"), h_add, i_add); + + Output k_add = ops::Add(root.WithOpName("K"), j_add, g_add); + + TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def)); + + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h index 70d758ea6a..a0df50162b 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h @@ -28,9 +28,31 @@ namespace tensorflow { class RemoteFusedGraphExecuteOpTestUtils { public: static Output BuildAddOp(const Scope& scope, const Input& x, const Input& y); - static GraphDef BuildAddGraph(const string& name0, const float val0, - const string& name1, const float val1, - const string& name_out); + static Status BuildAddGraph(const string& name0, const float val0, + const string& name1, const float val1, + const string& name_out, GraphDef* graph_def); + + // BuildMultipleAddGraph builds the following graph + // + // A B C D E + // | | | | | + // +----+----+ | +----+----+ + // | | | + // F / \ G + // | | | / \ + // +-----+------+ +-----+----+ + + // | | | + // H I | + // | | | + // +-------+--------+ | + // | | + // J | + // | | + // +--------+--------+ + // | + // K + // + static Status BuildMultipleAddGraph(GraphDef* graph_def); private: RemoteFusedGraphExecuteOpTestUtils() = delete; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 2174098bde..d0ffcb1064 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -104,6 +104,22 @@ string DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) { RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES; /* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES; +/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO; +/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES; +/* static */ constexpr const char* const + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES; RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar( const string& name, ExecutorBuildFunc executor_build_func) { @@ -797,6 +813,22 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt, shape, subgraph_def)); } + + // sort subgraph_def to align order in graph_def + std::unordered_map<string, int> name_to_id_map; + for (int i = 0; i < graph_def.node_size(); ++i) { + name_to_id_map.emplace(graph_def.node(i).name(), i); + } + std::sort(subgraph_def->mutable_node()->begin(), + subgraph_def->mutable_node()->end(), + [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) { + CHECK(name_to_id_map.count(node0.name()) > 0); + CHECK(name_to_id_map.count(node1.name()) > 0); + const int id0 = name_to_id_map.at(node0.name()); + const int id1 = name_to_id_map.at(node1.name()); + return id0 < id1; + }); + VLOG(1) << DumpGraphDef(*subgraph_def); return Status::OK(); } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index 97b0c2008a..3a792824c5 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -39,6 +39,25 @@ class RemoteFusedGraphExecuteUtils { // TODO(satok): Use "_output_shapes" to share a spec with other ops static constexpr const char* const ATTR_OUTPUT_SHAPES = "_default_remote_output_shapes"; + static constexpr const char* const + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO = + "serialized_remote_fused_graph_execute_info"; + + // Argument key strings to fuse a subgraph into RemoteFusedGraphExecuteOp. + static constexpr const char* const + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME = + "remote_fused_graph_executor_name"; + static constexpr const char* const + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME = + "remote_fused_graph_node_name"; + static constexpr const char* const TRANSFORM_ARG_FUSED_NODES = "fused_nodes"; + static constexpr const char* const TRANSFORM_ARG_BORDER_INPUTS = + "border_inputs"; + static constexpr const char* const TRANSFORM_ARG_BORDER_OUTPUTS = + "border_outputs"; + static constexpr const char* const TRANSFORM_ARG_INPUT_TYPES = "input_types"; + static constexpr const char* const TRANSFORM_ARG_INPUT_SHAPES = + "input_shapes"; using ExecutorBuildFunc = std::function<Status( std::unique_ptr<IRemoteFusedGraphExecutor>* executor)>; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc index 8bd63d996a..581b61a625 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc @@ -15,11 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" #include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/array_ops.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/core/common_runtime/shape_refiner.h" -#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -47,69 +43,12 @@ static NodeDef* GetNodeDef(const string& name, GraphDef* def) { return nullptr; } -// This function builds the following graph -// -// A B C D E -// | | | | | -// +----+----+ | +----+----+ -// | | | -// F / \ G -// | | | / \ -// +-----+------+ +-----+----+ + -// | | | -// H I | -// | | | -// +-------+--------+ | -// | | -// J | -// | | -// +--------+--------+ -// | -// K -// -Status BuildMultipleAddGraph(GraphDef* graph_def) { - Scope root = tensorflow::Scope::NewRootScope(); - - Tensor a_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); - test::FillIota<float>(&a_data, 1.0f); - Output a_const = ops::Const(root.WithOpName("A"), Input::Initializer(a_data)); - - Tensor b_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); - test::FillIota<float>(&b_data, 1.0f); - Output b_const = ops::Const(root.WithOpName("B"), Input::Initializer(b_data)); - - Tensor c_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); - test::FillIota<float>(&c_data, 1.0f); - Output c_const = ops::Const(root.WithOpName("C"), Input::Initializer(c_data)); - - Tensor d_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); - test::FillIota<float>(&d_data, 1.0f); - Output d_const = ops::Const(root.WithOpName("D"), Input::Initializer(d_data)); - - Tensor e_data(DT_FLOAT, TensorShape({1, 1, 1, 1})); - test::FillIota<float>(&e_data, 1.0f); - Output e_const = ops::Const(root.WithOpName("E"), Input::Initializer(e_data)); - - Output f_add = ops::Add(root.WithOpName("F"), a_const, b_const); - - Output g_add = ops::Add(root.WithOpName("G"), d_const, e_const); - - Output h_add = ops::Add(root.WithOpName("H"), f_add, c_const); - - Output i_add = ops::Add(root.WithOpName("I"), c_const, g_add); - - Output j_add = ops::Add(root.WithOpName("J"), h_add, i_add); - - Output k_add = ops::Add(root.WithOpName("K"), j_add, g_add); - - TF_RETURN_IF_ERROR(root.ToGraphDef(graph_def)); - - return Status::OK(); -} - class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test { protected: - void SetUp() final { TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def_)); } + void SetUp() final { + TF_ASSERT_OK( + RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_)); + } void TearDown() final {} @@ -196,8 +135,9 @@ static void ClearCluster(ClusterInfo* cluster) { } TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) { - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); std::pair<string, Tensor> input_node_info; input_node_info.first = NAME_A; input_node_info.second = Tensor(DT_FLOAT, {}); @@ -216,8 +156,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) { } TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) { - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); std::pair<string, Tensor> input_node_info; input_node_info.first = NAME_A; input_node_info.second = Tensor(DT_FLOAT, {}); @@ -235,8 +176,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) { } TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAB) { - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); std::pair<string, Tensor> input_node_info_a; input_node_info_a.first = NAME_A; input_node_info_a.second = Tensor(DT_FLOAT, {}); @@ -268,8 +210,9 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) { const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a}; RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); // dryrun const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( @@ -310,8 +253,9 @@ TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) { input_node_info_b}; RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); ImportGraphDefOptions opts; Graph graph(OpRegistry::Global()); ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); @@ -398,8 +342,9 @@ TEST(RemoteFusedGraphExecuteUtils, // Build outputs const std::vector<string> outputs = {NAME_A_PLUS_B}; - GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( input_tensors, /*dry_run_inference*/ true, &def)); @@ -427,8 +372,9 @@ TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) { // Build outputs const std::vector<string> outputs = {NAME_A_PLUS_B}; - const GraphDef def = RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( - NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B); + GraphDef def; + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( + NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); Graph graph(OpRegistry::Global()); ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); @@ -442,7 +388,8 @@ TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) { TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) { GraphDef graph_def; - TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def)); + TF_ASSERT_OK( + RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); ClusterInfo cluster; const std::unordered_set<string>& node_names = std::get<0>(cluster); TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( @@ -472,7 +419,8 @@ TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) { TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) { GraphDef graph_def; - TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def)); + TF_ASSERT_OK( + RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); std::vector<ClusterInfo> ci_vec; TF_ASSERT_OK( @@ -508,7 +456,8 @@ TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) { TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) { GraphDef graph_def; - TF_ASSERT_OK(BuildMultipleAddGraph(&graph_def)); + TF_ASSERT_OK( + RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); ClusterInfo cluster; GraphDef subgraph_def; @@ -556,7 +505,7 @@ TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) { EXPECT_EQ(3, subgraph_def.node_size()); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_hi_j) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_HI_J) { SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"}, this); @@ -569,7 +518,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_hi_j) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_fcg_j) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_FCG_J) { SetSubgraphArguments(std::vector<string>{"F", "C", "G"}, std::vector<string>{"J"}, this); @@ -582,7 +531,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_fcg_j) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_j) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_J) { SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"J"}, this); @@ -595,7 +544,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_j) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_k) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_K) { SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"K"}, this); @@ -608,7 +557,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_abcde_k) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_h) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_H) { subgraph_node_names_ = {"H"}; TF_ASSERT_OK(FuseByNodes()); @@ -620,7 +569,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_h) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_hij) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_HIJ) { subgraph_node_names_ = {"H", "I", "J"}; TF_ASSERT_OK(FuseByNodes()); @@ -632,7 +581,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_hij) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_cfghij) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_CFGHIJ) { subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"}; TF_ASSERT_OK(FuseByNodes()); @@ -644,7 +593,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_cfghij) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghij) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJ) { subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"}; TF_ASSERT_OK(FuseByNodes()); @@ -656,7 +605,7 @@ TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghij) { << SummarizeGraphDef(result_graph_def_); } -TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_abcdefghijk) { +TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJK) { subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K"}; diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc new file mode 100644 index 0000000000..8742214e17 --- /dev/null +++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform.cc @@ -0,0 +1,163 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wraps the hexagon rewriter in a transform so it can be used as part of the +// graph transform tool. +// A usage example, based on inception v3 model: +/* +bazel build tensorflow/tools/graph_transforms:transform_graph + + +// Specify remote graph by node names +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \ +--out_graph=\ +/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \ +--inputs='Mul' \ +--outputs='softmax' \ +--transforms='\ +fuse_remote_graph( +input_types="float" \ +input_shapes="1,299,299,3" \ +fused_nodes="NodeA,NodeB,NodeC", +remote_fused_graph_executor_name="executor" \ +remote_fused_graph_node_name="node_name" \ +)' + +// Specify remote graph by border inputs and outputs +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \ +--out_graph=\ +/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \ +--inputs='Mul' \ +--outputs='softmax' \ +--transforms='\ +fuse_remote_graph( +input_types="float" \ +input_shapes="1,299,299,3" \ +border_inputs="NodeA:0,NodeB:0" \ +border_outputs="NodeC" \ +remote_fused_graph_executor_name="executor" \ +remote_fused_graph_node_name="node_name" \ +)' +*/ + +#include <unordered_set> + +#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { +Status FuseRemoteGraph(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef mutable_input_graph_def = input_graph_def; + + const std::vector<string>& inputs = context.input_names; + const std::vector<string>& outputs = context.output_names; + + string input_types_str; + string input_shapes_str; + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, "", + &input_types_str)); + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, "", + &input_shapes_str)); + + if (!input_types_str.empty()) { + const std::vector<string> input_types_strs = + str_util::Split(input_types_str, ","); + const std::vector<string> input_shapes_strs = + str_util::Split(input_shapes_str, ":"); + CHECK_EQ(inputs.size(), input_types_strs.size()); + CHECK_EQ(inputs.size(), input_shapes_strs.size()); + std::vector<std::pair<string, Tensor>> input_tensors; + for (int i = 0; i < inputs.size(); ++i) { + const string& name = inputs.at(i); + std::vector<int64> dims; + CHECK(str_util::SplitAndParseAsInts(input_shapes_strs.at(i), ',', &dims)); + DataType data_type; + CHECK(DataTypeFromString(input_types_strs.at(i), &data_type)) + << "\"" << input_types_strs.at(i) << "\" was an invalid type"; + input_tensors.emplace_back( + std::make_pair(name, Tensor(data_type, TensorShape(dims)))); + } + TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( + input_tensors, /*dry_run_inference=*/true, &mutable_input_graph_def)); + } + + string fused_nodes_str; + string border_inputs_str; + string border_outputs_str; + string remote_graph_executor_name; + string remote_fused_graph_node_name; + + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, "", + &fused_nodes_str)); + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, "", + &border_inputs_str)); + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, "", + &border_outputs_str)); + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils:: + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, + "", &remote_graph_executor_name)); + TF_RETURN_IF_ERROR(context.GetOneStringParameter( + RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME, + "", &remote_fused_graph_node_name)); + + CHECK(!remote_graph_executor_name.empty()); + + const bool require_shape_type = !input_types_str.empty(); + if (!fused_nodes_str.empty()) { + const std::vector<string>& fused_node_name_vector = + str_util::Split(fused_nodes_str, ","); + const std::unordered_set<string> fused_node_names( + fused_node_name_vector.begin(), fused_node_name_vector.end()); + TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames( + mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, + fused_node_names, remote_graph_executor_name, require_shape_type, + output_graph_def)); + } else if (!border_inputs_str.empty() && !border_outputs_str.empty()) { + const std::vector<string> border_inputs = + str_util::Split(border_inputs_str, ","); + const std::vector<string> border_outputs = + str_util::Split(border_outputs_str, ","); + for (int i = 0; i < border_inputs.size(); ++i) { + VLOG(2) << "Border Input(" << i << "): " << border_inputs.at(i); + } + for (int i = 0; i < border_outputs.size(); ++i) { + VLOG(2) << "Border Output(" << i << "): " << border_outputs.at(i); + } + TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder( + mutable_input_graph_def, inputs, outputs, remote_fused_graph_node_name, + border_inputs, border_outputs, remote_graph_executor_name, + require_shape_type, output_graph_def)); + } else { + CHECK(false) << "Fuse targets are not specified."; + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("fuse_remote_graph", FuseRemoteGraph); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc new file mode 100644 index 0000000000..9e061437a9 --- /dev/null +++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" +#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declared here so we don't have to put it in a public header. +Status FuseRemoteGraph(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +namespace { + +constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME = + "remote_fused_graph_executor_name"; +constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME = + "remote_fused_graph_node_name"; + +class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test { + protected: + void SetUp() final { + TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph( + &input_graph_def_)); + } + + void TearDown() final {} + + Status Fuse() { + TransformFuncContext context; + context.input_names = inputs_; + context.output_names = outputs_; + + if (!input_types_.empty()) { + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, + {input_types_}})); + } + if (!input_shapes_.empty()) { + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, + {input_shapes_}})); + } + if (!fused_node_names_str_.empty()) { + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, + {fused_node_names_str_}})); + } + + if (!border_inputs_str_.empty()) { + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, + {border_inputs_str_}})); + } + if (!border_outputs_str_.empty()) { + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, + {border_outputs_str_}})); + } + + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils:: + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, + {REMOTE_FUSED_GRAPH_EXECUTOR_NAME}})); + context.params.insert(std::pair<string, std::vector<string>>( + {RemoteFusedGraphExecuteUtils:: + TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME, + {REMOTE_FUSED_GRAPH_NODE_NAME}})); + + return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_); + } + + void SetInputShapeType() { + input_types_ = "float"; + input_shapes_ = "1,1,1,1"; + } + + void CheckGraph(int expected_node_count, int expected_cluster_count) { + EXPECT_EQ(expected_node_count, output_graph_def_.node_size()); + + int cluster_count = 0; + for (const NodeDef& node_def : output_graph_def_.node()) { + const string& name = node_def.name(); + if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) { + ++cluster_count; + RemoteFusedGraphExecuteInfo info; + string serialized_proto; + TF_ASSERT_OK( + GetNodeAttr(node_def, + RemoteFusedGraphExecuteUtils:: + ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, + &serialized_proto)); + info.ParseFromString(serialized_proto); + CHECK_EQ(REMOTE_FUSED_GRAPH_EXECUTOR_NAME, info.executor_name()); + } + } + EXPECT_EQ(expected_cluster_count, cluster_count); + } + + public: + const std::vector<string> inputs_{"A"}; + const std::vector<string> outputs_{"K"}; + GraphDef input_graph_def_; + string input_types_; + string input_shapes_; + GraphDef output_graph_def_; + string fused_node_names_str_; + string border_inputs_str_; + string border_outputs_str_; +}; + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByNodesWithShapeType_HIJ) { + SetInputShapeType(); + fused_node_names_str_ = "H,I,J"; + TF_ASSERT_OK(Fuse()); + CheckGraph(9, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByNodesWithoutShapeType_HIJ) { + fused_node_names_str_ = "H,I,J"; + TF_ASSERT_OK(Fuse()); + CheckGraph(9, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) { + SetInputShapeType(); + fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K"; + TF_ASSERT_OK(Fuse()); + CheckGraph(3, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) { + fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K"; + TF_ASSERT_OK(Fuse()); + CheckGraph(3, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByBorderWithShapeType_FCG_J) { + SetInputShapeType(); + border_inputs_str_ = "F:0,C:0,G"; + border_outputs_str_ = "J:0"; + TF_ASSERT_OK(Fuse()); + CheckGraph(9, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByBorderWithoutShapeType_FCG_J) { + border_inputs_str_ = "F:0,C:0,G"; + border_outputs_str_ = "J:0"; + TF_ASSERT_OK(Fuse()); + CheckGraph(9, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByBorderWithShapeType_ABCDE_K) { + SetInputShapeType(); + border_inputs_str_ = "A,B,C,D,E"; + border_outputs_str_ = "K"; + TF_ASSERT_OK(Fuse()); + CheckGraph(7, 1); +} + +TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, + FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) { + border_inputs_str_ = "A,B,C,D,E"; + border_outputs_str_ = "K"; + TF_ASSERT_OK(Fuse()); + CheckGraph(7, 1); +} + +} // namespace +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 89068c6d01..20f958f640 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -97,6 +97,7 @@ cc_library( "//tensorflow/core:tensorflow", ] + if_not_windows([ "//tensorflow/core/kernels:quantized_ops", + "//tensorflow/core/kernels:remote_fused_graph_rewriter_transform", "//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform", ]), alwayslink = 1, |