aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-09 11:44:07 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-09 11:44:07 -0700
commit728422d1eee62374b3221676a1826660473897bc (patch)
tree161115f7ba7d544d86bd491da9a18ae2c4556c17 /tensorflow/contrib/tensorrt
parent1d4a8296b26150f7eabf5bbb981b9b2438a9fb2a (diff)
parentfd9fc4b4b69f7fce60497bbaf5cbd958f12ead8d (diff)
Fix conflicts with upstream/master
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD35
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc577
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc52
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h40
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc8
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.cc36
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.h11
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc159
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h8
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py4
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py90
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc47
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc4
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc9
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py252
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py72
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py349
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.cc101
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.h44
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i114
30 files changed, 1424 insertions, 642 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index b0337c3fe9..5b54cb76b4 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -3,7 +3,7 @@
# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
-package(default_visibility = ["//tensorflow:__subpackages__"])
+package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
@@ -37,7 +37,9 @@ tf_cuda_cc_test(
"nomac",
],
deps = [
+ "//tensorflow/core:gpu_init",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_tensorrt([
@@ -83,11 +85,12 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
+ ":test_utils",
":trt_allocator",
+ ":trt_conversion",
":trt_logging",
":trt_plugins",
":trt_resources",
- ":trt_conversion",
":utils",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -120,7 +123,6 @@ tf_cuda_library(
tf_gen_op_wrapper_py(
name = "trt_engine_op",
- gen_locally = True,
deps = [
":trt_engine_op_op_lib",
":trt_logging",
@@ -183,6 +185,8 @@ py_library(
],
)
+# TODO(aaroey): this wrapper has been causing troubles of double linking, so
+# either get rid of it, or split to make it contain minimum dependencies.
tf_py_wrap_cc(
name = "wrap_conversion",
srcs = ["trt_conversion.i"],
@@ -191,6 +195,7 @@ tf_py_wrap_cc(
"//tensorflow/python:platform/base.i",
],
deps = [
+ ":test_utils",
":trt_conversion",
":trt_engine_op_kernel",
"//third_party/python_runtime:headers",
@@ -263,6 +268,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":test_utils",
":trt_allocator",
":trt_plugins",
":trt_logging",
@@ -273,7 +279,6 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core:gpu_runtime",
"//tensorflow/core:framework_lite",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@@ -384,15 +389,16 @@ cuda_py_tests(
"test/base_test.py",
# "test/batch_matmul_test.py",
# "test/biasadd_matmul_test.py",
- "test/binary_tensor_weight_broadcast_test.py",
- "test/concatenation_test.py",
+ # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation
+ # "test/concatenation_test.py", # Blocked by trt4 installation
"test/const_broadcast_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
- "test/unary_test.py",
- # "test/rank_two_test.py",
+ "test/rank_two_test.py",
+ # "test/unary_test.py", # Blocked by trt4 installation
# "test/vgg_block_nchw_test.py",
# "test/vgg_block_test.py",
+ "test/memory_alignment_test.py",
],
additional_deps = [
":tf_trt_integration_test_base",
@@ -411,4 +417,17 @@ cc_library(
srcs = ["convert/utils.cc"],
hdrs = ["convert/utils.h"],
copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "test_utils",
+ srcs = ["test/utils.cc"],
+ hdrs = ["test/utils.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "@com_googlesource_code_re2//:re2",
+ ],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 3383f6bc9b..21ec8b0b30 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <map>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -29,9 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
@@ -195,20 +194,44 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
return tensorflow::Status::OK();
}
-// Entry function from Python.
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode, int minimum_segment_size, bool is_dyn_op,
int max_cached_engines, std::vector<int> cached_engine_batches) {
- // optimization pass
+ // Create GrapplerItem.
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
item.graph = graph_def;
- // grappler requires a virtual cluster with a proper GPU device
- // in order to calculate flops>0 or fails with FATAL
- // We add numbers from a Pascal card here to have flops>0
+
+ // TODO(aaroey): we should have used single machine cluster like the
+ // following, but the problem is then wrap_conversion will depend on
+ // direct_session and cause double linking problems. To fix this we need to
+ // fix or get rid of the swig dependency. Here we use VirtualCluster
+ // as a work around, and we need to create a session to initialize the
+ // underlying device before calling this method.
+#if 0
+ // Create single machine cluster. Note that this will create a session and
+ // initialize the gpu devices.
+ const int num_cpu_cores =
+ tensorflow::grappler::GetNumAvailableLogicalCPUCores();
+ const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
+ VLOG(2) << "cpu_cores: " << num_cpu_cores;
+ VLOG(2) << "gpus: " << num_gpus;
+ const int timeout_s = 60 * 10;
+ std::unique_ptr<tensorflow::grappler::Cluster> cluster(
+ new tensorflow::grappler::SingleMachine(
+ timeout_s, num_cpu_cores, num_gpus));
+ // These settings are the defaults in tensorflow/python/grappler/cluster.py.
+ cluster->DisableDetailedStats(true);
+ cluster->AllowSoftPlacement(true);
+ cluster->SetNumWarmupSteps(10);
+ TF_RETURN_IF_ERROR(cluster->Provision());
+#else
+ // Create virtual cluster. Grappler requires a virtual cluster with a proper
+ // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode.
+ // We add numbers from a Pascal card here to have flops>0.
tensorflow::DeviceProperties device_properties;
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
@@ -217,47 +240,43 @@ tensorflow::Status ConvertGraphDefToTensorRT(
std::unique_ptr<tensorflow::grappler::Cluster> cluster(
new tensorflow::grappler::VirtualCluster(
{{"/GPU:0", device_properties}}));
+#endif
- // single machine
- int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
- int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
- VLOG(2) << "cpu_cores: " << num_cpu_cores;
- VLOG(2) << "gpus: " << num_gpus;
+ // Create RewriterConfig.
tensorflow::RewriterConfig rw_cfg;
- // use only const folding and layout for the time being since new optimizers
- // break the graph for us
+ // TODO(aaroey): use only const folding and layout for the time being since
+ // new optimizers break the graph for trt.
rw_cfg.add_optimizers("constfold");
rw_cfg.add_optimizers("layout");
- rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
+ auto optimizer = rw_cfg.add_custom_optimizers();
+ optimizer->set_name("TensorRTOptimizer");
+ auto& parameters = *(optimizer->mutable_parameter_map());
+ parameters["minimum_segment_size"].set_i(minimum_segment_size);
+ parameters["max_batch_size"].set_i(max_batch_size);
+ parameters["is_dynamic_op"].set_b(is_dyn_op);
+ parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes);
+ TF_RETURN_IF_ERROR(GetPrecisionModeName(
+ precision_mode, parameters["precision_mode"].mutable_s()));
+ parameters["maximum_cached_engines"].set_i(max_cached_engines);
+ if (!cached_engine_batches.empty()) {
+ auto list = parameters["cached_engine_batches"].mutable_list();
+ for (const int batch : cached_engine_batches) {
+ list->add_i(batch);
+ }
+ }
+
+ // Run optimizer.
tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg);
- tensorflow::GraphDef gdef;
- TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef));
- item.graph = gdef;
-
- // AJ refactoring shape inference through grappler/GraphProperties.
- tensorflow::grappler::GraphProperties static_graph_properties(item);
- TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
- // Build full graph
- ConversionParams cp;
- cp.input_graph_def = &gdef;
- cp.output_names = &output_names;
- cp.max_batch_size = max_batch_size;
- cp.output_graph_def = new_graph_def;
- cp.precision_mode = precision_mode;
- cp.is_dyn_op = is_dyn_op;
- cp.max_cached_engines = max_cached_engines;
- cp.cached_engine_batches = cached_engine_batches;
- cp.minimum_segment_size = minimum_segment_size;
- cp.graph_properties = &static_graph_properties;
- cp.max_workspace_size_bytes = max_workspace_size_bytes;
+ TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def));
+
if (VLOG_IS_ON(5)) {
std::fstream f;
f.open("TRTConversionInput.pb",
std::fstream::out | std::fstream::binary | std::fstream::trunc);
- f << gdef.SerializeAsString();
+ f << new_graph_def->SerializeAsString();
f.close();
}
- return ConvertAfterShapes(cp);
+ return Status::OK();
}
// Function to get subsegment information structure.
@@ -268,11 +287,10 @@ tensorflow::Status GetEngineInfo(
const std::unordered_map<string, tensorflow::Node*>& node_map,
const std::vector<tensorflow::Node*>& reverse_topo_order,
EngineInfo* info) {
- std::vector<int> subgraph_node_ids;
+ std::vector<int> subgraph_node_ids; // Topologically sorted node ids.
+ std::set<string> subgraph_node_names = segment_nodes;
std::set<int> added_const_node_ids; // Used to prevent double insertion.
std::set<string> segment_devices;
- int input_port = 0;
- int output_port = 0;
// Map from src_node_name+port to the unique port numbers of the TRT op, where
// the src_node_name is the name of the source node of the input/output
@@ -280,13 +298,12 @@ tensorflow::Status GetEngineInfo(
// input/output edges must be in different split of the graph.
// TODO(aaroey): consider using node id and port instead.
// TODO(aaroey): using topo order instead of reverting reverse topo order.
- std::unordered_map<string, int> created_edges;
+ std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
++it) {
const auto& node_name = (*it)->name();
-
if (segment_nodes.count(node_name) == 0) continue;
- auto node = node_map.at(node_name);
+ auto node = *it;
auto node_device = node->requested_device();
if (!node_device.empty()) {
segment_devices.insert(node_device);
@@ -299,64 +316,93 @@ tensorflow::Status GetEngineInfo(
}
}
const int node_id = node->id();
+ subgraph_node_ids.push_back(node_id);
+ // Create input connections.
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
- if (segment_nodes.count(input_node->name()) == 0 &&
- !edge->IsControlEdge() && !input_node->IsSource()) {
- // Add constant input node into the segment. We don't care if it has
- // other output edges going into other engines or TF nodes. Since we add
- // it only to the subsegment node list, not the subsegment itself, it
- // won't be removed from the graph. If it doesn't have any edges, TF
- // will prune it out.
- if (input_node->type_string() == "Const") {
- if (added_const_node_ids.count(input_node->id()) == 0) {
- added_const_node_ids.insert(input_node->id());
- subgraph_node_ids.push_back(input_node->id());
- }
+ if (input_node->IsSource() || segment_nodes.count(input_node->name())) {
+ continue;
+ }
+ if (edge->IsControlEdge()) {
+ // Control input.
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ node_name, node_id,
+ /*input_edge=*/true);
+ } else if (input_node->type_string() == "Const") {
+ // Add constant data input nodes into the segment graphdef (thus also in
+ // the engine). We don't care if it has other output edges going into
+ // other engines or TF nodes. Since we add it only to the segment
+ // graphdef, not the segment itself, it won't be removed from the graph.
+ // If it doesn't have any edges, TF will prune it out.
+ //
+ // Note that the segmenter already ensure that the constant data input
+ // is valid and suppported by the engine.
+ if (!added_const_node_ids.insert(input_node->id()).second) {
+ // Already added before.
+ continue;
+ }
+ VLOG(1) << "Adding const node " << input_node->name();
+ QCHECK(subgraph_node_names.insert(input_node->name()).second);
+ // Since we already add (duplicate) the const input node to the segment
+ // graphdef, it's now not a data dependency any more, but to make the
+ // dependency correct we still add a control dependency.
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ node_name, node_id,
+ /*input_edge=*/true);
+ } else {
+ // Non-const data input.
+ int port = Graph::kControlSlot - 1;
+ // Use the source non-segment node name/port as key.
+ const string s = StrCat(input_node->name(), ":", edge->src_output());
+ VLOG(1) << "Input edge = " << s;
+ if (input_to_engine_port.count(s)) {
+ port = input_to_engine_port.at(s);
} else {
- string s(input_node->name());
- StrAppend(&s, ":", edge->src_output());
- VLOG(1) << "Input edge = " << s;
- int port = input_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
- } else {
- created_edges.insert({s, port});
- input_port++;
- }
- info->connections.emplace_back(input_node->name(), input_node->id(),
- edge->src_output(), node_name, node_id,
- edge->dst_input(), true, port);
+ port = input_to_engine_port.size();
+ input_to_engine_port.insert({s, port});
}
+ info->connections.emplace_back(
+ input_node->name(), input_node->id(), edge->src_output(), node_name,
+ node_id, edge->dst_input(), /*input_edge=*/true, port);
}
}
- // We need to add possible const input nodes before adding this node in
- // order to keep the topological order.
- subgraph_node_ids.push_back(node_id);
+ // Create output connections.
for (const auto edge : node->out_edges()) {
auto output_node = edge->dst();
- if (segment_nodes.count(output_node->name()) == 0 &&
- !edge->IsControlEdge() && !output_node->IsSink()) {
- string s(node_name);
- StrAppend(&s, ":", edge->src_output());
+ if (output_node->IsSink() || segment_nodes.count(output_node->name())) {
+ continue;
+ }
+ if (edge->IsControlEdge()) {
+ // Control output.
+ info->connections.emplace_back(output_node->name(), output_node->id(),
+ node_name, node_id,
+ /*input_edge=*/false);
+ } else {
+ // Data output.
+ int port = Graph::kControlSlot - 1;
+ // Use the source segment node name/port as key.
+ const string s = StrCat(node_name, ":", edge->src_output());
VLOG(1) << "Output edge = " << s;
- int port = output_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
+ if (output_to_engine_port.count(s)) {
+ port = output_to_engine_port.at(s);
} else {
- created_edges.insert({s, port});
- output_port++;
+ port = output_to_engine_port.size();
+ output_to_engine_port.insert({s, port});
}
- info->connections.emplace_back(output_node->name(), output_node->id(),
- edge->dst_input(), node_name, node_id,
- edge->src_output(), false, port);
+ info->connections.emplace_back(
+ output_node->name(), output_node->id(), edge->dst_input(),
+ node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
}
}
- }
+ } // For each segment node in topological order.
+ // Construct the const nodes first.
+ subgraph_node_ids.insert(subgraph_node_ids.begin(),
+ added_const_node_ids.begin(),
+ added_const_node_ids.end());
TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
- g, graph_properties, subgraph_node_ids, &info->connections,
- &info->segment_graph_def, &info->engine_name));
+ g, graph_properties, subgraph_node_names, subgraph_node_ids,
+ &info->connections, &info->segment_graph_def, &info->engine_name));
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) {
info->device = *segment_devices.begin();
@@ -366,94 +412,137 @@ tensorflow::Status GetEngineInfo(
<< "but this shouldn't have happened";
info->device = *segment_devices.begin();
} else {
- VLOG(1) << "Segment devices size is 0";
+ LOG(ERROR) << "Can't find a device placement for the op!";
}
return Status::OK();
}
-// Function to insert a TRT node into the graph. The graph is not modified if
-// the returned status is not ok.
-// 'alloc' is only used for creating static engine.
-tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
- const std::vector<EngineInfo>& infos, int pos,
+// Helper function to update edge connection from the removed node to the
+// engine node. If an outside node is gone, it must have been absorbed into
+// an engine node. Find the engine node.
+void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
+ const size_t my_engine_id,
+ const std::vector<Node*>& engine_nodes,
+ const bool is_input_edge, const string& node_name,
+ tensorflow::Node** node, int* port) {
+ for (size_t t = 0; t < infos.size(); ++t) {
+ if (t == my_engine_id) {
+ continue;
+ }
+ const auto& info = infos.at(t);
+ for (const auto& eng_conn : info.connections) {
+ // If the connection being updated is an input connection, the source of
+ // the connection must be an output connection of another engine. And vise
+ // versa.
+ if (is_input_edge == eng_conn.is_input_edge) continue;
+ if (eng_conn.inside_node_name == node_name &&
+ eng_conn.inside_port == *port) {
+ *node = CHECK_NOTNULL(engine_nodes[t]);
+ QCHECK_EQ(info.engine_name, (**node).name())
+ << "Engine name mismatch: " << info.engine_name << " vs "
+ << (**node).name();
+ *port = eng_conn.port_number;
+ return;
+ }
+ }
+ }
+ LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
+}
+
+// Function to insert a TRT engine node into the graph.
+// Create engine nodes in the following way:
+// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
+// 2. When an engine node is created, add it into the graph with necessary
+// re-wiring.
+// 2.1. If the outside connected node is existing, connect the engine
+// node to it.
+// 2.2. If the outside connected node is gone, it must have been absorted
+// into another engine node (which was processed before the processing
+// one). Connect to the pre-existing engine node instead.
+// 3. In this way, we ensure the graph is topologically sort-able after each
+// invocation of CreateTRTNode().
+tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
+ int max_batch_size, tensorflow::Graph* graph,
nvinfer1::IGpuAllocator* alloc,
- int max_batch_size) {
+ std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
+ TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail");
std::vector<tensorflow::TensorShapeProto> output_shape_protos;
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
+ std::vector<tensorflow::Node*> input_nodes;
+ std::vector<tensorflow::Node*> control_input_nodes;
+ std::unordered_set<string> control_input_names;
std::vector<tensorflow::DataType> out_types;
- VLOG(1) << "Processing " << info.engine_name;
- // Update the shape and data types of input/output nodes, and find all unique
- // inputs.
+ VLOG(1) << "Processing " << info.engine_name;
+ // Collect needed info for creating the engine node in the graph
for (const auto& conn : info.connections) {
- if (!conn.is_input_edge) {
- // Set the shapes and data types of output edge.
- tensorflow::TensorShapeProto out_shape;
- // shape of the output node inside segment
- conn.inside_shape.AsProto(&out_shape);
- if (output_shape_protos.size() <= conn.port_number) {
- output_shape_protos.resize(conn.port_number + 1);
- out_types.resize(conn.port_number + 1);
+ // Control edges
+ if (conn.is_control_edge()) {
+ // Skip control outputs for now. control output info are not needed for
+ // node creation and will be processed later.
+ if (!conn.is_input_edge) continue;
+
+ // Rewrire control input if it's not found in original graph.
+ tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
+ int port = tensorflow::Graph::kControlSlot;
+ if (!input_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
+ QCHECK_EQ(Graph::kControlSlot, port);
}
- output_shape_protos.at(conn.port_number) = out_shape;
- out_types.at(conn.port_number) = conn.connection_type;
- continue;
- }
-
- // Set the shapes and data types of input edge.
- tensorflow::TensorShapeProto in_shape;
- conn.outside_shape.AsProto(&in_shape);
- if (input_shape_protos.size() <= conn.port_number) {
- input_shape_protos.resize(conn.port_number + 1);
- input_shapes.resize(conn.port_number + 1);
- }
- input_shape_protos.at(conn.port_number) = in_shape;
- input_shapes.at(conn.port_number) = conn.outside_shape;
-
- string input_node = conn.outside_node_name;
- int input_port = conn.outside_port;
- bool found_engine = false;
- // Rewire the inputs to other engines if they contain original input node.
- // Note that we use the information of the engine here, not the information
- // of the created TRT nodes, so we're able to find all the connections to
- // any other engines beforehand.
- for (size_t t = 0; t < infos.size(); ++t) {
- if (t == pos) continue;
- auto& engine_info = infos.at(t);
- for (const auto& eng_conn : engine_info.connections) {
- if (eng_conn.is_input_edge) continue;
- if (eng_conn.inside_node_name == input_node) {
- input_node = engine_info.engine_name;
- if (eng_conn.inside_port == input_port) {
- input_port = eng_conn.port_number;
- found_engine = true;
- break;
- }
- }
+ if (!control_input_names.insert(input_node->name()).second) {
+ continue;
}
- if (found_engine) break;
- }
- VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
- << info.engine_name << ":" << inputs.size();
- // Skip duplicate inputs.
- // TODO(aaroey): use std::find instead. GetEngineInfo already remove
- // duplicate connections, so here we should never find any duplicate?
- bool new_input = true;
- for (const auto& inp : inputs) {
- if (inp.node == input_node && inp.index == input_port) {
- new_input = false;
- break;
+ control_input_nodes.push_back(input_node);
+ VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
+ << info.engine_name;
+ } else {
+ // Data edges
+ if (!conn.is_input_edge) {
+ // Set the shapes and data types of output edge.
+ tensorflow::TensorShapeProto out_shape;
+ // shape of the output node inside segment
+ conn.inside_shape.AsProto(&out_shape);
+ if (output_shape_protos.size() <= conn.port_number) {
+ output_shape_protos.resize(conn.port_number + 1);
+ out_types.resize(conn.port_number + 1);
+ }
+ output_shape_protos.at(conn.port_number) = out_shape;
+ out_types.at(conn.port_number) = conn.connection_type;
+ } else {
+ // Set the shapes and data types of input edge.
+ tensorflow::TensorShapeProto in_shape;
+ conn.outside_shape.AsProto(&in_shape);
+ if (input_shape_protos.size() <= conn.port_number) {
+ input_shape_protos.resize(conn.port_number + 1);
+ input_shapes.resize(conn.port_number + 1);
+ }
+ input_shape_protos.at(conn.port_number) = in_shape;
+ input_shapes.at(conn.port_number) = conn.outside_shape;
+
+ // Rewrire data input if it's not found in original graph.
+ tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
+ int port = conn.outside_port;
+ if (!input_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
+ }
+ if (std::find_if(
+ std::begin(inputs), std::end(inputs),
+ [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
+ return inp.node == input_node->name() && inp.index == port;
+ }) == std::end(inputs)) {
+ inputs.emplace_back(input_node->name(), port, conn.connection_type);
+ input_nodes.push_back(CHECK_NOTNULL(input_node));
+ VLOG(1) << "Engine Input " << input_node->name() << ":" << port
+ << " -> " << info.engine_name << ":" << inputs.size() - 1;
+ }
}
}
- if (new_input) {
- inputs.emplace_back(input_node, input_port, conn.connection_type);
- }
}
-
- // Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
info.precision_mode == INT8MODE) {
@@ -485,21 +574,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// TODO(aaroey): use enum instead, and add a helper method to do the
// conversion.
string prec_string;
- switch (info.precision_mode) {
- case FP32MODE:
- prec_string = "FP32";
- break;
- case FP16MODE:
- prec_string = "FP16";
- break;
- case INT8MODE:
- prec_string = "INT8";
- if (!TRTResourceManager::instance()->getManager("TRTCalibration")) {
- LOG(ERROR) << "Failed to construct calibration storage";
- }
- break;
- default:
- return tensorflow::errors::OutOfRange("Unknown precision mode");
+ TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string));
+ if (info.precision_mode == INT8MODE &&
+ !TRTResourceManager::instance()->getManager("TRTCalibration")) {
+ LOG(ERROR) << "Failed to construct calibration storage";
}
tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
if (!info.device.empty()) node_builder.Device(info.device);
@@ -511,6 +589,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << ins;
}
node_builder.Input(inputs);
+ for (const string& c : control_input_names) {
+ node_builder.ControlInput(c);
+ }
+
if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
info.cached_engine_batches.size()) {
LOG(WARNING) << "Cached engine batches are ignored for static engines";
@@ -539,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Up until this point, graph is not modified. If we return !status.ok() from
// here, this segment will be skipped
+ // TODO(aaroey): let it return proper error status for the following logic
+ // instead of checking fail.
tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
+ (*engine_nodes)[pos] = engine_node;
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return status;
}
+ // Add control input and input edges to the engine node.
+ for (const auto in : control_input_nodes) {
+ VLOG(1) << "Connecting control edge from " << in->name() << " to "
+ << engine_node->name();
+ graph->AddControlEdge(in, engine_node);
+ }
+ VLOG(1) << "input_nodes size = " << input_nodes.size();
+ for (int i = 0; i < input_nodes.size(); ++i) {
+ Node* n = CHECK_NOTNULL(input_nodes[i]);
+ const auto& in = inputs[i];
+ VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
+ << " to " << engine_node->name() << ":" << i;
+ graph->AddEdge(n, in.index, engine_node, i);
+ }
+
// Updates the inputs of output edges destination nodes, and point them to the
// engine node.
for (auto& conn : info.connections) {
- if (conn.is_input_edge) continue;
- VLOG(1) << " Updating DBG " << engine_node->name() << " out_port "
- << conn.port_number << " out_id " << conn.outside_id
- << " name=" << conn.outside_node_name;
- auto dst_node = graph->FindNodeId(conn.outside_id);
- // dst_node can only be removed if it is an input node of another engine.
- // In this case, other engines input edge is updated in nodedef to point to
- // this engine. Even though edge doesn't exists in the graph, when it is
- // deserialized again, correct edges will be constructed. This is a problem
- // of graph->AddNode().
- if (!dst_node) continue;
+ if (conn.is_input_edge) {
+ continue;
+ }
+ tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id);
+ int port = conn.outside_port;
+ if (!output_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
+ conn.outside_node_name, &output_node, &port);
+ }
VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
- << " to " << dst_node->name() << ":" << conn.outside_port;
- auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node,
- conn.outside_port);
- CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":"
- << conn.port_number << " -> " << dst_node->name() << ":"
- << conn.outside_port;
+ << " to " << output_node->name() << ":" << port;
+ if (conn.is_control_edge()) {
+ QCHECK_EQ(Graph::kControlSlot, port);
+ graph->AddControlEdge(engine_node, output_node);
+ } else {
+ auto new_edge =
+ graph->AddEdge(engine_node, conn.port_number, output_node, port);
+ QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
+ << ":" << conn.port_number << " -> "
+ << output_node->name() << ":" << conn.outside_port;
+ }
}
- return status;
+ return Status::OK();
}
// Function to construct a funcdef from the segment and add it to the graph.
@@ -666,72 +769,36 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
}
std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
- ConversionParams& params, EngineInfo& engine) {
+ const ConversionParams& params, const EngineInfo& engine) {
int cuda_device_id = -1;
- auto check_device_id = [](int tfid) -> int {
- tensorflow::TfGpuId tf_gpu_id(tfid);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
- if (s.ok()) {
- VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
- << cuda_gpu_id.value();
- return cuda_gpu_id.value();
- }
- VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s;
- return -1;
- };
tensorflow::Allocator* dev_allocator = nullptr;
- // we need to us PM here since in python path there is no way to get
- // to allocators.
- // TODO(sami): when grappler devices become available else path will not be
- // necessary
- auto pm = tensorflow::GPUProcessState::singleton();
- if (params.cluster) { // get allocator
- tensorflow::Device* device = nullptr;
- if (params.cluster->GetDeviceSet()) {
- device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device);
+ if (params.cluster) {
+ std::vector<tensorflow::Device*> devices;
+ if (!engine.device.empty() && params.cluster->GetDeviceSet()) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
+ parsed_name.has_id) {
+ params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name,
+ &devices);
+ }
}
- if (device) {
+ if (!devices.empty()) {
+ if (devices.size() > 1) {
+ string msg = "Found multiple matching devices using name '";
+ StrAppend(&msg, engine.device, "': ");
+ for (auto d : devices) StrAppend(&msg, d->name(), ", ");
+ StrAppend(&msg, ". Will get the allocator from first one.");
+ LOG(WARNING) << msg;
+ }
tensorflow::AllocatorAttributes alloc_attr;
- dev_allocator = device->GetAllocator(alloc_attr);
- VLOG(1) << "Using allocator " << dev_allocator->Name();
+ cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
+ dev_allocator = devices[0]->GetAllocator(alloc_attr);
+ VLOG(1) << "Using allocator " << dev_allocator->Name()
+ << " and cuda_device_id " << cuda_device_id;
} else {
LOG(WARNING) << "Cluster is set but device '" << engine.device
<< "' is not found in the cluster";
}
- } else { // cluster not found, possibly a python call
- VLOG(1) << "Cluster is not set, probably called from python";
- int found_device = 0;
- bool try_gpu_ids = true;
- // if device is set, try to find the device. Might be a problem for multi
- // host case but TensorRT do not support multi host setups yet.
- if (!engine.device.empty()) {
- DeviceNameUtils::ParsedName parsed_name;
- if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) {
- cuda_device_id = parsed_name.has_id ? parsed_name.id : -1;
- }
- try_gpu_ids = !parsed_name.has_id;
- }
- if (try_gpu_ids) {
- while (found_device < 100) {
- cuda_device_id = check_device_id(found_device);
- if (cuda_device_id >= 0) break;
- found_device++;
- }
- }
- if (found_device == 100) {
- LOG(ERROR) << " Can't find a GPU device to work with. Please "
- "instantiate a session to initialize devices";
- return std::make_pair(cuda_device_id, dev_allocator);
- }
- LOG(WARNING)
- << "Can't determine the device, constructing an allocator at device "
- << found_device;
- tensorflow::GPUOptions gpuoptions;
- // this will be a noop if device is already initialized
- gpuoptions.set_allow_growth(true);
- tensorflow::TfGpuId tf_gpu_id(found_device);
- dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
}
return std::make_pair(cuda_device_id, dev_allocator);
}
@@ -824,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
}
VLOG(1) << "Current cuda device is " << old_cuda_device;
+ std::vector<Node*> engine_nodes;
+ engine_nodes.resize(engine_segments.size());
for (int i = 0; i < engine_segments.size(); ++i) {
auto& engine = engine_segments.at(i);
// Partition the workspace size by the average of node ratio and segment
@@ -847,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
}
cudaSetDevice(cuda_device_id);
- auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(),
- params.max_batch_size);
+ auto status = CreateTRTNode(engine_segments, i, params.max_batch_size,
+ &graph, alloc.get(), &engine_nodes);
// If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified.
+ const string msg = StrCat("Engine ", engine.engine_name,
+ " creation for segment ", i, ", composed of ",
+ converted_segments.at(i).first.size(), " nodes");
if (status.ok()) {
+ LOG(INFO) << msg << " succeeded.";
for (auto node_name : converted_segments.at(i).first) {
graph.RemoveNode(node_map.at(node_name));
}
} else {
// Graph is not modified.
- LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
- << converted_segments.at(i).first.size()
- << " nodes failed: " << status << ". Skipping...";
+ LOG(WARNING) << msg << " failed: " << status << ". Skipping...";
}
}
cudaSetDevice(old_cuda_device);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 9d881eda90..7aec963aa3 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -2690,7 +2691,7 @@ tensorflow::Status ConvertGraphDefToEngine(
// Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
- VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
+ VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
int32 slot_number = -1;
@@ -2791,6 +2792,7 @@ tensorflow::Status ConvertGraphDefToEngine(
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids, // In topological order
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope) {
@@ -2799,6 +2801,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
// nodes in the segment graphdef.
for (size_t i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
+ if (connection.is_control_edge()) continue;
auto outside_node = graph->FindNodeId(connection.outside_id);
if (!outside_node) {
// This should never happen, unless the original graph is problematic.
@@ -2812,13 +2815,13 @@ tensorflow::Status ConvertSegmentToGraphDef(
GetInputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
-
+ connection.outside_shape = partial_shape;
} else {
GetOutputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
+ connection.inside_shape = partial_shape;
}
- connection.outside_shape = partial_shape;
connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
@@ -2871,12 +2874,12 @@ tensorflow::Status ConvertSegmentToGraphDef(
old_to_new_id_map[node_id] = segment_def->node_size();
auto snode = segment_def->add_node();
snode->CopyFrom(node->def());
- VLOG(1) << "Copying " << snode->name() << " to subgraph";
+ VLOG(2) << "Copying " << snode->name() << " to subgraph";
}
// Update the inputs of the new input nodes to point to placeholder nodes.
for (int i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
- if (!connection.is_input_edge) continue;
+ if (connection.is_control_edge() || !connection.is_input_edge) continue;
auto snode =
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
const string placeholder_name =
@@ -2886,6 +2889,39 @@ tensorflow::Status ConvertSegmentToGraphDef(
<< placeholder_name;
snode->set_input(connection.inside_port, placeholder_name);
}
+ // Remove control inputs that are not inside the segment.
+ for (int i = 0; i < segment_def->node_size(); ++i) {
+ auto snode = segment_def->mutable_node(i);
+ const int input_size = snode->input_size();
+ int input_idx = 0;
+ int actual_input_idx = 0;
+ while (input_idx < input_size) {
+ TensorId input = ParseTensorName(snode->input(input_idx));
+ if (!subgraph_node_names.count(
+ string(input.first.data(), input.first.size())) &&
+ !str_util::StartsWith(input.first, kInputPHName)) {
+ if (input.second == Graph::kControlSlot) {
+ VLOG(1) << "... removing control inputs " << input.first
+ << " from subgraph.";
+ ++input_idx;
+ continue;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Found non control input outside the segment that is not an "
+ "engine connection to ",
+ snode->name(), ": ", input.first);
+ }
+ }
+ if (actual_input_idx != input_idx) {
+ snode->set_input(actual_input_idx, snode->input(input_idx));
+ }
+ ++input_idx;
+ ++actual_input_idx;
+ }
+ for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
+ snode->mutable_input()->RemoveLast();
+ }
+ }
*common_scope = local_scope;
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK();
@@ -2900,7 +2936,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
nvinfer1::DataType trt_dtype;
Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
if (!status.ok()) {
- VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
<< ": " << status;
return false;
}
@@ -2908,7 +2944,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
#if NV_TENSORRT_MAJOR == 3
// TRT 3.x only support 4 dimensional input tensor.
if (shape.dims() != 4 && in_edge->src()->type_string() != "Const") {
- VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
<< " which has an input at port " << in_edge->dst_input()
<< " with #dim!=4 and is not a const: " << shape;
return false;
@@ -2920,7 +2956,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
if (out_edge->IsControlEdge()) return true;
if (out_edge->src()->type_string() == "Const") {
- VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
+ VLOG(1) << "--> Need to remove output node " << out_edge->src()->name()
<< " which is a Const.";
return false;
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 6ae60ec352..a60253740f 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,16 +36,12 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-static const char* kInputPHName = "InputPH_";
-static const char* kOutputPHName = "OutputPH_";
+static const char* kInputPHName = "TensorRTInputPH_";
+static const char* kOutputPHName = "TensorRTOutputPH_";
namespace convert {
-// TODO(aaroey): use an enum instead.
-const int FP32MODE = 0;
-const int FP16MODE = 1;
-const int INT8MODE = 2;
-
struct EngineConnection {
+ // Constructs a non-control edge.
EngineConnection(const string& outside, int out_id, int out_port,
const string& inside, int in_id, int in_port,
bool input_edge, int port)
@@ -58,21 +54,35 @@ struct EngineConnection {
is_input_edge(input_edge),
port_number(port) {}
+ // Constructs a control edge.
+ EngineConnection(const string& outside, int out_id, const string& inside,
+ int in_id, bool input_edge)
+ : outside_node_name(outside),
+ outside_id(out_id),
+ outside_port(Graph::kControlSlot),
+ inside_node_name(inside),
+ inside_id(in_id),
+ inside_port(Graph::kControlSlot),
+ is_input_edge(input_edge),
+ port_number(Graph::kControlSlot) {}
+
+ bool is_control_edge() const { return port_number == Graph::kControlSlot; }
+
const string outside_node_name;
const int outside_id;
const int outside_port;
- tensorflow::PartialTensorShape outside_shape;
+ tensorflow::PartialTensorShape outside_shape; // Only set for input edge.
const string inside_node_name;
const int inside_id;
const int inside_port;
- tensorflow::PartialTensorShape inside_shape;
+ tensorflow::PartialTensorShape inside_shape; // Only set for output edge.
tensorflow::DataType connection_type;
- bool is_input_edge;
+ const bool is_input_edge;
- // The port number of the TRT node connecting to this edge.
- int port_number;
+ // The port number of the TRT node connected with this edge.
+ const int port_number;
};
struct EngineInfo {
@@ -85,7 +95,9 @@ struct EngineInfo {
string device;
tensorflow::GraphDef segment_graph_def;
- // The segment nodes that are on one side of the edges are topological sorted.
+ // Non-control input connections inside this vector are sorted in a way such
+ // that, the segment nodes connecting to them are topological sorted.
+ // In addition, for non-control connections, there must be no duplicates.
std::vector<EngineConnection> connections;
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
@@ -101,6 +113,7 @@ struct EngineInfo {
// (OutputPH_*). This function needs to be called before TensorRT nodes
// inserted in order to correctly get sizes from the original graph.
//
+// - subgraph_node_names: the node names of the subgraph.
// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
@@ -110,6 +123,7 @@ struct EngineInfo {
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids,
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope);
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 044c736c03..f33f2cc4d6 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stacktrace.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -189,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::Cluster* cluster,
const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Called TRTOptimization Pass " << name_;
- if (VLOG_IS_ON(1)) {
- PrintDebugInfo(cluster, item);
- }
// This is a hack to workaround optimizer issue. MetaOptimizer calls
// optimization passes on function objects as well, we should not modify
// generated funcdefs! This is fragile but we don't have any other option
@@ -203,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize(
*optimized_graph = item.graph;
return tensorflow::Status::OK();
}
+ if (VLOG_IS_ON(1)) {
+ VLOG(2) << CurrentStackTrace();
+ PrintDebugInfo(cluster, item);
+ }
int max_dim = -1;
if (item.feed.size()) {
for (const auto& f : item.feed) {
diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc
index 24591cf84b..e7a1febb8c 100644
--- a/tensorflow/contrib/tensorrt/convert/utils.cc
+++ b/tensorflow/contrib/tensorrt/convert/utils.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
namespace tensorrt {
@@ -24,12 +27,43 @@ bool IsGoogleTensorRTEnabled() {
// safely write code that uses tensorrt conditionally. E.g. if it does not
// check for for tensorrt, and user mistakenly uses tensorrt, they will just
// crash and burn.
-#ifdef GOOGLE_TENSORRT
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
return true;
#else
return false;
#endif
}
+Status GetPrecisionModeName(const int precision_mode, string* name) {
+ switch (precision_mode) {
+ case FP32MODE:
+ *name = "FP32";
+ break;
+ case FP16MODE:
+ *name = "FP16";
+ break;
+ case INT8MODE:
+ *name = "INT8";
+ break;
+ default:
+ return tensorflow::errors::OutOfRange("Unknown precision mode");
+ }
+ return Status::OK();
+}
+
+Status GetPrecisionMode(const string& name, int* precision_mode) {
+ if (name == "FP32") {
+ *precision_mode = FP32MODE;
+ } else if (name == "FP16") {
+ *precision_mode = FP16MODE;
+ } else if (name == "INT8") {
+ *precision_mode = INT8MODE;
+ } else {
+ return tensorflow::errors::InvalidArgument("Invalid precision mode name: ",
+ name);
+ }
+ return Status::OK();
+}
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h
index 8b5f4d614a..0592f31462 100644
--- a/tensorflow/contrib/tensorrt/convert/utils.h
+++ b/tensorflow/contrib/tensorrt/convert/utils.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
namespace tensorrt {
@@ -33,6 +35,15 @@ using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
bool IsGoogleTensorRTEnabled();
+// TODO(aaroey): use an enum instead.
+const int FP32MODE = 0;
+const int FP16MODE = 1;
+const int INT8MODE = 2;
+
+Status GetPrecisionModeName(const int precision_mode, string* name);
+
+Status GetPrecisionMode(const string& name, int* precision_mode);
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 646d62483f..2b42d81f47 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -45,11 +46,11 @@ using ::tensorflow::strings::StrCat;
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
- AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
+ AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper() override { done_(); }
private:
- tensorflow::AsyncOpKernel::DoneCallback done_;
+ AsyncOpKernel::DoneCallback done_;
};
#define TYPECASE(dt, X, Y) \
@@ -122,15 +123,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
- if (precision_string == "FP32") {
- precision_mode_ = convert::FP32MODE;
- } else if (precision_string == "FP16") {
- precision_mode_ = convert::FP16MODE;
- } else if (precision_string == "INT8") {
- precision_mode_ = convert::INT8MODE;
- }
+ OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
calibration_mode_ =
- (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0);
+ (precision_mode_ == INT8MODE && calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
@@ -152,7 +147,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
}
}
-void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
+void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
if (!calibration_mode_) {
VLOG(1) << "Executing native engine";
@@ -179,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
- [ctx, outputs, helper](const tensorflow::Status& s) {
+ [this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
@@ -189,11 +184,13 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
+ test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
+ "done");
delete outputs;
});
}
-void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
AsyncHelper* helper) {
helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
@@ -234,11 +231,12 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
+ test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
-int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
+int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
int num_batch = ctx->input(0).shape().dim_size(0);
int smallest_engine = 0;
for (const auto i : cached_engine_batches_) {
@@ -254,21 +252,20 @@ int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
cached_engine_batches_.push_back(num_batch);
VLOG(1) << "Running with batch size " << num_batch;
} else {
- string s("Engine buffer is full. buffer limit= ");
- StrAppend(&s, max_cached_engines_, ", current entries= ");
- for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
- StrAppend(&s, "Requested batch= ", num_batch);
- LOG(ERROR) << s;
- ctx->SetStatus(tensorflow::errors::ResourceExhausted(
- "Requested batch size is not available and engine cache is full"));
+ string msg =
+ StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
+ ", current entries=");
+ for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
+ StrAppend(&msg, " requested batch=", num_batch);
+ LOG(WARNING) << msg;
return -1;
}
}
return smallest_engine;
}
-void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
- tensorflow::AsyncOpKernel::DoneCallback done) {
+void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
+ AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
@@ -276,32 +273,54 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
return;
}
const int smallest_engine = GetEngineBatch(ctx);
- if (smallest_engine < 0) return; // GetEngineBatch already set the status.
+ if (smallest_engine < 0) {
+ LOG(WARNING) << "Failed to get engine batch, running native segment for "
+ << name();
+ ExecuteNativeSegment(ctx, helper);
+ return;
+ }
const int num_batch = ctx->input(0).shape().dim_size(0);
auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
- << " failed Running native segment";
+ << " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
+ const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
+ engine_ctx_pair.second.get());
+ if (retry) {
+ LOG(WARNING) << "Failed to execute engine, "
+ << "retrying with native segment for " << name();
+ ExecuteNativeSegment(ctx, helper);
+ return;
+ }
+}
+bool TRTEngineOp::ExecuteTrtEngine(
+ OpKernelContext* ctx, const int num_batch,
+ nvinfer1::ICudaEngine* trt_engine_ptr,
+ nvinfer1::IExecutionContext* trt_execution_context_ptr) {
+ const bool kRetry = true;
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
- const string inp_name = StrCat(kInputPHName, i);
+ const string input_name = StrCat(kInputPHName, i);
const size_t binding_index =
- trt_engine_ptr->getBindingIndex(inp_name.c_str());
+ trt_engine_ptr->getBindingIndex(input_name.c_str());
+ if (binding_index == -1) {
+ LOG(ERROR) << "Input node not found, at " << input_name;
+ return kRetry;
+ }
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (num_batch != input_shape.dim_size(0)) {
- LOG(ERROR) << "input data inconsistent batch size";
- ctx->SetStatus(tensorflow::errors::FailedPrecondition(
- "Different batch sizes between input tensors"));
- return;
+ LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
+ << " vs " << input_shape.dim_size(0);
+ return kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
@@ -310,14 +329,10 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "FP16 inputs are not supported!"));
- return;
+ return kRetry;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "INT8 inputs are not supported!"));
- return;
+ return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
@@ -325,9 +340,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
#endif
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unknown output TRT data type! ", static_cast<int>(dtype)));
- return;
+ return kRetry;
}
}
@@ -344,20 +357,23 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
- OP_REQUIRES_OK(
- ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
- &output_shape));
+ auto status = TensorShapeUtils::MakeShape(
+ trt_shape.data(), trt_shape.size(), &output_shape);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed to get output shape: " << status;
+ return kRetry;
+ }
} else {
- LOG(ERROR) << "output node not found, at " << output_name;
- ctx->SetStatus(tensorflow::errors::Internal("output ", output_name,
- " couldn't be found!"));
- return;
+ LOG(ERROR) << "Output node not found, at " << output_name;
+ return kRetry;
}
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
if (!status.ok()) {
LOG(ERROR) << "Allocating output failed with " << status;
ctx->SetStatus(status);
- return;
+ // Do not retry since we cannot allocate the same output twice.
+ // TODO(aaroey): ideally we should retry, fix this.
+ return !kRetry;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
@@ -366,15 +382,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
- LOG(ERROR) << "half size is not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Half outputs are not supported!"));
- return;
+ LOG(WARNING) << "half size is not supported yet!";
+ return kRetry;
case nvinfer1::DataType::kINT8:
- LOG(ERROR) << "int8 is not supported yet!";
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "INT8 outputs are not supported!"));
- return;
+ LOG(WARNING) << "int8 is not supported yet!";
+ return kRetry;
#if NV_TENSORRT_MAJOR > 3
case nvinfer1::DataType::kINT32:
buffers[binding_index] =
@@ -382,13 +394,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
break;
#endif
default:
- LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
- ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unsupported output data type! ", static_cast<int>(dtype)));
- return;
+ LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
+ return kRetry;
}
}
- // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
@@ -396,15 +406,15 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
->GpuStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
- auto& trt_execution_context_ptr = engine_ctx_pair.second;
auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
nullptr);
if (!ret) {
- LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
- ctx->SetStatus(tensorflow::errors::Internal(
- "Failed to enqueue batch for TRT engine: ", name()));
+ LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
+ return kRetry;
}
- // sync should be done by TF.
+ test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
+ // Synchronization will be done by TF.
+ return !kRetry;
}
TRTEngineOp::~TRTEngineOp() {
@@ -424,8 +434,6 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
- ctx->SetStatus(tensorflow::errors::Internal(
- "Can't get device allocator for device ", device->name()));
return nullptr;
}
allocator_.reset(new TRTDeviceAllocator(alloc));
@@ -452,7 +460,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#if NV_TENSORRT_MAJOR > 3
auto allocator = GetAllocator(ctx);
if (allocator == nullptr) {
- // GetAllocator already set the Status.
return null_pair;
}
infer->setGpuAllocator(allocator);
@@ -469,7 +476,9 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
raw_static_engine->createExecutionContext())};
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
- if (max_batch_size < batch_size) return null_pair;
+ if (max_batch_size < batch_size) {
+ return null_pair;
+ }
return engine_map_.at(max_batch_size);
} // static_engine_
@@ -481,7 +490,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#if NV_TENSORRT_MAJOR > 3
allocator = GetAllocator(ctx);
if (allocator == nullptr) {
- // GetAllocator already set the Status.
return null_pair;
}
#endif
@@ -505,9 +513,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
// retry in the future.
engine_map_[batch_size] = {nullptr, nullptr};
}
- LOG(ERROR) << "Engine creation for batch size " << batch_size
- << " failed " << status;
- ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
+ LOG(WARNING) << "Engine creation for batch size " << batch_size
+ << " failed " << status;
return null_pair;
}
VLOG(1) << "Conversion is done";
@@ -519,7 +526,7 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
- tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) {
+ OpKernelContext* ctx, TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.
@@ -583,7 +590,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
- *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
+ *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*convert_successfully=*/nullptr);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 9265250605..8fe0675891 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -35,7 +35,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-class TRTInt8Calibrator;
+struct TRTInt8Calibrator;
class TRTCalibrationResource;
class AsyncHelper;
// TODO(Sami): Remove this file?
@@ -60,6 +60,12 @@ class TRTEngineOp : public AsyncOpKernel {
// Execute replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
+ // Execute the tensorrt engine. Returns whether we need to retry by running
+ // the native segment.
+ bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch,
+ nvinfer1::ICudaEngine* trt_engine_ptr,
+ nvinfer1::IExecutionContext* trt_execution_context_ptr);
+
// Allocate necessary resources for calibration
Status AllocateCalibrationResources(OpKernelContext* ctx,
TRTCalibrationResource** cr);
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index fe4fa166a1..7cdfe2b1a6 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -20,7 +20,11 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value
+from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled
# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 2b67931661..4116f2fe30 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -20,26 +20,26 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
+from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
+from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
-from tensorflow.python.util import compat
-
+from tensorflow.python.training import saver
# pylint: enable=unused-import,line-too-long
-# TODO(skama): get outputs from session when implemented as c++
-# optimization pass
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
@@ -48,7 +48,7 @@ def create_inference_graph(input_graph_def,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[]):
+ cached_engine_batches=None):
"""Python wrapper for the TRT transformation.
Args:
@@ -87,8 +87,7 @@ def create_inference_graph(input_graph_def,
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
- "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
- )
+ "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
@@ -121,41 +120,42 @@ def create_inference_graph(input_graph_def,
to_bytes = py3bytes
to_string = py3string
- out_names = []
- for i in outputs:
- if isinstance(i, ops.Tensor):
- out_names.append(to_bytes(i.name))
- else:
- out_names.append(to_bytes(i))
-
- input_graph_def_str = input_graph_def.SerializeToString()
-
- # TODO(sami): Fix this when we can return status from C++ library
- # There is a problem with the TF internal library setup that doesn't
- # allow us to return a status object from C++. Thus we return a
- # pair or strings where first one is encoded status and the second
- # one is the transformed graphs protobuf string.
- out = trt_convert(input_graph_def_str, out_names, max_batch_size,
- max_workspace_size_bytes, mode, minimum_segment_size,
- is_dynamic_op, maximum_cached_engines,
- cached_engine_batches)
- status = to_string(out[0])
- output_graph_def_string = out[1]
- del input_graph_def_str # Save some memory
- if len(status) < 2:
- raise _impl.UnknownError(None, None, status)
- if status[:2] != "OK":
- msg = status.split(";")
- if len(msg) == 1:
- raise RuntimeError("Status message is malformed {}".format(status))
- # pylint: disable=protected-access
- raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
- int(msg[0]))
- # pylint: enable=protected-access
- output_graph_def = graph_pb2.GraphDef()
- output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string # Save some memory
- return output_graph_def
+ # Create MetaGraphDef
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ meta_graph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(to_bytes(i.name))
+ else:
+ output_list.append(to_bytes(i))
+ meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+
+ # Create RewriterConfig.
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batches:
+ if not isinstance(cached_engine_batches, list):
+ raise TypeError("cached_engine_batches should be a list.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batches)
+
+ return tf_optimizer.OptimizeGraph(
+ rewriter_cfg, meta_graph, graph_id=b"tf_graph")
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 008fffc954..b43f1b190f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph(
}
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
- VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
+ VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
- VLOG(2) << "... not a TRT candidate";
+ VLOG(3) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output
@@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph(
while (true) {
std::set<const SimpleEdge*> contract_edges;
for (const SimpleEdge* out_edge : node->out_edges()) {
- VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
+ VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
- VLOG(2) << "... ... Control Edge, Skipping";
+ VLOG(3) << "... ... Control Edge, Skipping";
continue;
}
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
- VLOG(2) << "... ... not a TRT candidate";
+ VLOG(3) << "... ... not a TRT candidate";
continue;
}
if (CanContractEdge(out_edge, graph)) {
- VLOG(2) << "... ... can contract";
+ VLOG(3) << "... ... can contract";
contract_edges.insert(out_edge);
} else {
- VLOG(2) << "... ... cannot contract, would form cycle";
+ VLOG(3) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
@@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph(
const SimpleNode* src = contract_edge->src();
const SimpleNode* dst = contract_edge->dst();
- VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
+ VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
@@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph(
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
- std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map;
+ std::map<string, std::set<const tensorflow::Node*>> sg_map;
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the device names that the nodes in the segment are
@@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph(
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
- // For simplicity we use heuristics: for input nodes remove all its
- // input, for output nodes remove all its output. In this way, for common
- // cases the number of removed nodes should be minimum.
+ // For simplicity we use heuristics: for input and const output nodes
+ // remove all their inputs, and for non-const output nodes remove all
+ // their outputs. In this way, for common cases the number of removed
+ // nodes should be minimum.
auto remove_nodes = [&segment_nodes](
bool is_input_nodes,
std::deque<const tensorflow::Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const tensorflow::Node*> visited;
+ std::set<const tensorflow::Node*> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
- for (auto in :
- is_input_nodes ? node->in_nodes() : node->out_nodes()) {
+ for (auto in : (is_input_nodes || node->type_string() == "Const")
+ ? node->in_nodes()
+ : node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
- VLOG(2) << "Need to remove node " << in->name()
- << " because one of its "
- << (is_input_nodes ? "output" : "input")
- << " nodes in the graph was removed: " << node->name();
+ if (VLOG_IS_ON(2)) {
+ if (!logged.count(in)) {
+ VLOG(2) << "----> Need to remove node " << in->name()
+ << " because one of its "
+ << (is_input_nodes ? "output" : "input")
+ << " nodes in the graph was removed: "
+ << node->name();
+ logged.insert(in);
+ }
+ }
}
}
}
@@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph(
for (const auto& itr : sg_map) {
const std::set<const tensorflow::Node*>& segment_nodes = itr.second;
if (VLOG_IS_ON(1)) {
- string s;
+ string s = "parent=" + itr.first + ":";
for (auto node : segment_nodes) s += " " + node->name();
VLOG(1) << "Segment " << segments->size() << ": " << s;
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 432e7b1c04..5937fa8259 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) {
// Make add5 not a TRT candidate, and we expect two segments.
auto without_add5 = all_adds - "add5";
RunTest(&g, without_add5, without_add5, without_add5,
- {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}});
+ {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
// Make add8 not a candidate and add6 not an input candidate, then all direct
// and indirect inputs of add6 will be removed from the segment.
@@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) {
const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
"add4", "add5", "add6", "add7"};
RunTest(&g, all_adds - "add2", all_adds, all_adds,
- {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}});
+ {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
}
} // namespace test
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
index 3712a9a6fe..769982c645 100644
--- a/tensorflow/contrib/tensorrt/tensorrt_test.cc
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
@@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input,
}
TEST(TensorrtTest, BasicFunctions) {
+ // Handle the case where the test is run on machine with no gpu available.
+ if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) {
+ LOG(WARNING) << "No gpu device available, probably not being run on a gpu "
+ "machine. Skipping...";
+ return;
+ }
+
// Create the network model.
nvinfer1::IHostMemory* model = CreateNetwork();
// Use the model to create an engine and then an execution context.
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index edd30ad7a9..8ea5a63735 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -20,17 +20,19 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
-class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing single segment."""
@@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(100, 6, 6, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing multiple segment."""
@@ -95,32 +101,246 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
padding="SAME",
name="conv")
c1 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- p = conv * c1
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1")
+ p = math_ops.mul(conv, c1, name="mul")
c2 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- q = conv / c2
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2")
+ q = math_ops.div(conv, c2, name="div")
- edge = self.trt_incompatible_op(q)
- edge /= edge
- r = edge + edge
+ edge = self.trt_incompatible_op(q, name="incompatible")
+ edge = math_ops.div(edge, edge, name="div1")
+ r = math_ops.add(edge, edge, name="add")
- p -= edge
- q *= edge
- s = p + q
- s -= r
+ p = math_ops.sub(p, edge, name="sub")
+ q = math_ops.mul(q, edge, name="mul1")
+ s = math_ops.add(p, q, name="add1")
+ s = math_ops.sub(s, r, name="sub1")
array_ops.squeeze(s, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
+ # "add", "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(100, 12, 12, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-# TODO(aaroey): add a large complex graph to test.
+class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestA, self).setUp()
+ # Let it fail to build the second engine.
+ trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ for i in range(2):
+ c = constant_op.constant(1.0, name="c%d" % i)
+ n = math_ops.add(n, c, name="add%d" % i)
+ n = math_ops.mul(n, n, name="mul%d" % i)
+ edge = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([edge]):
+ c = constant_op.constant(1.0, name="c2")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul2")
+ c = constant_op.constant(1.0, name="c3")
+ n = math_ops.add(n, c, name="add3")
+ n = math_ops.mul(n, n, name="mul3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class PartiallyConvertedTestB(PartiallyConvertedTestA):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestB, self).setUp()
+ # Let it fail to build the first engine.
+ trt_convert.clear_test_values("")
+ trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ return super(PartiallyConvertedTestB, self).GetParams()._replace(
+ expected_engines={
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ })
+
+
+class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ # Adds control dependency from the constant op to a trt incompatible op,
+ # and adds control dependency from the trt incompatible op to all other
+ # ops, to make sure the constant op cannot be contracted with any trt
+ # segment that depends on it.
+ with g.control_dependencies([c]):
+ d = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
+ "my_trt_op_1": ["add", "add1", "mul"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ c1 = constant_op.constant(1.0, name="c1")
+ c2 = constant_op.constant(1.0, name="c2")
+ d1 = constant_op.constant(1.0, name="d1")
+ d2 = self.trt_incompatible_op(inp, name="d2")
+ with g.control_dependencies([d1, d2]):
+ add = math_ops.add(inp, c1, name="add")
+ with g.control_dependencies([d1, d2]):
+ mul = math_ops.mul(add, add, name="mul")
+ with g.control_dependencies([d1, d2]):
+ add1 = math_ops.add(mul, mul, name="add1")
+ edge = self.trt_incompatible_op(add1, name="incompatible")
+ with g.control_dependencies([d1, d2, add, mul]):
+ add2 = math_ops.add(edge, c2, name="add2")
+ with g.control_dependencies([d1, d2, add1, mul]):
+ mul1 = math_ops.mul(add2, add2, name="mul1")
+ with g.control_dependencies([d1, d2, add, add1]):
+ add3 = math_ops.add(mul1, mul1, name="add3")
+ array_ops.squeeze(add3, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 730b6843fb..2e1107e303 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(12, 5, 8, 7),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 0c03a10b64..8be32f59b4 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=7,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
+ ],
expected_output_dims=(48, 89),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index dd673463a5..9316b14da0 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=16,
+ expected_engines=[
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ],
expected_output_dims=(5, 23040),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 8c51c45b0a..1874b9dd45 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 126),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 97b29bf05d..8c59000b70 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=['my_trt_op_0'],
expected_output_dims=(5, 12, 12, 1),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
new file mode 100644
index 0000000000..66eb6be757
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -0,0 +1,72 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Testing conversion of BatchMatMul in TF-TRT conversion."""
+ dtype = dtypes.float32
+ input_name = "input"
+ input_dims = [2, 15, 15, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
+ with g.device("/GPU:0"):
+ e1 = constant_op.constant(
+ np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype)
+ e2 = constant_op.constant(
+ np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=e1,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv")
+ out = nn.conv2d(
+ input=conv,
+ filter=e2,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv_2")
+ array_ops.squeeze(out, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines=["my_trt_op_0"],
+ expected_output_dims=(2, 15, 15, 10),
+ allclose_atol=1.e-02,
+ allclose_rtol=1.e-02)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index 734ccf6345..fd55b8cd99 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 50265c0845..51c905a50b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
- t = conv * b
- e = gen_math_ops.tan(conv)
- t = t - e
+ t = math_ops.mul(conv, b, name="mul")
+ e = self.trt_incompatible_op(conv, name="incompatible")
+ t = math_ops.sub(t, e, name="sub")
array_ops.squeeze(t, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines={
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ },
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 60b8eb6e81..6f85ada464 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from collections import namedtuple
import itertools
+import os
import warnings
import numpy as np
import six
@@ -30,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -37,10 +39,14 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "num_expected_engines",
+ "gdef", "input_names", "input_dims", "expected_engines",
"expected_output_dims", "allclose_atol", "allclose_rtol"
])
+RunParams = namedtuple(
+ "RunParams",
+ ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -48,6 +54,12 @@ def _IsQuantizationMode(mode):
return mode == "INT8"
+class GraphState(object):
+ ORIGINAL = 0
+ CALIBRATE = 1
+ INFERENCE = 2
+
+
class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@@ -63,50 +75,96 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def precision_modes(self):
return ["FP32", "FP16", "INT8"]
+ # str is bytes in py2, but unicode in py3.
+ def _ToUnicode(self, s):
+ if six.PY2:
+ if isinstance(s, unicode):
+ return s
+ return s.decode("utf-8")
+ else:
+ if isinstance(s, str):
+ return s
+ return s.decode("utf-8")
+
def _ToBytes(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
- return s.encode("utf-8")
+ if isinstance(s, str):
+ return s.encode("utf-8")
+ return s
def _ToString(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
+ if isinstance(s, str):
+ return s
return s.decode("utf-8")
+ @classmethod
+ def setUpClass(cls):
+ """Setup method for the module."""
+ super(TfTrtIntegrationTestBase, cls).setUpClass()
+ trt_convert.enable_test_value()
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
warnings.simplefilter("always")
+ trt_convert.clear_test_values("")
def GetParams(self):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _GetConfigProto(self,
- params,
- use_optimizer,
- precision_mode=None,
- is_dynamic_op=None):
+ def _PrepareRun(self, params, graph_state):
+ """Set up necessary testing environment before calling sess.run()."""
+ # Clear test values added by TRTEngineOp.
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
+
+ def _VerifyRun(self, params, graph_state):
+ """Verify the state after sess.run()."""
+ for engine_name in params.expected_engines:
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+
+ def _GetConfigProto(self, params, run_params, graph_state):
"""Get config proto based on specific settings."""
- if use_optimizer:
+ if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["minimum_segment_size"].i = 2
custom_op.parameter_map["max_batch_size"].i = max(
[dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- precision_mode)
+ run_params.precision_mode)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
gpu_options = config_pb2.GPUOptions()
+ gpu_options.allow_growth = True
if trt_convert.get_linked_tensorrt_version()[0] == 3:
gpu_options.per_process_gpu_memory_fraction = 0.50
@@ -114,7 +172,26 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
gpu_options=gpu_options, graph_options=graph_options)
return config
- def _RunGraph(self, params, gdef, input_data, config, num_runs=2):
+ def _ExpectTestValue(self, engine_name, method, expected_value):
+ label = "%s:%s" % (engine_name, method)
+ actual_value = trt_convert.get_test_value(label)
+ self.assertEqual(
+ expected_value,
+ actual_value,
+ msg="Unexpected test value with label %s. Actual: %s; expected: %s" %
+ (label, actual_value, expected_value))
+
+ def _ExpectCalibration(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteCalibration", value)
+
+ def _ExpectTrtEngine(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value)
+
+ def _ExpectNativeSegment(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
+
+ def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ num_runs=2):
"""Run given graphdef multiple times."""
assert len(params.input_names) == len(input_data)
g = ops.Graph()
@@ -131,93 +208,170 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
+ self._PrepareRun(params, graph_state)
new_val = sess.run(out,
{inp[i]: input_data[i] for i in range(len(inp))})
self.assertEqual(params.expected_output_dims, new_val.shape)
if val is not None:
self.assertAllEqual(val, new_val)
val = new_val
+ self._VerifyRun(params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
def _RunCalibration(self, params, gdef, input_data, config):
"""Run calibration on given graph."""
- return self._RunGraph(params, gdef, input_data, config, 30)
+ return self._RunGraph(
+ params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op):
+ def _GetTrtGraphDef(self, params, run_params, gdef):
"""Return trt converted graphdef."""
return trt_convert.create_inference_graph(
input_graph_def=gdef,
outputs=[self.output_name],
max_batch_size=max([dims[0] for dims in params.input_dims]),
max_workspace_size_bytes=1 << 25,
- precision_mode=precision_mode,
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
- is_dynamic_op=is_dynamic_op)
-
- def _VerifyGraphDef(self,
- params,
- gdef,
- precision_mode=None,
- is_calibrated=None,
- dynamic_engine=None):
+ is_dynamic_op=run_params.dynamic_engine)
+
+ def _WriteGraph(self, params, run_params, gdef, graph_state):
+ if graph_state == GraphState.ORIGINAL:
+ label = "Original"
+ elif graph_state == GraphState.CALIBRATE:
+ label = "CalibEngine"
+ elif graph_state == GraphState.INFERENCE:
+ label = "InferEngine"
+ graph_name = (
+ self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
+ ".pbtxt")
+ temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
+
+ def _VerifyConnections(self, params, converted_gdef):
+ old_to_new_node_map = {
+ self._ToString(node.name): self._ToString(node.name)
+ for node in params.gdef.node
+ }
+ for engine_name, node_names in params.expected_engines.items():
+ for node_name in node_names:
+ old_to_new_node_map[node_name] = engine_name
+ name_to_node_map = {
+ self._ToString(node.name): node for node in params.gdef.node
+ }
+
+ def _InputName(inp):
+ inp = self._ToString(inp)
+ prefix = ""
+ if inp[0] == "^":
+ prefix = "^"
+ inp = inp[1:]
+ parts = inp.split(":")
+ if len(parts) > 1 and parts[-1].isdigit():
+ inp = inp[:-len(parts[-1]) - 1]
+ return (prefix, inp)
+
+ expected_input_map = {}
+ for node in params.gdef.node:
+ name_str = self._ToString(node.name)
+ target_node_name = old_to_new_node_map[name_str]
+ is_engine_op = (target_node_name != name_str)
+ if target_node_name not in expected_input_map:
+ expected_input_map[target_node_name] = set()
+ input_set = expected_input_map[target_node_name]
+ for inp in node.input:
+ (prefix, inp_name) = _InputName(inp)
+ # Add the input only if it's outside the segment (note that it could be
+ # in a different engine).
+ if (not is_engine_op or
+ old_to_new_node_map[inp_name] != target_node_name):
+ if is_engine_op and name_to_node_map[inp_name].op == "Const":
+ # Const data input nodes to the segment has been copied to the
+ # segment graphdef and the engine, and the dependency has been
+ # converted to control dependendy.
+ input_set.add("^" + old_to_new_node_map[inp_name])
+ else:
+ input_set.add(prefix + old_to_new_node_map[inp_name])
+
+ actual_input_map = {}
+ for node in converted_gdef.node:
+ name_str = self._ToString(node.name)
+ actual_input_map[name_str] = set()
+ input_set = actual_input_map[name_str]
+ for inp in node.input:
+ (prefix, node_name) = _InputName(inp)
+ input_set.add(prefix + node_name)
+
+ self.assertEqual(
+ expected_input_map,
+ actual_input_map,
+ msg="expected:\n%s\nvs actual:\n%s" % (sorted(
+ expected_input_map.items()), sorted(actual_input_map.items())))
+
+ def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
+ self._WriteGraph(params, run_params, gdef, graph_state)
+
num_engines = 0
- for n in gdef.node:
- # TODO(jie): we should have coverage for failed conversion (TF fallback).
- # where the conversion will fail and we shouldn't count this engine as the
- # converted engines.
- if n.op == "TRTEngineOp":
+ for node in gdef.node:
+ if node.op == "TRTEngineOp":
num_engines += 1
- self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s)
- self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s)
+ self.assertTrue(node.name in params.expected_engines)
+ self.assertTrue(len(node.attr["serialized_segment"].s))
+ self.assertTrue(len(node.attr["segment_funcdef_name"].s))
self.assertEqual(
- self._ToBytes(precision_mode), n.attr["precision_mode"].s)
- self.assertEqual(not dynamic_engine, n.attr["static_engine"].b)
- if _IsQuantizationMode(precision_mode) and is_calibrated:
- self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s)
+ self._ToBytes(run_params.precision_mode),
+ node.attr["precision_mode"].s)
+
+ is_dynamic_engine = not node.attr["static_engine"].b
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+
+ has_calibration_data = len(node.attr["calibration_data"].s)
+ if (_IsQuantizationMode(run_params.precision_mode) and
+ graph_state == GraphState.INFERENCE):
+ self.assertTrue(has_calibration_data)
else:
- self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s)
- if precision_mode is None: # This means gdef is the original GraphDef.
+ self.assertFalse(has_calibration_data)
+ if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, params.num_expected_engines)
+ self.assertEqual(num_engines, len(params.expected_engines))
+ if isinstance(params.expected_engines, dict):
+ self._VerifyConnections(params, gdef)
+ # TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine):
- assert precision_mode in PRECISION_MODES
+ def RunTest(self, params, run_params):
+ assert run_params.precision_mode in PRECISION_MODES
input_data = [np.random.random_sample(dims) for dims in params.input_dims]
input_gdef = params.gdef
- self._VerifyGraphDef(params, input_gdef)
+ self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, False)
+ config_no_trt = self._GetConfigProto(params, run_params,
+ GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt)
+ ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
+ GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(precision_mode):
+ if _IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_calib_engine)
+ calib_config = self._GetConfigProto(params, run_params,
+ GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
- if use_optimizer:
- self.assertTrue(False)
- # TODO(aaroey): uncomment this and get infer_gdef when this mode is
- # supported.
- # result = self._RunCalibration(params, input_gdef, input_data,
- # calib_config)
+ if run_params.use_optimizer:
+ result = self._RunCalibration(params, input_gdef, input_data,
+ calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
- dynamic_calib_engine)
- self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
- dynamic_calib_engine)
+ calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
+ self._VerifyGraphDef(params, run_params, calib_gdef,
+ GraphState.CALIBRATE)
result = self._RunCalibration(params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
- dynamic_calib_engine)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
+ self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -228,18 +382,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_infer_engine)
+ infer_config = self._GetConfigProto(params, run_params,
+ GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config)
+ if run_params.use_optimizer:
+ result = self._RunGraph(params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
else:
- trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
- dynamic_infer_engine)
- self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
- dynamic_infer_engine)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config)
+ trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
+ self._VerifyGraphDef(params, run_params, trt_infer_gdef,
+ GraphState.INFERENCE)
+ result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -262,66 +417,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _AddTests(test_class):
"""Adds test methods to TfTrtIntegrationTestBase."""
- def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine):
+ def _GetTest(run_params):
"""Gets a single test method based on the parameters."""
def _Test(self):
params = self.GetParams()
logging.info(
- "Running test with parameters: use_optimizer=%s, precision_mode=%s, "
- "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer,
- precision_mode, dynamic_infer_engine, dynamic_calib_engine)
- self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine)
+ "Running test %s with parameters: use_optimizer=%s, "
+ "precision_mode=%s, dynamic_engine=%s",
+ "testTfTrt_" + run_params.test_name, run_params.use_optimizer,
+ run_params.precision_mode, run_params.dynamic_engine)
+ self.RunTest(params, run_params)
return _Test
use_optimizer_options = [False, True]
- dynamic_infer_engine_options = [False, True]
- dynamic_calib_engine_options = [False, True]
- for (use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
- use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options,
- dynamic_calib_engine_options):
+ dynamic_engine_options = [False, True]
+ for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
+ use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
if _IsQuantizationMode(precision_mode):
- if not dynamic_calib_engine and dynamic_infer_engine:
- # TODO(aaroey): test this case, the conversion from static calibration
- # engine to dynamic inference engine should be a noop.
- continue
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
# supported yet.
continue
- if not dynamic_calib_engine:
+ if not dynamic_engine:
# TODO(aaroey): construction of static calibration engine is not
# supported yet.
continue
- if dynamic_calib_engine and not dynamic_infer_engine:
- # TODO(aaroey): construction of static inference engine using dynamic
- # calibration engine is not supported yet.
- continue
- else: # In non int8 mode.
- if dynamic_calib_engine:
- # dynamic_calib_engine doesn't affect non-int8 modes, so just let
- # related tests run once on dynamic_calib_engine=False.
- continue
conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
- infer_engine_type = ("DynamicInferEngine"
- if dynamic_infer_engine else "StaticInferEngine")
- calib_engine_type = ""
- if precision_mode == "INT8":
- calib_engine_type = ("DynamicCalibEngine"
- if dynamic_calib_engine else "StaticCalibEngine")
- test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type,
- ("_" + calib_engine_type)
- if len(calib_engine_type) else "")
- setattr(
- test_class, "testTfTRT_" + test_name,
- _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine))
+ engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine")
+ test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type)
+ run_params = RunParams(
+ use_optimizer=use_optimizer,
+ precision_mode=precision_mode,
+ dynamic_engine=dynamic_engine,
+ test_name=test_name)
+ setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
if trt_convert.is_tensorrt_enabled():
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index b9e977cf67..500057a36d 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- num_expected_engines=5,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ],
expected_output_dims=(12, 5, 8, 12),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc
new file mode 100644
index 0000000000..276308b3a0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/test/utils.h"
+
+#include <unordered_map>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// TODO(aaroey): make this class thread-safe.
+class TestValueManager {
+ public:
+ static TestValueManager* singleton() {
+ static TestValueManager* manager = new TestValueManager();
+ return manager;
+ }
+
+ void Enable() {
+ VLOG(1) << "Enabling test value";
+ enabled_ = true;
+ }
+
+ void Add(const string& label, const string& value) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ QCHECK_NE("", value);
+ VLOG(1) << "Adding test value: " << label << " -> " << value;
+ values_.insert({label, value});
+ }
+ }
+
+ string Get(const string& label) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Getting test value by " << label;
+ auto itr = values_.find(label);
+ if (itr == values_.end()) return "";
+ return itr->second;
+ }
+ return "";
+ }
+
+ void Clear(const string& pattern) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Clearing test values";
+ if (pattern.empty()) {
+ values_.clear();
+ return;
+ }
+ std::vector<string> keys_to_clear;
+ for (const auto& kv : values_) {
+ if (RE2::FullMatch(kv.first, pattern)) {
+ keys_to_clear.push_back(kv.first);
+ }
+ }
+ for (const string& key : keys_to_clear) {
+ values_.erase(key);
+ }
+ }
+ }
+
+ private:
+ TestValueManager() : enabled_(false) {}
+
+ bool enabled_;
+ std::unordered_map<string, string> values_;
+};
+
+void EnableTestValue() { TestValueManager::singleton()->Enable(); }
+
+void ClearTestValues(const string& pattern) {
+ TestValueManager::singleton()->Clear(pattern);
+}
+
+void AddTestValue(const string& label, const string& value) {
+ TestValueManager::singleton()->Add(label, value);
+}
+
+string GetTestValue(const string& label) {
+ return TestValueManager::singleton()->Get(label);
+}
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h
new file mode 100644
index 0000000000..4bb4120206
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// Helper methods to inject values used by testing tools.
+void EnableTestValue();
+void ClearTestValues(const string& pattern);
+void AddTestValue(const string& label, const string& value);
+string GetTestValue(const string& label);
+
+#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \
+ do { \
+ if (::tensorflow::tensorrt::test::GetTestValue(label) == \
+ value_to_return) { \
+ return errors::Internal("Injected manually"); \
+ } \
+ } while (0)
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index 2b134c3bce..ab4d224db4 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 6, 2, 2),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index bec2f23eff..56bdf848ea 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 2, 2, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 422740fdf6..6ea15fb8ef 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -101,82 +101,22 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
%}
%ignoreall
%unignore tensorflow;
-%unignore trt_convert;
%unignore calib_convert;
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
%unignore is_tensorrt_enabled;
+%unignore enable_test_value;
+%unignore clear_test_values;
+%unignore add_test_value;
+%unignore get_test_value;
%{
-std::pair<string, string> trt_convert(
- string graph_def_string, // The serialized GraphDef string.
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode,
- int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches
- // Unfortunately we can't use TF_Status here since it
- // is in c/c_api and brings in a lot of other libraries
- // which in turn declare ops. These ops are included
- // statically in our library and cause an abort when
- // module is loaded due to double registration
- // until Tensorflow properly exposes these headers
- // we have to work around this by returning a string
- // and converting it to exception on python side.
- //,TF_Status* out_status) {
-) {
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
- string out_status;
-
- tensorflow::GraphDef graph_def;
- if (!graph_def.ParseFromString(graph_def_string)) {
- out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
-
- if (precision_mode < 0 || precision_mode > 2) {
- out_status = "InvalidArgument;Invalid precision_mode";
- return std::pair<string, string>{out_status, ""};
- }
- if (!output_names.size()) {
- out_status = "InvalidArgument;Size of the output_names vector is 0";
- return std::pair<string, string>{out_status, ""};
- }
- tensorflow::GraphDef out_graph;
- tensorflow::Status conversion_status =
- tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
- graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &out_graph, precision_mode, minimum_segment_size,
- is_dyn_op, max_cached_engines, cached_engine_batches);
- if (!conversion_status.ok()) {
- auto retCode = (int)conversion_status.code();
- char buff[2000];
- snprintf(buff, 2000, "%d;%s", retCode,
- conversion_status.error_message().c_str());
- out_status = buff;
- return std::pair<string, string>{out_status, ""};
- }
- string result;
- if (!out_graph.SerializeToString(&result)) {
- out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
- out_status = "OK;All good!";
- return std::pair<string, string>{out_status, result};
-#else
- // Returns FAILED_PRECONDITION.
- return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
-#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
-}
-
std::pair<string, string> calib_convert(
string graph_def_string, bool is_dyn_op
// unfortunately we can't use TF_Status here since it
@@ -251,20 +191,44 @@ bool is_tensorrt_enabled() {
return tensorflow::tensorrt::IsGoogleTensorRTEnabled();
}
-%}
+void enable_test_value() {
+ tensorflow::tensorrt::test::EnableTestValue();
+}
+
+#if PY_MAJOR_VERSION < 3
+#define TRT_PY_TO_CPP_STRING PyString_AsString
+#define TRT_CPP_TO_PY_STRING PyString_FromString
+#else
+#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8
+#define TRT_CPP_TO_PY_STRING PyUnicode_FromString
+#endif
+
+void clear_test_values(PyObject* pattern) {
+ tensorflow::tensorrt::test::ClearTestValues(
+ string(TRT_PY_TO_CPP_STRING(pattern)));
+}
+
+void add_test_value(PyObject* label, PyObject* value) {
+ tensorflow::tensorrt::test::AddTestValue(
+ string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value)));
+}
-std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
+PyObject* get_test_value(PyObject* label) {
+ string value = tensorflow::tensorrt::test::GetTestValue(
+ string(TRT_PY_TO_CPP_STRING(label)));
+ return TRT_CPP_TO_PY_STRING(value.c_str());
+}
-std::pair<string, string> trt_convert(string graph_def_string,
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode, int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches);
+%}
+
+std::pair<string, string> calib_convert(
+ string graph_def_string, bool is_dyn_op);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();
bool is_tensorrt_enabled();
+void enable_test_value();
+void clear_test_values(PyObject* pattern);
+void add_test_value(PyObject* label, PyObject* value);
+PyObject* get_test_value(PyObject* label);
%unignoreall