aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-12 15:22:03 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-12 15:22:03 -0700
commit86f632e29810fa93db559f882567b9569dabfad5 (patch)
tree542d32c630c461ebcdbc483644c8585b92743aec
parent571d3dc5747e04fe0a80be185e64532cf74e1fb0 (diff)
Implement the input/output edge validaters
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc46
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc165
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h26
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc8
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h4
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc21
6 files changed, 182 insertions, 88 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 359fac36f5..ba01eaabc2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -107,10 +107,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(ben,jie): ...
};
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
- if (!candidate_ops.count(node->type_string()) &&
- !PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())) {
- return false;
- }
+ return (candidate_ops.count(node->type_string()) ||
+ PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
}
tensorflow::Status BuildNodeMap(
@@ -280,7 +278,8 @@ tensorflow::Status GetEngineInfo(
subgraph_node_ids.push_back(node_id);
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
- if (segment_nodes.count(input_node->name()) == 0) {
+ if (segment_nodes.count(input_node->name()) == 0 &&
+ !edge->IsControlEdge() && !input_node->IsSource()) {
// Add constant input node into the segment. We don't care if it has
// other output edges going into other engines or TF nodes. Since we add
// it only to the subsegment node list, not the subsegment itself, it
@@ -288,7 +287,7 @@ tensorflow::Status GetEngineInfo(
// will prune it out.
if (input_node->type_string() == "Const") {
subgraph_node_ids.push_back(input_node->id());
- } else if (!edge->IsControlEdge() && !input_node->IsSource()) {
+ } else {
string s(input_node->name());
StrAppend(&s, ":", edge->src_output());
VLOG(1) << "Input edge = " << s;
@@ -351,9 +350,9 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
nvinfer1::IGpuAllocator* alloc,
int max_batch_size) {
const auto& info = infos.at(pos);
- std::vector<tensorflow::TensorShapeProto> out_shapes;
- std::vector<tensorflow::TensorShapeProto> input_shapes;
- std::vector<tensorflow::PartialTensorShape> shapes;
+ std::vector<tensorflow::TensorShapeProto> output_shape_protos;
+ std::vector<tensorflow::TensorShapeProto> input_shape_protos;
+ std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
std::vector<tensorflow::DataType> out_types;
VLOG(1) << "Processing " << info.engine_name;
@@ -366,11 +365,11 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
tensorflow::TensorShapeProto out_shape;
// shape of the output node inside segment
conn.inside_shape.AsProto(&out_shape);
- if (out_shapes.size() <= conn.port_number) {
- out_shapes.resize(conn.port_number + 1);
+ if (output_shape_protos.size() <= conn.port_number) {
+ output_shape_protos.resize(conn.port_number + 1);
out_types.resize(conn.port_number + 1);
}
- out_shapes.at(conn.port_number) = out_shape;
+ output_shape_protos.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
continue;
}
@@ -378,12 +377,12 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Set the shapes and data types of input edge.
tensorflow::TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
- if (input_shapes.size() <= conn.port_number) {
+ if (input_shape_protos.size() <= conn.port_number) {
+ input_shape_protos.resize(conn.port_number + 1);
input_shapes.resize(conn.port_number + 1);
- shapes.resize(conn.port_number + 1);
}
- input_shapes.at(conn.port_number) = in_shape;
- shapes.at(conn.port_number) = conn.outside_shape;
+ input_shape_protos.at(conn.port_number) = in_shape;
+ input_shapes.at(conn.port_number) = conn.outside_shape;
string input_node = conn.outside_node_name;
int input_port = conn.outside_port;
@@ -411,6 +410,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
<< info.engine_name << ":" << inputs.size();
// Skip duplicate inputs.
+ // TODO(aaroey): use std::find instead. GetEngineInfo already remove
+ // duplicate connections, so here we should never find any duplicate?
bool new_input = true;
for (const auto& inp : inputs) {
if (inp.node == input_node && inp.index == input_port) {
@@ -438,8 +439,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
- max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
- alloc, /*calibrator=*/nullptr, &engine,
+ max_batch_size, info.max_workspace_size_bytes, input_shapes,
+ &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
@@ -487,8 +488,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
}
tensorflow::NodeDef trt_node;
tensorflow::Status status =
- node_builder.Attr("input_shapes", input_shapes)
- .Attr("output_shapes", out_shapes)
+ node_builder.Attr("input_shapes", input_shape_protos)
+ .Attr("output_shapes", output_shape_protos)
.Attr("static_engine",
info.engine_type == EngineInfo::EngineType::TRTStatic)
.Attr("segment_funcdef_name",
@@ -705,6 +706,7 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
}
// Entry function from optimization pass.
+// TODO(aaeory): parameter should use pointer type.
tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
// Convert graphdef to graph.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
@@ -722,8 +724,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
segment_options.minimum_segment_size = params.minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
- &graph, IsTensorRTCandidate, IsTensorRTInputCandidate,
- IsTensorRTOutputCandidate, segment_options, &initial_segments));
+ &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties),
+ OutputEdgeValidator(), segment_options, &initial_segments));
if (initial_segments.size() > 1) {
VLOG(0) << "MULTIPLE tensorrt candidate conversion: "
<< initial_segments.size();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 8f6656e4ad..c49e26ea4e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include <algorithm>
+#include <cstring>
#include <list>
#include <map>
#include <memory>
@@ -57,7 +58,6 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::tensorflow::str_util::Split;
-
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -77,11 +77,63 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
break;
default:
return tensorflow::errors::InvalidArgument(
- "Unsupported data type " + tensorflow::DataTypeString(tf_dtype));
+ "Unsupported data type ", tensorflow::DataTypeString(tf_dtype));
}
return tensorflow::Status::OK();
}
+void GetInputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int out_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasOutputProperties(outside_node->name())) {
+ auto output_params =
+ graph_properties.GetOutputProperties(outside_node->name());
+ auto out_shape = output_params.at(out_port);
+ *dtype = out_shape.dtype();
+ *shape = out_shape.shape();
+ } else {
+ VLOG(0) << "Unknown output shape" << outside_node->name();
+ *dtype = outside_node->output_type(out_port);
+ }
+}
+
+void GetOutputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int in_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasInputProperties(outside_node->name())) {
+ auto input_params =
+ graph_properties.GetInputProperties(outside_node->name());
+ auto in_shape = input_params.at(in_port);
+ *dtype = in_shape.dtype();
+ *shape = in_shape.shape();
+ } else {
+ *dtype = outside_node->input_type(in_port);
+ }
+}
+
+tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
+ const tensorflow::DataType dtype,
+ nvinfer1::DataType* trt_dtype) {
+ TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
+ if (shape.dims() < 0) {
+ return tensorflow::errors::InvalidArgument(
+ "Input tensor rank is unknown.");
+ }
+ if (shape.dims() > 8) {
+ return tensorflow::errors::OutOfRange(
+ "Input tensor rank is greater than 8.");
+ }
+ for (int d = 1; d < shape.dims(); ++d) {
+ if (shape.dim_size(d) < 0) {
+ return tensorflow::errors::InvalidArgument(
+ "Input tensor has a unknow non-batch dimemension at dim ", d);
+ }
+ }
+ return Status::OK();
+}
+
inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
nvinfer1::Dims dims;
dims.nbDims = tensor.dims();
@@ -2177,25 +2229,22 @@ tensorflow::Status ConvertGraphDefToEngine(
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
- nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- auto type_status =
- ConvertDType(node_def.attr().at("dtype").type(), &dtype);
- if (type_status != tensorflow::Status::OK()) {
- LOG(WARNING) << "Type conversion failed for " << node_name;
- return type_status;
- }
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +8= " << node_name.c_str() + 8;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kInputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
+ nvinfer1::DataType dtype;
auto shape = input_shapes.at(slot_number);
- if (shape.dims() > 8) {
- LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
- << " at input slot " << slot_number;
- return tensorflow::errors::OutOfRange(
- "Input tensor rank is greater than 8");
+ auto status = ValidateInputProperties(
+ shape, node_def.attr().at("dtype").type(), &dtype);
+ if (!status.ok()) {
+ const string error_message = StrCat(
+ "Validation failed for ", node_name, " and input slot ",
+ slot_number, ": ", status.error_message());
+ LOG(WARNING) << error_message;
+ return Status(status.code(), error_message);
}
if (VLOG_IS_ON(1)) {
string dim_str("dims=");
@@ -2226,10 +2275,10 @@ tensorflow::Status ConvertGraphDefToEngine(
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +9=" << node_name.c_str() + 9;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1);
@@ -2288,38 +2337,20 @@ tensorflow::Status ConvertSegmentToGraphDef(
"Cannot find node with id ", connection.outside_id, " in the graph.");
}
// Updates the shape and data types of input/output connections.
- tensorflow::DataType input_type = tensorflow::DT_FLOAT;
+ tensorflow::DataType dtype;
tensorflow::PartialTensorShape partial_shape;
if (connection.is_input_edge) {
- if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
- auto output_params =
- graph_properties.GetOutputProperties(connection.outside_node_name);
- auto out_shape = output_params.at(connection.outside_port);
- input_type = out_shape.dtype();
- std::vector<tensorflow::int64> dims;
- partial_shape = out_shape.shape();
- connection.outside_shape = partial_shape;
- } else {
- VLOG(0) << "Unknown output shape" << outside_node->name();
- input_type = graph->FindNodeId(connection.outside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
-
- } else { // output edge
- if (graph_properties.HasInputProperties(connection.outside_node_name)) {
- auto input_params =
- graph_properties.GetInputProperties(connection.outside_node_name);
- auto in_shape = input_params.at(connection.outside_port);
- input_type = in_shape.dtype();
- partial_shape = in_shape.shape();
- connection.inside_shape = partial_shape;
- } else {
- input_type = graph->FindNodeId(connection.inside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
+ GetInputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
+
+ } else {
+ GetOutputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
}
+ connection.outside_shape = partial_shape;
+ connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
@@ -2335,7 +2366,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
auto status = builder.Attr("shape", partial_shape)
- .Attr("dtype", input_type)
+ .Attr("dtype", dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
@@ -2353,7 +2384,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Identity");
- auto status = builder.Input(connection.inside_node_name, 0, input_type)
+ auto status = builder.Input(connection.inside_node_name, 0, dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
@@ -2391,11 +2422,35 @@ tensorflow::Status ConvertSegmentToGraphDef(
return tensorflow::Status::OK();
}
-bool IsTensorRTInputCandidate(const tensorflow::Node* node) {
+bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
+ if (in_edge->IsControlEdge()) return true;
+ PartialTensorShape shape;
+ tensorflow::DataType dtype;
+ GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(),
+ &shape, &dtype);
+ nvinfer1::DataType trt_dtype;
+ Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
+ if (!status.ok()) {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << ": " << status;
+ return false;
+ }
+ if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << " which has an input at port " << in_edge->dst_input()
+ << " with #dim<3 and is not a const: " << shape;
+ return false;
+ }
return true;
}
-bool IsTensorRTOutputCandidate(const tensorflow::Node* node) {
+bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
+ if (out_edge->IsControlEdge()) return true;
+ if (out_edge->src()->type_string() == "Const") {
+ VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
+ << " which is a Const.";
+ return false;
+ }
return true;
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 872ba6a080..64337eee84 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -104,6 +104,8 @@ struct EngineInfo {
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
// sorted in topological order.
+//
+// TODO(aaroey): add tests to validate these properties.
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
@@ -128,9 +130,29 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully);
-bool IsTensorRTInputCandidate(const tensorflow::Node* node);
+// Helper class for the segmenter to determine whether an input edge to the TRT
+// segment is valid.
+class InputEdgeValidator {
+ public:
+ InputEdgeValidator(const grappler::GraphProperties& graph_properties)
+ : graph_properties_(graph_properties) {}
+
+ // Return true if the specified edge is eligible to be an input edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* in_edge) const;
-bool IsTensorRTOutputCandidate(const tensorflow::Node* node);
+ private:
+ const grappler::GraphProperties& graph_properties_;
+};
+
+// Helper class for the segmenter to determine whether an output edge from the
+// TRT segment is valid.
+class OutputEdgeValidator {
+ public:
+ // Return true if the specified edge is eligible to be an output edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* out_edge) const;
+};
} // namespace convert
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 5c0898b29a..92807bed14 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -364,8 +364,8 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
tensorflow::Status SegmentGraph(
const tensorflow::Graph* tf_graph,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
- const std::function<bool(const tensorflow::Node*)>& input_candidate_fn,
- const std::function<bool(const tensorflow::Node*)>& output_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments) {
// Steps:
// 1. run the segmentation algorithm to find all the segments, which uses
@@ -526,7 +526,7 @@ tensorflow::Status SegmentGraph(
for (const tensorflow::Edge* edge : node->in_edges()) {
if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
!segment_nodes.count(edge->src())) { // 'node' is an input node.
- if (!input_candidate_fn(node)) {
+ if (!input_candidate_fn(edge)) {
in_nodes_que.push_back(node);
added = true;
break;
@@ -537,7 +537,7 @@ tensorflow::Status SegmentGraph(
for (const tensorflow::Edge* edge : node->out_edges()) {
if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
!segment_nodes.count(edge->dst())) { // 'node' is an output node.
- if (!output_candidate_fn(node)) {
+ if (!output_candidate_fn(edge)) {
out_nodes_que.push_back(node);
break;
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index ab75135054..8c44eb782a 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -52,8 +52,8 @@ struct SegmentOptions {
tensorflow::Status SegmentGraph(
const tensorflow::Graph* tf_graph,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
- const std::function<bool(const tensorflow::Node*)>& input_candidate_fn,
- const std::function<bool(const tensorflow::Node*)>& output_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
} // namespace segment
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index a43cf4f416..432e7b1c04 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -41,15 +41,30 @@ class SegmentTest : public ::testing::Test {
};
}
+ std::function<bool(const tensorflow::Edge*)> MakeInputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* in_edge) -> bool {
+ return node_names.find(in_edge->dst()->name()) != node_names.end();
+ };
+ }
+
+ std::function<bool(const tensorflow::Edge*)> MakeOutputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* out_edge) -> bool {
+ return node_names.find(out_edge->src()->name()) != node_names.end();
+ };
+ }
+
void RunTest(const tensorflow::Graph* graph,
const std::set<string>& candidates,
const std::set<string>& input_candidates,
const std::set<string>& output_candidates,
const std::vector<std::set<string>>& expected_segments) {
SegmentNodesVector segments;
- TF_EXPECT_OK(SegmentGraph(
- graph, MakeCandidateFn(candidates), MakeCandidateFn(input_candidates),
- MakeCandidateFn(output_candidates), default_options_, &segments));
+ TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
+ MakeInputEdgeCandidateFn(input_candidates),
+ MakeOutputEdgeCandidateFn(output_candidates),
+ default_options_, &segments));
ValidateSegment(segments, expected_segments);
}