aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 14:57:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 14:58:02 -0700
commit2a88a957ff19fdf9f20c3c5e98d9a3c7bde79a4f (patch)
treee8585a1e290bbb1212e3dff1b7f931e221a99f68 /tensorflow
parent5acd07ada19a280edaaa68ae68c1feef54b9bfae (diff)
parent482b056d3ba925f52ccad8e7166a81120f43a761 (diff)
Merge pull request #20755 from aaroey:segmentor_fix
PiperOrigin-RevId: 205728147
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD6
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc39
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc168
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h26
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc188
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h20
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc473
7 files changed, 504 insertions, 416 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 08b267c11a..2fe1f2c242 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -314,11 +314,15 @@ tf_cc_test(
],
deps = [
":segment",
- "//tensorflow/c:c_api",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 68c78e8301..3383f6bc9b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -301,7 +301,8 @@ tensorflow::Status GetEngineInfo(
const int node_id = node->id();
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
- if (segment_nodes.count(input_node->name()) == 0) {
+ if (segment_nodes.count(input_node->name()) == 0 &&
+ !edge->IsControlEdge() && !input_node->IsSource()) {
// Add constant input node into the segment. We don't care if it has
// other output edges going into other engines or TF nodes. Since we add
// it only to the subsegment node list, not the subsegment itself, it
@@ -312,7 +313,7 @@ tensorflow::Status GetEngineInfo(
added_const_node_ids.insert(input_node->id());
subgraph_node_ids.push_back(input_node->id());
}
- } else if (!edge->IsControlEdge() && !input_node->IsSource()) {
+ } else {
string s(input_node->name());
StrAppend(&s, ":", edge->src_output());
VLOG(1) << "Input edge = " << s;
@@ -378,9 +379,9 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
nvinfer1::IGpuAllocator* alloc,
int max_batch_size) {
const auto& info = infos.at(pos);
- std::vector<tensorflow::TensorShapeProto> out_shapes;
- std::vector<tensorflow::TensorShapeProto> input_shapes;
- std::vector<tensorflow::PartialTensorShape> shapes;
+ std::vector<tensorflow::TensorShapeProto> output_shape_protos;
+ std::vector<tensorflow::TensorShapeProto> input_shape_protos;
+ std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
std::vector<tensorflow::DataType> out_types;
VLOG(1) << "Processing " << info.engine_name;
@@ -393,11 +394,11 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
tensorflow::TensorShapeProto out_shape;
// shape of the output node inside segment
conn.inside_shape.AsProto(&out_shape);
- if (out_shapes.size() <= conn.port_number) {
- out_shapes.resize(conn.port_number + 1);
+ if (output_shape_protos.size() <= conn.port_number) {
+ output_shape_protos.resize(conn.port_number + 1);
out_types.resize(conn.port_number + 1);
}
- out_shapes.at(conn.port_number) = out_shape;
+ output_shape_protos.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
continue;
}
@@ -405,12 +406,12 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Set the shapes and data types of input edge.
tensorflow::TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
- if (input_shapes.size() <= conn.port_number) {
+ if (input_shape_protos.size() <= conn.port_number) {
+ input_shape_protos.resize(conn.port_number + 1);
input_shapes.resize(conn.port_number + 1);
- shapes.resize(conn.port_number + 1);
}
- input_shapes.at(conn.port_number) = in_shape;
- shapes.at(conn.port_number) = conn.outside_shape;
+ input_shape_protos.at(conn.port_number) = in_shape;
+ input_shapes.at(conn.port_number) = conn.outside_shape;
string input_node = conn.outside_node_name;
int input_port = conn.outside_port;
@@ -438,6 +439,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
<< info.engine_name << ":" << inputs.size();
// Skip duplicate inputs.
+ // TODO(aaroey): use std::find instead. GetEngineInfo already remove
+ // duplicate connections, so here we should never find any duplicate?
bool new_input = true;
for (const auto& inp : inputs) {
if (inp.node == input_node && inp.index == input_port) {
@@ -465,8 +468,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
- max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
- alloc, /*calibrator=*/nullptr, &engine,
+ max_batch_size, info.max_workspace_size_bytes, input_shapes,
+ &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
@@ -514,8 +517,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
}
tensorflow::NodeDef trt_node;
tensorflow::Status status =
- node_builder.Attr("input_shapes", input_shapes)
- .Attr("output_shapes", out_shapes)
+ node_builder.Attr("input_shapes", input_shape_protos)
+ .Attr("output_shapes", output_shape_protos)
.Attr("static_engine",
info.engine_type == EngineInfo::EngineType::TRTStatic)
.Attr("segment_funcdef_name",
@@ -734,6 +737,7 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
}
// Entry function from optimization pass.
+// TODO(aaeory): parameter should use pointer type.
tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
// Convert graphdef to graph.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
@@ -751,7 +755,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
segment_options.minimum_segment_size = params.minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
- &graph, IsTensorRTCandidate, segment_options, &initial_segments));
+ &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties),
+ OutputEdgeValidator(), segment_options, &initial_segments));
if (initial_segments.size() > 1) {
VLOG(0) << "MULTIPLE tensorrt candidate conversion: "
<< initial_segments.size();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 49e825151a..451d6fe698 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include <algorithm>
+#include <cstring>
#include <list>
#include <map>
#include <memory>
@@ -77,7 +78,6 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::tensorflow::str_util::Split;
-
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -107,6 +107,59 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
return tensorflow::Status::OK();
}
+void GetInputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int out_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasOutputProperties(outside_node->name())) {
+ auto output_params =
+ graph_properties.GetOutputProperties(outside_node->name());
+ auto out_shape = output_params.at(out_port);
+ *dtype = out_shape.dtype();
+ *shape = out_shape.shape();
+ } else {
+ VLOG(0) << "Unknown output shape" << outside_node->name();
+ *dtype = outside_node->output_type(out_port);
+ }
+}
+
+void GetOutputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int in_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasInputProperties(outside_node->name())) {
+ auto input_params =
+ graph_properties.GetInputProperties(outside_node->name());
+ auto in_shape = input_params.at(in_port);
+ *dtype = in_shape.dtype();
+ *shape = in_shape.shape();
+ } else {
+ *dtype = outside_node->input_type(in_port);
+ }
+}
+
+tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
+ const tensorflow::DataType dtype,
+ nvinfer1::DataType* trt_dtype) {
+ // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so
+ // put them there instead.
+ TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
+ if (shape.dims() < 0) {
+ return tensorflow::errors::InvalidArgument("Input tensor rank is unknown.");
+ }
+ if (shape.dims() > 9) {
+ return tensorflow::errors::OutOfRange(
+ "Input tensor rank is greater than 8.");
+ }
+ for (int d = 1; d < shape.dims(); ++d) {
+ if (shape.dim_size(d) < 0) {
+ return tensorflow::errors::InvalidArgument(
+ "Input tensor has a unknown non-batch dimemension at dim ", d);
+ }
+ }
+ return Status::OK();
+}
+
// Return whether or not the broadcast is feasible;
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
const bool operand_l_is_tensor,
@@ -2642,25 +2695,22 @@ tensorflow::Status ConvertGraphDefToEngine(
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
- nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- auto type_status =
- ConvertDType(node_def.attr().at("dtype").type(), &dtype);
- if (type_status != tensorflow::Status::OK()) {
- LOG(WARNING) << "Type conversion failed for " << node_name;
- return type_status;
- }
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +8= " << node_name.c_str() + 8;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kInputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
+ nvinfer1::DataType dtype;
auto shape = input_shapes.at(slot_number);
- if (shape.dims() > 8) {
- LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
- << " at input slot " << slot_number;
- return tensorflow::errors::OutOfRange(
- "Input tensor rank is greater than 8");
+ auto status = ValidateInputProperties(
+ shape, node_def.attr().at("dtype").type(), &dtype);
+ if (!status.ok()) {
+ const string error_message =
+ StrCat("Validation failed for ", node_name, " and input slot ",
+ slot_number, ": ", status.error_message());
+ LOG(WARNING) << error_message;
+ return Status(status.code(), error_message);
}
if (VLOG_IS_ON(1)) {
string dim_str("dims=");
@@ -2691,10 +2741,10 @@ tensorflow::Status ConvertGraphDefToEngine(
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +9=" << node_name.c_str() + 9;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1);
@@ -2753,38 +2803,20 @@ tensorflow::Status ConvertSegmentToGraphDef(
"Cannot find node with id ", connection.outside_id, " in the graph.");
}
// Updates the shape and data types of input/output connections.
- tensorflow::DataType input_type = tensorflow::DT_FLOAT;
+ tensorflow::DataType dtype;
tensorflow::PartialTensorShape partial_shape;
if (connection.is_input_edge) {
- if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
- auto output_params =
- graph_properties.GetOutputProperties(connection.outside_node_name);
- auto out_shape = output_params.at(connection.outside_port);
- input_type = out_shape.dtype();
- std::vector<tensorflow::int64> dims;
- partial_shape = out_shape.shape();
- connection.outside_shape = partial_shape;
- } else {
- VLOG(0) << "Unknown output shape" << outside_node->name();
- input_type = graph->FindNodeId(connection.outside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
-
- } else { // output edge
- if (graph_properties.HasInputProperties(connection.outside_node_name)) {
- auto input_params =
- graph_properties.GetInputProperties(connection.outside_node_name);
- auto in_shape = input_params.at(connection.outside_port);
- input_type = in_shape.dtype();
- partial_shape = in_shape.shape();
- connection.inside_shape = partial_shape;
- } else {
- input_type = graph->FindNodeId(connection.inside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
+ GetInputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
+
+ } else {
+ GetOutputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
}
+ connection.outside_shape = partial_shape;
+ connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
@@ -2800,7 +2832,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
auto status = builder.Attr("shape", partial_shape)
- .Attr("dtype", input_type)
+ .Attr("dtype", dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
@@ -2818,7 +2850,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Identity");
- auto status = builder.Input(connection.inside_node_name, 0, input_type)
+ auto status = builder.Input(connection.inside_node_name, 0, dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
@@ -2856,6 +2888,38 @@ tensorflow::Status ConvertSegmentToGraphDef(
return tensorflow::Status::OK();
}
+bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
+ if (in_edge->IsControlEdge()) return true;
+ PartialTensorShape shape;
+ tensorflow::DataType dtype;
+ GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(),
+ &shape, &dtype);
+ nvinfer1::DataType trt_dtype;
+ Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
+ if (!status.ok()) {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << ": " << status;
+ return false;
+ }
+ if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << " which has an input at port " << in_edge->dst_input()
+ << " with #dim<3 and is not a const: " << shape;
+ return false;
+ }
+ return true;
+}
+
+bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
+ if (out_edge->IsControlEdge()) return true;
+ if (out_edge->src()->type_string() == "Const") {
+ VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
+ << " which is a Const.";
+ return false;
+ }
+ return true;
+}
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 81baf8e7c1..6ae60ec352 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -105,6 +105,8 @@ struct EngineInfo {
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
// sorted in topological order.
+//
+// TODO(aaroey): add tests to validate these properties.
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
@@ -129,6 +131,30 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully);
+// Helper class for the segmenter to determine whether an input edge to the TRT
+// segment is valid.
+class InputEdgeValidator {
+ public:
+ InputEdgeValidator(const grappler::GraphProperties& graph_properties)
+ : graph_properties_(graph_properties) {}
+
+ // Return true if the specified edge is eligible to be an input edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* in_edge) const;
+
+ private:
+ const grappler::GraphProperties& graph_properties_;
+};
+
+// Helper class for the segmenter to determine whether an output edge from the
+// TRT segment is valid.
+class OutputEdgeValidator {
+ public:
+ // Return true if the specified edge is eligible to be an output edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* out_edge) const;
+};
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index cc42913eca..008fffc954 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include <queue>
#include <set>
#include <unordered_map>
#include <vector>
@@ -32,6 +33,7 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
using ::tensorflow::strings::StrAppend;
+
// A simple graph representation to mirror tensorflow::Graph. This structure
// helps saving memory since segmenter modifies the graph in place, preventing
// the need to create a copy of the graph. It is composed of edges and nodes.
@@ -215,7 +217,7 @@ namespace {
bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
const std::vector<SimpleNode*>& start) {
- // copied from TF ReverseDFS.
+ // Copied from TF ReverseDFS, which only works for tensorflow::Graph.
struct Work {
SimpleNode* node;
bool leave; // Are we entering or leaving n?
@@ -269,6 +271,24 @@ bool CanContractEdge(const SimpleEdge* edge,
// 1. Get all nodes incoming to 'dst', excluding 'src'
// 2. Reverse DFS from those nodes
// 3. If reverse DFS reaches 'src' then we have a cycle
+ //
+ // TODO(aaroey): there are several problems with the current approach:
+ // 1. src->dst->src, this is not detected but it should be;
+ // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
+ // is detected but it should not be.
+ //
+ // Note that it's fine that dst connects back to src indirectly (i.e. through
+ // a path with length > 1 that consists of intermedia nodes other than src).
+ // While loops is one example.
+ //
+ // The goal is to make sure that the trt subgraph:
+ // 1. has no loops (i.e. is a DAG), and
+ // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
+ // in the subgraph), then all paths from X to Y are in the subgraph.
+ //
+ // To achieve this goal, the correct way seems to be:
+ // 1. remove any direct edge from src->dst;
+ // 2. detect if src can reach dst, if so they cannot be merged.
std::vector<SimpleNode*> dfs_start_nodes;
for (SimpleNode* node : dst->in_nodes()) {
if (node != src) {
@@ -276,8 +296,8 @@ bool CanContractEdge(const SimpleEdge* edge,
}
}
- bool is_cycle = CheckCycles(graph, src, dfs_start_nodes);
- return !is_cycle;
+ const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes);
+ return !has_cycle;
}
} // namespace
@@ -342,22 +362,20 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
}
tensorflow::Status SegmentGraph(
- const tensorflow::GraphDef& gdef,
- const std::function<bool(const tensorflow::Node*)>& candidate_fn,
- const SegmentOptions& options, SegmentNodesVector* segments) {
- // Create a Graph representation of the GraphDef.
- tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
- gdef.library());
- tensorflow::Graph graph(flib);
- TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
- tensorflow::GraphConstructorOptions(), gdef, &graph));
- return SegmentGraph(&graph, candidate_fn, options, segments);
-}
-
-tensorflow::Status SegmentGraph(
- tensorflow::Graph* tf_graph,
+ const tensorflow::Graph* tf_graph,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments) {
+ // Steps:
+ // 1. run the segmentation algorithm to find all the segments, which uses
+ // candidate_fn to determine the candidates segment nodes;
+ // 2. for each segments, remove the nodes that are inputs/outputs of the
+ // segment but are not eligible, using input/output_candidate_fn to
+ // determine the eligibilities;
+ // 3. convert the segment into expected return format and return the result.
+
+ // --------------------------------- Step 1 ---------------------------------
auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
// Use a union-find to collect the nodes that belong to the same
// segment. A node value of nullptr indicates that the node is not a candidate
@@ -372,14 +390,19 @@ tensorflow::Status SegmentGraph(
node_segments.emplace_back(node);
}
- // The segmentation algorithm below visits nodes in reverse
- // topological order and attempts to merge nodes along output
- // edges. That means that subgraphs grow from the output-side of the
- // network towards the inputs. In general this is not guaranteed to
- // produce a globally optimal segmentation. In the future if we have
- // a measure of how beneficial it is to include a given node in a
- // TRT subgraph then we can revisit this algorithm to take advantage
- // of that information.
+ // The segmentation algorithm below visits nodes in reverse topological order
+ // and attempts to merge nodes along output edges. That means that subgraphs
+ // grow from the output-side of the network towards the inputs.
+ //
+ // In general this is not guaranteed to produce a globally optimal
+ // segmentation. For exaample, consider graph with node {A, B, C, D} and edges
+ // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
+ // in theory we can choose to contract either A, B or B, D but not both, but
+ // here it always choose to contract B, D.
+ //
+ // In the future if we have a measure of how beneficial it is to include a
+ // given node in a TRT subgraph then we can revisit this algorithm to take
+ // advantage of that information.
std::vector<tensorflow::Node*> tforder;
tensorflow::GetPostOrder(*tf_graph, &tforder);
// use postorder implementation from tensorflow and construct mirror in
@@ -392,13 +415,11 @@ tensorflow::Status SegmentGraph(
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
-
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
VLOG(2) << "... not a TRT candidate";
continue;
}
-
// Contract output edges to combine 'node' with output
// nodes. Iterate since combining two nodes may unblock other
// combining.
@@ -416,7 +437,6 @@ tensorflow::Status SegmentGraph(
VLOG(2) << "... ... not a TRT candidate";
continue;
}
-
if (CanContractEdge(out_edge, graph)) {
VLOG(2) << "... ... can contract";
contract_edges.insert(out_edge);
@@ -424,11 +444,9 @@ tensorflow::Status SegmentGraph(
VLOG(2) << "... ... cannot contract, would form cycle";
}
}
-
if (contract_edges.empty()) {
break;
}
-
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
@@ -457,11 +475,22 @@ tensorflow::Status SegmentGraph(
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
- std::unordered_map<string, std::set<string>> sg_map;
+
+ // A map from the segment identifier (currently the name of the root node of
+ // the segment tree) to the segment nodes set.
+ std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map;
+
+ // A map from the segment identifier (currently the name of the root node of
+ // the segment tree) to the device names that the nodes in the segment are
+ // assigned to.
+ //
+ // TODO(aaroey): nodes assigned to different devices should not be merged,
+ // fix this.
std::unordered_map<string, std::set<string>> device_maps;
+
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
- sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
auto tf_node = u.Value()->tf_node();
// has_assigned_device_name() is expected to return true
// when called from optimization pass. However, since graph
@@ -482,25 +511,104 @@ tensorflow::Status SegmentGraph(
}
}
+ // --------------------------------- Step 2 ---------------------------------
+ // Remove ineligible input/output nodes.
+ for (auto& itr : sg_map) {
+ std::set<const tensorflow::Node*>& segment_nodes = itr.second;
+ VLOG(1) << "Segment original size: " << segment_nodes.size();
+ while (true) {
+ std::deque<const tensorflow::Node*> in_nodes_que, out_nodes_que;
+ // Find an input node that is not eligible and add it to the queue.
+ // Nodes that has no incoming edges should not be treated as "input",
+ // as there are really no inputs to them. Similar for output nodes.
+ for (auto node : segment_nodes) {
+ bool added = false;
+ for (const tensorflow::Edge* edge : node->in_edges()) {
+ if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
+ !segment_nodes.count(edge->src())) { // 'node' is an input node.
+ if (!input_candidate_fn(edge)) {
+ in_nodes_que.push_back(node);
+ added = true;
+ break;
+ }
+ }
+ }
+ if (added) continue; // Only adding the node once to either queue.
+ for (const tensorflow::Edge* edge : node->out_edges()) {
+ if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
+ !segment_nodes.count(edge->dst())) { // 'node' is an output node.
+ if (!output_candidate_fn(edge)) {
+ out_nodes_que.push_back(node);
+ break;
+ }
+ }
+ }
+ }
+ if (in_nodes_que.empty() && out_nodes_que.empty()) {
+ // No more ineligible input/output nodes.
+ break;
+ }
+ // Now for each ineligible node, remove all of its inputs or outputs from
+ // the subgraph.
+ //
+ // It can be proven that, if the original subgraph:
+ // 1. is a DAG, and
+ // 2. all paths between two nodes in the subgraph are all inside the
+ // subgraph
+ // then after doing this operation the resulting subgraph will keep the
+ // same properties 1 and 2.
+ //
+ // For simplicity we use heuristics: for input nodes remove all its
+ // input, for output nodes remove all its output. In this way, for common
+ // cases the number of removed nodes should be minimum.
+ auto remove_nodes = [&segment_nodes](
+ bool is_input_nodes,
+ std::deque<const tensorflow::Node*>* que) {
+ // Run a BFS on the queue to find all the input/output nodes.
+ std::set<const tensorflow::Node*> visited;
+ while (!que->empty()) {
+ auto node = que->front();
+ que->pop_front();
+ if (!visited.insert(node).second) continue;
+ segment_nodes.erase(node);
+ for (auto in :
+ is_input_nodes ? node->in_nodes() : node->out_nodes()) {
+ if (segment_nodes.count(in)) {
+ que->push_back(in);
+ VLOG(2) << "Need to remove node " << in->name()
+ << " because one of its "
+ << (is_input_nodes ? "output" : "input")
+ << " nodes in the graph was removed: " << node->name();
+ }
+ }
+ }
+ };
+ remove_nodes(true, &in_nodes_que);
+ remove_nodes(false, &out_nodes_que);
+ }
+ VLOG(1) << "Segment new size: " << segment_nodes.size();
+ }
+
+ // --------------------------------- Step 3 ---------------------------------
// Convert the segments into the expected return format
for (const auto& itr : sg_map) {
- const auto& segment_node_names = itr.second;
+ const std::set<const tensorflow::Node*>& segment_nodes = itr.second;
if (VLOG_IS_ON(1)) {
string s;
- for (const auto& name : segment_node_names) {
- s += " " + name;
- }
- VLOG(1) << "Segment " << segments->size() << ":" << s;
+ for (auto node : segment_nodes) s += " " + node->name();
+ VLOG(1) << "Segment " << segments->size() << ": " << s;
}
// Don't use small segments.
- if (static_cast<int>(segment_node_names.size()) <
- options.minimum_segment_size) {
+ if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) {
VLOG(1) << "Segment " << segments->size() << " has only "
- << segment_node_names.size() << " nodes, dropping";
+ << segment_nodes.size() << " nodes, dropping";
continue;
}
+
// TODO(sami): Make segmenter placement aware once trtscopes are in place
+ std::set<string> segment_node_names;
+ for (auto node : itr.second) segment_node_names.insert(node->name());
const auto& dev_itr = device_maps.find(itr.first);
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
VLOG(1) << "No device assigned to segment " << segments->size();
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 81b4bfe49f..8c44eb782a 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -42,22 +42,6 @@ struct SegmentOptions {
// Get the subgraphs of a graph that can be handled by TensorRT.
//
-// @param gdef The GraphDef describing the network
-// @param candidate_fn A function that returns true for a NodeDef if
-// that node can be handled by TensorRT.
-// @param segments Returns the TensorRT segments/subgraphs. Each entry
-// in the vector describes a subgraph by giving a set of the names of
-// all the NodeDefs in that subgraph.
-// @return the status.
-//
-// TODO(aaroey): remove this method.
-tensorflow::Status SegmentGraph(
- const tensorflow::GraphDef& gdef,
- const std::function<bool(const tensorflow::Node*)>& candidate_fn,
- const SegmentOptions& options, SegmentNodesVector* segments);
-
-// Get the subgraphs of a graph that can be handled by TensorRT.
-//
// @param graph tensorflow::Graph of the network
// @param candidate_fn A function that returns true for a Node* if
// that node can be handled by TensorRT.
@@ -66,8 +50,10 @@ tensorflow::Status SegmentGraph(
// all the NodeDefs in that subgraph.
// @return the status.
tensorflow::Status SegmentGraph(
- tensorflow::Graph* tf_graph,
+ const tensorflow::Graph* tf_graph,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
} // namespace segment
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index f5b2d258d7..432e7b1c04 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -14,350 +14,245 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/segment/segment.h"
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/core/framework/graph.pb.h"
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace tensorrt {
namespace segment {
namespace test {
+namespace ops = ::tensorflow::ops;
class SegmentTest : public ::testing::Test {
- public:
- bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
-
- TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
- TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name);
-
- std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
- const std::set<string>& node_names);
-
protected:
- void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
- TF_Operation** op);
- void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name, TF_Operation** op, bool check);
-
- SegmentOptions default_options_;
-};
-
-bool SegmentTest::GetGraphDef(TF_Graph* graph,
- tensorflow::GraphDef* graph_def) {
- TF_Status* s = TF_NewStatus();
- TF_Buffer* buffer = TF_NewBuffer();
- TF_GraphToGraphDef(graph, buffer, s);
- bool ret = TF_GetCode(s) == TF_OK;
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
- TF_DeleteBuffer(buffer);
- TF_DeleteStatus(s);
- return ret;
-}
+ std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Node* node) -> bool {
+ return node_names.find(node->name()) != node_names.end();
+ };
+ }
-std::function<bool(const tensorflow::Node*)> SegmentTest::MakeCandidateFn(
- const std::set<string>& node_names) {
- return [node_names](const tensorflow::Node* node) -> bool {
- return node_names.find(node->name()) != node_names.end();
- };
-}
+ std::function<bool(const tensorflow::Edge*)> MakeInputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* in_edge) -> bool {
+ return node_names.find(in_edge->dst()->name()) != node_names.end();
+ };
+ }
-void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
- const char* name, TF_Operation** op) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
- TF_SetAttrType(desc, "dtype", TF_INT32);
- *op = TF_FinishOperation(desc, s);
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- ASSERT_NE(*op, nullptr);
-}
+ std::function<bool(const tensorflow::Edge*)> MakeOutputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* out_edge) -> bool {
+ return node_names.find(out_edge->src()->name()) != node_names.end();
+ };
+ }
-TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
- const char* name) {
- TF_Operation* op;
- PlaceholderHelper(graph, s, name, &op);
- return op;
-}
+ void RunTest(const tensorflow::Graph* graph,
+ const std::set<string>& candidates,
+ const std::set<string>& input_candidates,
+ const std::set<string>& output_candidates,
+ const std::vector<std::set<string>>& expected_segments) {
+ SegmentNodesVector segments;
+ TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
+ MakeInputEdgeCandidateFn(input_candidates),
+ MakeOutputEdgeCandidateFn(output_candidates),
+ default_options_, &segments));
+ ValidateSegment(segments, expected_segments);
+ }
-void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name, TF_Operation** op,
- bool check) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
- TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
- TF_AddInputList(desc, add_inputs, 2);
- *op = TF_FinishOperation(desc, s);
- if (check) {
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- ASSERT_NE(*op, nullptr);
+ void ValidateSegment(const SegmentNodesVector& segments,
+ const std::vector<std::set<string>>& expected_segments) {
+ EXPECT_EQ(expected_segments.size(), segments.size());
+ for (int i = 0; i < segments.size(); ++i) {
+ const auto& segment_node_names = segments[i].first;
+ const auto& expected = expected_segments[i];
+ for (const auto& name : expected) {
+ EXPECT_TRUE(segment_node_names.count(name))
+ << "Segment " << i << " is missing expected node: " << name;
+ }
+ if (segment_node_names.size() == expected.size()) continue;
+ for (const auto& name : segment_node_names) {
+ EXPECT_TRUE(expected.count(name))
+ << "Unexpected node found in segment " << i << ": " << name;
+ }
+ }
}
-}
-TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
- TF_Graph* graph, TF_Status* s,
- const char* name) {
- TF_Operation* op;
- AddHelper(l, r, graph, s, name, &op, true);
- return op;
+ SegmentOptions default_options_;
+};
+
+std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
+ std::set<string> result = lhs;
+ CHECK(result.erase(rhs));
+ return result;
}
TEST_F(SegmentTest, Empty) {
- TF_Graph* graph = TF_NewGraph();
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
- tensorflow::Status::OK());
-
+ Scope s = Scope::NewRootScope();
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
// Expect no segments/subgraphs.
- EXPECT_TRUE(segments.empty());
- TF_DeleteGraph(graph);
+ RunTest(&g, {}, {}, {}, {});
}
TEST_F(SegmentTest, Simple) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
// feed
- // // ||
+ // // \\
// add0 add1
- // | | /
+ // | \ /
// | add2
- // | / ||
+ // | / \\
// add3 add4
- // | /
+ // \ /
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect all Add operations to be collapsed into a single segment
- ASSERT_EQ(segments.size(), 1);
- std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
- for (const auto& ex : expected) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // All Add operations are candidates, and we expect all of them to be
+ // collapsed into a single segment
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
+ RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
+
+ // Make add1 not a candidate, and we expect all other Add operations to be
+ // collapsed into a single segment
+ auto without_add1 = all_adds - "add1";
+ RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
+
+ // Make add1 not a candidate and add2 not an input candidate, and we expect
+ // add0 and add2 are removed from the segment.
+ auto without_add2 = all_adds - "add2";
+ RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
+
+ // Making add2 not an input candidate itself won't affect anything.
+ RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
+
+ // Making add1 not an input candidate.
+ RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
+
+ // Making add3 not an output candidate doesn't affect anything, since it's
+ // output is sink.
+ auto without_add3 = all_adds - "add3";
+ RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
}
TEST_F(SegmentTest, AvoidCycle) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add2 is not a TRT candidate so add0/add3 cannot be formed as a
- // subgraph
- //
// feed
- // // ||
+ // // \\
// add0 add1
- // | | /
+ // | \ /
// | add2
- // | / ||
+ // | / \\
// add3 add4
- // | /
+ // \ /
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect no subgraphs
- EXPECT_EQ(segments.size(), 0);
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // add2 is not a TRT candidate so there should be no segments generated.
+ const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
+ RunTest(&g, without_add2, without_add2, without_add2, {});
}
TEST_F(SegmentTest, Multiple) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add5 is not a TRT candidate so two subgraphs should be formed
- //
- // feed
- // // || ||
- // add0 add1 add7
- // | | / / ||
- // | add2-----add5 add8
- // | / | | | |
- // add3 add4 add6
- // | | /
- // <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
- TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add2", "add3",
- "add4", "add6", "add7", "add8"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect two subgraphs
- EXPECT_EQ(segments.size(), 2);
-
- std::vector<string> expected0{"add6", "add8"};
- for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
-
- std::vector<string> expected1{"add0", "add1", "add2", "add3"};
- for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ // feed
+ // // || \\
+ // add0 add1 add7
+ // | \ / / \\
+ // | add2 / \\
+ // | || \ | ||
+ // | || add5 add8
+ // | / \ / \ /
+ // add3 add4 add6
+ // \ | /
+ // <sink>
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
+ auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
+ auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
+ "add5", "add6", "add7", "add8"};
+ // Make add5 not a TRT candidate, and we expect two segments.
+ auto without_add5 = all_adds - "add5";
+ RunTest(&g, without_add5, without_add5, without_add5,
+ {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}});
+
+ // Make add8 not a candidate and add6 not an input candidate, then all direct
+ // and indirect inputs of add6 will be removed from the segment.
+ auto without_add8 = all_adds - "add8";
+ auto without_add6 = all_adds - "add6";
+ RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
+
+ // Make add3 not a candidate and add0 not an output candidate, then all
+ // direct and indirect outputs of add0 will be removed from the segment.
+ auto without_add3 = all_adds - "add3";
+ auto without_add0 = all_adds - "add0";
+ RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
}
TEST_F(SegmentTest, BigIfElse) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add2 is not a TRT candidate
- //
// feed
// ||
// add0
- // // ||
+ // // \\
// add1 add4
// || ||
// add2 add5
// || ||
// add3 add6
- // || //
+ // \\ //
// add7
// ||
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add3", "add4",
- "add5", "add6", "add7"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect 2 subgraphs
- EXPECT_EQ(segments.size(), 2);
-
- std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
- for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
-
- std::vector<string> expected1{"add0", "add1"};
- for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
+ auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
+ auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
+ auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
+ auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // Make add2 not a TRT candidate, and we expect 2 segments.
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
+ "add4", "add5", "add6", "add7"};
+ RunTest(&g, all_adds - "add2", all_adds, all_adds,
+ {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}});
}
} // namespace test