aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-30 00:27:58 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-30 00:27:58 -0700
commit1009f9de414365d0f2401c51b6e023374ad11ad6 (patch)
tree81b51f02196562a9f9f7e4075643c972aff93a0f /tensorflow/contrib/tensorrt
parentb1e7f284443b6e0220ffd1d5ba728340c768649f (diff)
Fix control dependency problems and add corresponding tests.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD14
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc506
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc41
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h42
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc17
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py4
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py4
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py144
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py349
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.cc101
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.h43
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i25
22 files changed, 876 insertions, 457 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 033d5207f6..a1071d6749 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -85,11 +85,12 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
+ ":test_utils",
":trt_allocator",
+ ":trt_conversion",
":trt_logging",
":trt_plugins",
":trt_resources",
- ":trt_conversion",
":utils",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -192,6 +193,7 @@ tf_py_wrap_cc(
"//tensorflow/python:platform/base.i",
],
deps = [
+ ":test_utils",
":trt_conversion",
":trt_engine_op_kernel",
"//third_party/python_runtime:headers",
@@ -264,6 +266,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":test_utils",
":trt_allocator",
":trt_plugins",
":trt_logging",
@@ -412,3 +415,12 @@ cc_library(
hdrs = ["convert/utils.h"],
copts = tf_copts(),
)
+
+cc_library(
+ name = "test_utils",
+ srcs = ["test/utils.cc"],
+ hdrs = ["test/utils.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 22909a199d..1e6300578d 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <map>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
@@ -49,9 +51,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
+#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
-#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
+#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
#include "tensorflow/core/util/device_name_utils.h"
#if GOOGLE_CUDA
@@ -260,63 +262,6 @@ tensorflow::Status ConvertGraphDefToTensorRT(
return ConvertAfterShapes(cp);
}
-bool IsUniformTensorValue(const tensorflow::TensorProto& tensor) {
- using tensorflow::DataType;
- switch (tensor.dtype()) {
- case DataType::DT_HALF: // fall-through
- case DataType::DT_BFLOAT16:
- return tensor.half_val_size() == 1;
- case DataType::DT_FLOAT:
- return tensor.float_val_size() == 1;
- case DataType::DT_DOUBLE:
- return tensor.double_val_size() == 1;
- case DataType::DT_INT32: // fall-through
- case DataType::DT_INT16: // fall-through
- case DataType::DT_INT8: // fall-through
- case DataType::DT_UINT8:
- return tensor.int_val_size() == 1;
- case DataType::DT_STRING:
- return tensor.string_val_size() == 1;
- case DataType::DT_COMPLEX64:
- return tensor.scomplex_val_size() == 1;
- case DataType::DT_INT64:
- return tensor.int64_val_size() == 1;
- case DataType::DT_BOOL:
- return tensor.bool_val_size() == 1;
- case DataType::DT_COMPLEX128:
- return tensor.dcomplex_val_size() == 1;
- case DataType::DT_RESOURCE:
- return tensor.resource_handle_val_size() == 1;
- case DataType::DT_VARIANT:
- return tensor.variant_val_size() == 1;
- case DataType::DT_UINT32:
- return tensor.uint32_val_size() == 1;
- case DataType::DT_UINT64:
- return tensor.uint64_val_size() == 1;
- default:
- return false;
- }
-}
-
-std::unordered_set<int> GetAttributeInputs(const tensorflow::Node* node) {
- typedef std::unordered_map<string, std::unordered_set<int>> InputMap;
- static const InputMap attribute_inputs = {
- {"Concat", {0}}, {"ConcatV2", {-1}}, {"Reshape", {1}}};
- auto iter = attribute_inputs.find(node->type_string());
- if (iter != attribute_inputs.end()) {
- // Apply reverse indexing
- std::unordered_set<int> result;
- for (int idx : iter->second) {
- if (idx < 0) {
- idx += node->num_inputs();
- }
- result.insert(idx);
- }
- return result;
- }
- return {};
-}
-
// Function to get subsegment information structure.
tensorflow::Status GetEngineInfo(
const tensorflow::Graph* g,
@@ -325,13 +270,10 @@ tensorflow::Status GetEngineInfo(
const std::unordered_map<string, tensorflow::Node*>& node_map,
const std::vector<tensorflow::Node*>& reverse_topo_order,
EngineInfo* info) {
- std::vector<int> subgraph_node_ids;
+ std::vector<int> subgraph_node_ids; // Topologically sorted node ids.
+ std::set<string> subgraph_node_names = segment_nodes;
std::set<int> added_const_node_ids; // Used to prevent double insertion.
std::set<string> segment_devices;
- std::unordered_set<string> segment_consts;
- std::vector<int> const_node_ids;
- int input_port = 0;
- int output_port = 0;
// Map from src_node_name+port to the unique port numbers of the TRT op, where
// the src_node_name is the name of the source node of the input/output
@@ -339,7 +281,7 @@ tensorflow::Status GetEngineInfo(
// input/output edges must be in different split of the graph.
// TODO(aaroey): consider using node id and port instead.
// TODO(aaroey): using topo order instead of reverting reverse topo order.
- std::unordered_map<string, int> created_edges;
+ std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
++it) {
const auto& node_name = (*it)->name();
@@ -358,133 +300,114 @@ tensorflow::Status GetEngineInfo(
}
const int node_id = node->id();
subgraph_node_ids.push_back(node_id);
+ // Create input connections.
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
- if (input_node->IsSource()) continue;
- if (segment_nodes.count(input_node->name()) == 0) {
- // Add constant input node into the segment. We don't care if it has
- // other output edges going into other engines or TF nodes. Since we add
- // it only to the subsegment node list, not the subsegment itself, it
- // won't be removed from the graph. If it doesn't have any edges, TF
- // will prune it out.
- if (input_node->type_string() == "Const") {
- bool is_supported = input_node->output_type(0) == DT_FLOAT ||
- input_node->output_type(0) == DT_INT32;
- bool is_attribute_input =
- GetAttributeInputs(node).count(edge->dst_input()) != 0;
- const tensorflow::TensorProto& tensor_proto =
- input_node->def().attr().at("value").tensor();
- bool is_uniform = IsUniformTensorValue(tensor_proto);
-
- // Const can be absorbed
- if (is_supported && is_attribute_input && is_uniform) {
- if (segment_consts.count(input_node->name()) != 0) {
- continue; // skip if already added
- }
- VLOG(0) << "Adding const node " << input_node->name();
- const_node_ids.push_back(input_node->id());
- segment_consts.insert(input_node->name());
- int conn_count = 0;
- for (auto cinp_e :
- input_node->in_edges()) { // must be Control edges
- if (!cinp_e->IsControlEdge() || cinp_e->src()->IsSource()) {
- conn_count++;
- continue;
- }
- VLOG(0) << info->engine_name << ": Control edge " << conn_count
- << " from node " << input_node->name()
- << " edge= " << cinp_e->src()->name();
- auto cinp = cinp_e->src();
- EngineConnection ec(cinp->name(), cinp->id(),
- cinp_e->src_output(), input_node->name(),
- input_node->id(), cinp_e->dst_input(), true,
- -1, true);
- info->connections.emplace_back(std::move(ec));
- }
- continue;
- }
+ if (input_node->IsSource() || segment_nodes.count(input_node->name())) {
+ continue;
+ }
+ if (edge->IsControlEdge()) {
+ // Control input.
+ info->connections.emplace_back(
+ input_node->name(), input_node->id(), node_name, node_id,
+ /*input_edge=*/true);
+ } else if (input_node->type_string() == "Const") {
+ // Add constant data input nodes into the segment graphdef (thus also in
+ // the engine). We don't care if it has other output edges going into
+ // other engines or TF nodes. Since we add it only to the segment
+ // graphdef, not the segment itself, it won't be removed from the graph.
+ // If it doesn't have any edges, TF will prune it out.
+ // Note that the constant data input must be supported by the engine
+ // regardless of the datatype, since the segmenter already removed
+ // unsupported data input nodes.
+ if (!added_const_node_ids.insert(input_node->id()).second) {
+ // Already added before.
+ continue;
}
-
- // Non-const data/control edge
- if (!edge->IsControlEdge()) {
- string s(input_node->name());
- StrAppend(&s, ":", edge->src_output());
- VLOG(1) << "Input edge = " << s;
- int port = input_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
- } else {
- created_edges.insert({s, port});
- input_port++;
- }
- EngineConnection ec(input_node->name(), input_node->id(),
- edge->src_output(), node_name, node_id,
- edge->dst_input(), true, port);
- ec.connection_type = input_node->output_type(edge->src_output());
- info->connections.emplace_back(std::move(ec));
+ VLOG(1) << "Adding const node " << input_node->name();
+ QCHECK(subgraph_node_names.insert(input_node->name()).second);
+#if 1
+ // Since we duplicate the const input node in both the segment graphdef
+ // and the engine, the segment node doesn't depend on it anymore, so we
+ // add a control dependency instead.
+ info->connections.emplace_back(
+ input_node->name(), input_node->id(), node_name, node_id,
+ /*input_edge=*/true);
+#else
+ // Add control inputs to the const node as control input connections to
+ // the engine.
+ for (const auto const_in_edge : input_node->in_edges()) {
+ QCHECK(const_in_edge->IsControlEdge()); // Must be control edge.
+ auto const_in_node = const_in_edge->src();
+ QCHECK(!segment_nodes.count(const_in_node->name()))
+ << "Loop found between segment and non-segment nodes, from "
+ "segment node "
+ << const_in_node->name() << " to non-segment node "
+ << input_node->name() << " to segment node " << node->name();
+ if (const_in_node->IsSource()) continue;
+ VLOG(1) << "Control edge from node " << const_in_node->name()
+ << " to " << input_node->name();
+ info->connections.emplace_back(
+ const_in_node->name(), const_in_node->id(), input_node->name(),
+ input_node->id(), /*input_edge=*/true);
+ }
+#endif
+ } else {
+ // Non-const data input.
+ int port = Graph::kControlSlot - 1;
+ // Use the source non-segment node name/port as key.
+ const string s = StrCat(input_node->name(), ":", edge->src_output());
+ VLOG(1) << "Input edge = " << s;
+ if (input_to_engine_port.count(s)) {
+ port = input_to_engine_port.at(s);
} else {
- EngineConnection ec(input_node->name(), input_node->id(),
- edge->src_output(), node_name, node_id,
- edge->dst_input(), true, -1, true);
- ec.connection_type = input_node->output_type(edge->src_output());
- info->connections.emplace_back(std::move(ec));
+ port = input_to_engine_port.size();
+ input_to_engine_port.insert({s, port});
}
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ edge->src_output(), node_name, node_id,
+ edge->dst_input(), /*input_edge=*/true,
+ port);
}
}
-
+ // Create output connections.
for (const auto edge : node->out_edges()) {
auto output_node = edge->dst();
- if (output_node->IsSink()) continue;
- if (segment_nodes.count(output_node->name()) == 0) {
- if (!edge->IsControlEdge()) {
- string s(node_name);
- StrAppend(&s, ":", edge->src_output());
- VLOG(1) << "Output edge = " << s;
- int port = output_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
- } else {
- created_edges.insert({s, port});
- output_port++;
- }
- info->connections.emplace_back(output_node->name(), output_node->id(),
- edge->dst_input(), node_name, node_id,
- edge->src_output(), false, port);
- } else {
- info->connections.emplace_back(output_node->name(), output_node->id(),
- edge->dst_input(), node_name, node_id,
- edge->src_output(), false, -1, true);
- }
+ if (output_node->IsSink() || segment_nodes.count(output_node->name())) {
+ continue;
}
- }
- }
-
- // Fix control edges
- for (size_t t = 0; t < info->connections.size(); t++) {
- auto& conn = info->connections.at(t);
- if (conn.is_control_edge) {
- for (size_t k = 0; k < info->connections.size(); k++) {
- if (k == t) continue;
- const auto& other = info->connections.at(k);
- if (conn.outside_id == other.outside_id && other.port_number != -1) {
- VLOG(0) << "Updating control edge " << conn.outside_node_name
- << " -> " << conn.inside_node_name << " to input port "
- << other.port_number;
- conn.port_number = other.port_number;
- break;
+ if (edge->IsControlEdge()) {
+ // Control output.
+ info->connections.emplace_back(
+ output_node->name(), output_node->id(), node_name, node_id,
+ /*input_edge=*/false);
+ } else {
+ // Data output.
+ int port = Graph::kControlSlot - 1;
+ // Use the source segment node name/port as key.
+ const string s = StrCat(node_name, ":", edge->src_output());
+ VLOG(1) << "Output edge = " << s;
+ if (output_to_engine_port.count(s)) {
+ port = output_to_engine_port.at(s);
+ } else {
+ port = output_to_engine_port.size();
+ output_to_engine_port.insert({s, port});
}
+ info->connections.emplace_back(output_node->name(), output_node->id(),
+ edge->dst_input(), node_name, node_id,
+ edge->src_output(), /*input_edge=*/false,
+ port);
}
}
- }
+ } // For each segment node in topological order.
- // Construct the const nodes first
- subgraph_node_ids.insert(subgraph_node_ids.begin(), const_node_ids.begin(),
- const_node_ids.end());
+ // Construct the const nodes first.
+ subgraph_node_ids.insert(subgraph_node_ids.begin(),
+ added_const_node_ids.begin(),
+ added_const_node_ids.end());
TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
- g, graph_properties, subgraph_node_ids, &info->connections,
- &info->segment_graph_def, &info->engine_name));
- info->engine_type = EngineInfo::EngineType::TRTStatic;
-
+ g, graph_properties, subgraph_node_names, subgraph_node_ids,
+ &info->connections, &info->segment_graph_def, &info->engine_name));
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) {
info->device = *segment_devices.begin();
@@ -502,36 +425,34 @@ tensorflow::Status GetEngineInfo(
// Helper function to update edge connection from the removed node to the
// engine node. If an outside node is gone, it must have been absorbed into
// an engine node. Find the engine node.
-void UpdateToEngineNode(tensorflow::Node*& node, string& node_name, int& port,
- const std::vector<EngineInfo>& infos,
- size_t my_engine_id,
+void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
+ const size_t my_engine_id,
const std::vector<Node*>& engine_nodes,
- bool update_input_edge) {
- bool found_engine = false;
+ const bool is_input_edge,
+ const string& node_name,
+ tensorflow::Node** node, int* port) {
for (size_t t = 0; t < infos.size(); ++t) {
if (t == my_engine_id) {
continue;
}
- auto& connected_eng_info = infos.at(t);
- for (const auto& eng_conn : connected_eng_info.connections) {
- if (update_input_edge && eng_conn.is_input_edge) {
- continue;
- } else if (!update_input_edge && !eng_conn.is_input_edge) {
- continue;
- }
+ const auto& info = infos.at(t);
+ for (const auto& eng_conn : info.connections) {
+ // If the connection being updated is an input connection, the source of
+ // the connection must be an output connection of another engine. And vise
+ // versa.
+ if (is_input_edge == eng_conn.is_input_edge) continue;
if (eng_conn.inside_node_name == node_name &&
- eng_conn.inside_port == port) {
- node = engine_nodes[t];
- node_name = connected_eng_info.engine_name;
- port = eng_conn.port_number;
- found_engine = true;
- break;
+ eng_conn.inside_port == *port) {
+ *node = CHECK_NOTNULL(engine_nodes[t]);
+ QCHECK_EQ(info.engine_name, (**node).name())
+ << "Engine name mismatch: " << info.engine_name << " vs "
+ << (**node).name();
+ *port = eng_conn.port_number;
+ return;
}
}
- if (found_engine) break;
}
- CHECK(found_engine);
- CHECK(node != nullptr);
+ LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
}
// Function to insert a TRT engine node into the graph.
@@ -539,114 +460,91 @@ void UpdateToEngineNode(tensorflow::Node*& node, string& node_name, int& port,
// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
// 2. When an engine node is created, add it into the graph with necessary
// re-wiring.
-// 2.1. If the outside connected node is existing, connect the engine
-// node to it.
-// 2.2. If the outside connected node is gone, it must have been absorted
-// into another engine node (which was processed before the processing
-// one). Connect to the pre-existing engine node instead.
+// 2.1. If the outside connected node is existing, connect the engine
+// node to it.
+// 2.2. If the outside connected node is gone, it must have been absorted
+// into another engine node (which was processed before the processing
+// one). Connect to the pre-existing engine node instead.
// 3. In this way, we ensure the graph is topologically sort-able after each
// invocation of CreateTRTNode().
-
-tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
- const std::vector<EngineInfo>& infos, int pos,
- tensorflow::Allocator* alloc,
- int max_batch_size,
- std::vector<Node*>& engine_nodes) {
- auto& info = infos.at(pos);
+tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
+ int max_batch_size, tensorflow::Graph* graph,
+ nvinfer1::IGpuAllocator* alloc,
+ std::vector<Node*>* engine_nodes) {
+ const auto& info = infos.at(pos);
+ TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail");
std::vector<tensorflow::TensorShapeProto> output_shape_protos;
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
- std::vector<tensorflow::PartialTensorShape> shapes;
+ std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
std::vector<tensorflow::Node*> input_nodes;
std::vector<tensorflow::Node*> control_input_nodes;
- std::vector<string> control_input_names;
+ std::unordered_set<string> control_input_names;
std::vector<tensorflow::DataType> out_types;
VLOG(1) << "Processing " << info.engine_name;
-
- // -- Preprocessing -- //
- // collect needed info for creating the engine node in the graph
- for (const auto conn : info.connections) {
- // control edges
- if (conn.is_control_edge) {
- // skip control outputs for now. control output info are not needed for
+ // Collect needed info for creating the engine node in the graph
+ for (const auto& conn : info.connections) {
+ // Control edges
+ if (conn.is_control_edge()) {
+ // Skip control outputs for now. control output info are not needed for
// node creation and will be processed later.
- if (!conn.is_input_edge) {
- continue;
- }
+ if (!conn.is_input_edge) continue;
- // control inputs
+ // Rewrire control input if it's not found in original graph.
tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
- string input_node_name = conn.outside_node_name;
int port = tensorflow::Graph::kControlSlot;
if (!input_node) {
- UpdateToEngineNode(input_node, input_node_name, port, infos, pos,
- engine_nodes, true);
- }
- bool new_input = true;
- for (const auto& name : control_input_names) {
- if (name == input_node_name) {
- new_input = false;
- break;
- }
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
+ QCHECK_EQ(Graph::kControlSlot, port);
}
- if (new_input) {
- control_input_nodes.push_back(input_node);
- control_input_names.push_back(input_node_name);
-
- VLOG(1) << "Engine Control Input " << input_node_name << ":" << port
- << " -> " << info.engine_name << ":"
- << tensorflow::Graph::kControlSlot;
+ if (!control_input_names.insert(input_node->name()).second) {
+ continue;
}
-
- // data edges
+ control_input_nodes.push_back(input_node);
+ VLOG(1) << "Engine Control Input " << input_node->name()
+ << " -> " << info.engine_name;
} else {
- // data outputs
+ // Data edges
if (!conn.is_input_edge) {
+ // Set the shapes and data types of output edge.
tensorflow::TensorShapeProto out_shape;
- conn.inside_shape.AsProto(
- &out_shape); // shape of the output node inside segment
+ // shape of the output node inside segment
+ conn.inside_shape.AsProto(&out_shape);
if (output_shape_protos.size() <= conn.port_number) {
output_shape_protos.resize(conn.port_number + 1);
out_types.resize(conn.port_number + 1);
}
output_shape_protos.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
-
- // data input
} else {
+ // Set the shapes and data types of input edge.
tensorflow::TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
-
if (input_shape_protos.size() <= conn.port_number) {
input_shape_protos.resize(conn.port_number + 1);
- shapes.resize(conn.port_number + 1);
+ input_shapes.resize(conn.port_number + 1);
}
input_shape_protos.at(conn.port_number) = in_shape;
- shapes.at(conn.port_number) = conn.outside_shape;
+ input_shapes.at(conn.port_number) = conn.outside_shape;
+ // Rewrire data input if it's not found in original graph.
tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
- string input_node_name = conn.outside_node_name;
- int input_port = conn.outside_port;
- auto dtype = conn.connection_type;
-
+ int port = conn.outside_port;
if (!input_node) {
- UpdateToEngineNode(input_node, input_node_name, input_port, infos,
- pos, engine_nodes, true);
- }
- bool new_input = true;
- for (const auto& inp : inputs) {
- if (inp.node == input_node_name && inp.index == input_port) {
- new_input = false;
- break;
- }
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
}
- if (new_input) {
- inputs.emplace_back(input_node_name, input_port, dtype);
- CHECK(input_node != nullptr);
- input_nodes.push_back(input_node);
-
- VLOG(1) << "Engine Input " << input_node_name << ":" << input_port
+ if (std::find_if(std::begin(inputs), std::end(inputs),
+ [input_node, &port](
+ const NodeDefBuilder::NodeOut& inp) {
+ return inp.node == input_node->name() &&
+ inp.index == port;
+ }) == std::end(inputs)) {
+ inputs.emplace_back(input_node->name(), port, conn.connection_type);
+ input_nodes.push_back(CHECK_NOTNULL(input_node));
+ VLOG(1) << "Engine Input " << input_node->name() << ":" << port
<< " -> " << info.engine_name << ":" << inputs.size() - 1;
}
}
@@ -662,14 +560,12 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Otherwise we skip node creation for this engine.
Logger trt_logger;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
- std::unique_ptr<TRTDeviceAllocator> allocator(
- new TRTDeviceAllocator(alloc));
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
- max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
- allocator.get(), /*calibrator=*/nullptr, &engine,
+ max_batch_size, info.max_workspace_size_bytes, input_shapes,
+ &trt_logger, alloc, /*calibrator=*/nullptr, &engine,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
@@ -711,7 +607,7 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << ins;
}
node_builder.Input(inputs);
- for (auto& c : control_input_names) {
+ for (const string& c : control_input_names) {
node_builder.ControlInput(c);
}
@@ -744,54 +640,50 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Up until this point, graph is not modified. If we return !status.ok() from
// here, this segment will be skipped
tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
- engine_nodes[pos] = engine_node;
+ (*engine_nodes)[pos] = engine_node;
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return status;
}
- // input edges of the engine node
- for (auto in : control_input_nodes) {
+ // Add control input and input edges to the engine node.
+ for (const auto in : control_input_nodes) {
VLOG(1) << "Connecting control edge from " << in->name() << " to "
<< engine_node->name();
graph->AddControlEdge(in, engine_node);
}
- int idx = 0;
VLOG(1) << "input_nodes size = " << input_nodes.size();
- for (auto in : inputs) {
- Node* n = input_nodes[idx];
- CHECK(n != nullptr);
+ for (int i = 0; i < input_nodes.size(); ++i) {
+ Node* n = input_nodes[i];
+ const auto& in = inputs[i];
+ CHECK_NOTNULL(n);
VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
- << " to " << engine_node->name() << ":" << idx;
- graph->AddEdge(n, in.index, engine_node, idx++);
+ << " to " << engine_node->name() << ":" << i;
+ graph->AddEdge(n, in.index, engine_node, i);
}
+
// Updates the inputs of output edges destination nodes, and point them to the
// engine node.
-
for (auto& conn : info.connections) {
if (conn.is_input_edge) {
continue;
}
-
- string out_name = conn.outside_node_name;
- auto out_node = graph->FindNodeId(conn.outside_id);
- int out_port = conn.outside_port;
-
- if (!out_node) {
- UpdateToEngineNode(out_node, out_name, out_port, infos, pos, engine_nodes,
- false);
+ tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id);
+ int port = conn.outside_port;
+ if (!output_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
+ conn.outside_node_name, &output_node, &port);
}
-
VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
- << " to " << out_node->name() << ":" << out_port;
-
- if (conn.is_control_edge) {
- graph->AddControlEdge(engine_node, out_node);
+ << " to " << output_node->name() << ":" << port;
+ if (conn.is_control_edge()) {
+ QCHECK_EQ(Graph::kControlSlot, port);
+ graph->AddControlEdge(engine_node, output_node);
} else {
auto new_edge =
- graph->AddEdge(engine_node, conn.port_number, out_node, out_port);
- CHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
- << ":" << conn.port_number << " -> " << out_node->name()
- << ":" << conn.outside_port;
+ graph->AddEdge(engine_node, conn.port_number, output_node, port);
+ QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
+ << ":" << conn.port_number << " -> "
+ << output_node->name() << ":" << conn.outside_port;
}
}
return status;
@@ -1077,19 +969,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
}
cudaSetDevice(cuda_device_id);
- auto status = CreateTRTNode(&graph, engine_segments, i, device_alloc.second,
- params.max_batch_size, engine_nodes);
+ auto status = CreateTRTNode(engine_segments, i, params.max_batch_size,
+ &graph, alloc.get(), &engine_nodes);
// If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified.
+ const string msg = StrCat(
+ "Engine ", engine.engine_name, " creation for segment ", i,
+ ", composed of ", converted_segments.at(i).first.size(), " nodes");
if (status.ok()) {
+ LOG(INFO) << msg << " succeeded.";
for (auto node_name : converted_segments.at(i).first) {
graph.RemoveNode(node_map.at(node_name));
}
} else {
// Graph is not modified.
- LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
- << converted_segments.at(i).first.size()
- << " nodes failed: " << status << ". Skipping...";
+ LOG(WARNING) << msg << " failed: " << status << ". Skipping...";
}
}
cudaSetDevice(old_cuda_device);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 451d6fe698..3b0ac43061 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -2788,6 +2789,7 @@ tensorflow::Status ConvertGraphDefToEngine(
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids, // In topological order
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope) {
@@ -2796,6 +2798,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
// nodes in the segment graphdef.
for (size_t i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
+ if (connection.is_control_edge()) continue;
auto outside_node = graph->FindNodeId(connection.outside_id);
if (!outside_node) {
// This should never happen, unless the original graph is problematic.
@@ -2809,13 +2812,13 @@ tensorflow::Status ConvertSegmentToGraphDef(
GetInputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
-
+ connection.outside_shape = partial_shape;
} else {
GetOutputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
+ connection.inside_shape = partial_shape;
}
- connection.outside_shape = partial_shape;
connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
@@ -2873,7 +2876,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
// Update the inputs of the new input nodes to point to placeholder nodes.
for (int i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
- if (!connection.is_input_edge) continue;
+ if (connection.is_control_edge() || !connection.is_input_edge) continue;
auto snode =
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
const string placeholder_name =
@@ -2883,6 +2886,38 @@ tensorflow::Status ConvertSegmentToGraphDef(
<< placeholder_name;
snode->set_input(connection.inside_port, placeholder_name);
}
+ // Remove control inputs that are not inside the segment.
+ for (int i = 0; i < segment_def->node_size(); ++i) {
+ auto snode = segment_def->mutable_node(i);
+ const int input_size = snode->input_size();
+ int input_idx = 0;
+ int actual_input_idx = 0;
+ while (input_idx < input_size) {
+ TensorId input = ParseTensorName(snode->input(input_idx));
+ if (!subgraph_node_names.count(
+ string(input.first.data(), input.first.size())) &&
+ !str_util::StartsWith(input.first, kInputPHName)) {
+ if (input.second == Graph::kControlSlot) {
+ VLOG(2) << "... removing control inputs " << input.first
+ << " from subgraph.";
+ ++input_idx;
+ continue;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Found non control input outside the segment that is not an "
+ "engine connection to ", snode->name(), ": ", input.first);
+ }
+ }
+ if (actual_input_idx != input_idx) {
+ snode->set_input(actual_input_idx, snode->input(input_idx));
+ }
+ ++input_idx;
+ ++actual_input_idx;
+ }
+ for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
+ snode->mutable_input()->RemoveLast();
+ }
+ }
*common_scope = local_scope;
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index d41a886b30..328efbf50c 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,8 +36,8 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-static const char* kInputPHName = "InputPH_";
-static const char* kOutputPHName = "OutputPH_";
+static const char* kInputPHName = "TensorRTInputPH_";
+static const char* kOutputPHName = "TensorRTOutputPH_";
namespace convert {
// TODO(aaroey): use an enum instead.
@@ -46,9 +46,10 @@ const int FP16MODE = 1;
const int INT8MODE = 2;
struct EngineConnection {
+ // Constructs a non-control edge.
EngineConnection(const string& outside, int out_id, int out_port,
const string& inside, int in_id, int in_port,
- bool input_edge, int port, bool control_edge = false)
+ bool input_edge, int port)
: outside_node_name(outside),
outside_id(out_id),
outside_port(out_port),
@@ -56,24 +57,39 @@ struct EngineConnection {
inside_id(in_id),
inside_port(in_port),
is_input_edge(input_edge),
- is_control_edge(control_edge),
port_number(port) {}
+ // Constructs a control edge.
+ EngineConnection(const string& outside, int out_id, const string& inside,
+ int in_id, bool input_edge)
+ : outside_node_name(outside),
+ outside_id(out_id),
+ outside_port(Graph::kControlSlot),
+ inside_node_name(inside),
+ inside_id(in_id),
+ inside_port(Graph::kControlSlot),
+ is_input_edge(input_edge),
+ port_number(Graph::kControlSlot) {}
+
+ bool is_control_edge() const {
+ return port_number == Graph::kControlSlot;
+ }
+
const string outside_node_name;
const int outside_id;
const int outside_port;
- tensorflow::PartialTensorShape outside_shape;
+ tensorflow::PartialTensorShape outside_shape; // Only set for input edge.
const string inside_node_name;
const int inside_id;
const int inside_port;
- tensorflow::PartialTensorShape inside_shape;
+ tensorflow::PartialTensorShape inside_shape; // Only set for output edge.
tensorflow::DataType connection_type;
- bool is_input_edge;
- bool is_control_edge;
- // The port number of the TRT node connecting to this edge.
- int port_number;
+ const bool is_input_edge;
+
+ // The port number of the TRT node connected with this edge.
+ const int port_number;
};
struct EngineInfo {
@@ -86,7 +102,9 @@ struct EngineInfo {
string device;
tensorflow::GraphDef segment_graph_def;
- // The segment nodes that are on one side of the edges are topological sorted.
+ // Non-control input connections inside this vector are sorted in a way such
+ // that, the segment nodes connecting to them are topological sorted.
+ // In addition, for non-control connections, there must be no duplicates.
std::vector<EngineConnection> connections;
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
@@ -102,6 +120,7 @@ struct EngineInfo {
// (OutputPH_*). This function needs to be called before TensorRT nodes
// inserted in order to correctly get sizes from the original graph.
//
+// - subgraph_node_names: the node names of the subgraph.
// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
@@ -111,6 +130,7 @@ struct EngineInfo {
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids,
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 6699b71d28..a19cd24c94 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -179,7 +180,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
- [ctx, outputs, helper](const tensorflow::Status& s) {
+ [this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
@@ -189,6 +190,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
+ test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
+ "done");
delete outputs;
});
}
@@ -234,6 +237,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
+ test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
@@ -258,7 +262,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
", current entries=");
for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
- StrAppend(&msg, "Requested batch=", num_batch);
+ StrAppend(&msg, " requested batch=", num_batch);
LOG(WARNING) << msg;
return -1;
}
@@ -276,7 +280,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
}
const int smallest_engine = GetEngineBatch(ctx);
if (smallest_engine < 0) {
- LOG(WARNING) << "Failed to get engine batch, running native segment";
+ LOG(WARNING) << "Failed to get engine batch, running native segment for "
+ << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -286,14 +291,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
- << " failed. Running native segment";
+ << " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
engine_ctx_pair.second.get());
if (retry) {
- LOG(WARNING) << "Failed to execute engine, retrying with native segment";
+ LOG(WARNING) << "Failed to execute engine, "
+ << "retrying with native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -412,6 +418,7 @@ bool TRTEngineOp::ExecuteTrtEngine(
LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
return kRetry;
}
+ test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
// Synchronization will be done by TF.
return !kRetry;
}
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index fe4fa166a1..7cdfe2b1a6 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -20,7 +20,11 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value
+from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled
# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 2b67931661..5c1f4a466e 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -20,9 +20,13 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
+from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
+from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index edd30ad7a9..9d14e635f4 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -20,17 +20,19 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
-class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing single segment."""
@@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(100, 6, 6, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing multiple segment."""
@@ -95,32 +101,138 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
padding="SAME",
name="conv")
c1 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- p = conv * c1
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1")
+ p = math_ops.mul(conv, c1, name="mul")
c2 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- q = conv / c2
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2")
+ q = math_ops.div(conv, c2, name="div")
- edge = self.trt_incompatible_op(q)
- edge /= edge
- r = edge + edge
+ edge = self.trt_incompatible_op(q, name="incompatible")
+ edge = math_ops.div(edge, edge, name="div1")
+ r = math_ops.add(edge, edge, name="add")
- p -= edge
- q *= edge
- s = p + q
- s -= r
+ p = math_ops.sub(p, edge, name="sub")
+ q = math_ops.mul(q, edge, name="mul1")
+ s = math_ops.add(p, q, name="add1")
+ s = math_ops.sub(s, r, name="sub1")
array_ops.squeeze(s, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", "add",
+ # "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(100, 12, 12, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-# TODO(aaroey): add a large complex graph to test.
+class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestA, self).setUp()
+ # Let it fail to build the first engine.
+ trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ for i in range(2):
+ c = constant_op.constant(1.0, name="c%d" % i)
+ n = math_ops.add(n, c, name="add%d" % i)
+ n = math_ops.mul(n, n, name="mul%d" % i)
+ edge = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([edge]):
+ c = constant_op.constant(1.0, name="c2")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul2")
+ c = constant_op.constant(1.0, name="c3")
+ n = math_ops.add(n, c, name="add3")
+ n = math_ops.mul(n, n, name="mul3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ # Only the second engine is built.
+ "my_trt_op_1": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class PartiallyConvertedTestB(PartiallyConvertedTestA):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestB, self).setUp()
+ # Let it fail to build the second engine.
+ trt_convert.clear_test_values("")
+ trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ return super(PartiallyConvertedTestB, self).GetParams()._replace(
+ expected_engines={
+ # Only the first engine is built.
+ "my_trt_op_0": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ })
+
+
+class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ # Adds control dependency from the constant op to a trt incompatible op,
+ # and adds control dependency from the trt incompatible op to all other
+ # ops, to make sure the constant op cannot be contracted with any trt
+ # segment that depends on it.
+ with g.control_dependencies([c]):
+ d = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ "my_trt_op_1": ["add", "add1", "mul"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 730b6843fb..2e1107e303 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(12, 5, 8, 7),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 0c03a10b64..8be32f59b4 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=7,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
+ ],
expected_output_dims=(48, 89),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index dd673463a5..9316b14da0 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=16,
+ expected_engines=[
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ],
expected_output_dims=(5, 23040),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 8c51c45b0a..1874b9dd45 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 126),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 97b29bf05d..8c59000b70 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=['my_trt_op_0'],
expected_output_dims=(5, 12, 12, 1),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index 734ccf6345..fd55b8cd99 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 50265c0845..97e0d23b18 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -59,7 +59,7 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index bb7f5a77f0..5968af28ae 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -37,10 +38,14 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "num_expected_engines",
+ "gdef", "input_names", "input_dims", "expected_engines",
"expected_output_dims", "allclose_atol", "allclose_rtol"
])
+RunParams = namedtuple(
+ "RunParams",
+ ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -48,6 +53,12 @@ def _IsQuantizationMode(mode):
return mode == "INT8"
+class GraphState:
+ ORIGINAL = 0
+ CALIBRATE = 1
+ INFERENCE = 2
+
+
class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@@ -63,34 +74,79 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def precision_modes(self):
return ["FP32", "FP16", "INT8"]
+ # str is bytes in py2, but unicode in py3.
+ def _ToUnicode(self, s):
+ if six.PY2:
+ if isinstance(s, unicode):
+ return s
+ return s.decode("utf-8")
+ else:
+ if isinstance(s, str):
+ return s
+ return s.decode("utf-8")
+
def _ToBytes(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
- return s.encode("utf-8")
+ if isinstance(s, str):
+ return s.encode("utf-8")
+ return s
def _ToString(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
+ if isinstance(s, str):
+ return s
return s.decode("utf-8")
+ @classmethod
+ def setUpClass(cls):
+ """Setup method for the module."""
+ super(TfTrtIntegrationTestBase, cls).setUpClass()
+ trt_convert.enable_test_value()
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
warnings.simplefilter("always")
+ trt_convert.clear_test_values("")
def GetParams(self):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _GetConfigProto(self,
- params,
- use_optimizer,
- precision_mode=None,
- is_dynamic_op=None):
+ def _PrepareRun(self, params, graph_state):
+ """Set up necessary testing environment before calling sess.run()."""
+ # Clear test values added by TRTEngineOp.
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
+
+ def _VerifyRun(self, params, graph_state):
+ """Verify the state after sess.run()."""
+ for engine_name in params.expected_engines:
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+
+ def _GetConfigProto(self, params, run_params, graph_state):
"""Get config proto based on specific settings."""
- if use_optimizer:
+ if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
@@ -98,14 +154,31 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
custom_op.parameter_map["minimum_segment_size"].i = 3
custom_op.parameter_map["max_batch_size"].i = max(
[dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- precision_mode)
+ run_params.precision_mode)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
+ # Disable all other optimizations which can affect the converted graph.
+ off = rewriter_config_pb2.RewriterConfig.OFF
+ graph_options.optimizer_options.opt_level = config_pb2.OptimizerOptions.L0
+ graph_options.rewrite_options.layout_optimizer = off
+ graph_options.rewrite_options.constant_folding = off
+ graph_options.rewrite_options.shape_optimization = off
+ graph_options.rewrite_options.remapping = off
+ graph_options.rewrite_options.arithmetic_optimization = off
+ graph_options.rewrite_options.dependency_optimization = off
+ graph_options.rewrite_options.loop_optimization = off
+ graph_options.rewrite_options.function_optimization = off
+ graph_options.rewrite_options.debug_stripper = off
+ graph_options.rewrite_options.disable_model_pruning = True
+ graph_options.rewrite_options.scoped_allocator_optimization = off
+ graph_options.rewrite_options.memory_optimization = (
+ rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
+
gpu_options = config_pb2.GPUOptions()
gpu_options.allow_growth = True
if trt_convert.get_linked_tensorrt_version()[0] == 3:
@@ -115,7 +188,21 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
gpu_options=gpu_options, graph_options=graph_options)
return config
- def _RunGraph(self, params, gdef, input_data, config, num_runs=2):
+ def _ExpectTestValue(self, engine_name, method, value):
+ self.assertEqual(
+ value, trt_convert.get_test_value("%s:%s" % (engine_name, method)))
+
+ def _ExpectCalibration(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteCalibration", value)
+
+ def _ExpectTrtEngine(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value)
+
+ def _ExpectNativeSegment(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
+
+ def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ num_runs=2):
"""Run given graphdef multiple times."""
assert len(params.input_names) == len(input_data)
g = ops.Graph()
@@ -132,93 +219,166 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
+ self._PrepareRun(params, graph_state)
new_val = sess.run(out,
{inp[i]: input_data[i] for i in range(len(inp))})
self.assertEqual(params.expected_output_dims, new_val.shape)
if val is not None:
self.assertAllEqual(val, new_val)
val = new_val
+ self._VerifyRun(params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
def _RunCalibration(self, params, gdef, input_data, config):
"""Run calibration on given graph."""
- return self._RunGraph(params, gdef, input_data, config, 30)
+ return self._RunGraph(
+ params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op):
+ def _GetTrtGraphDef(self, params, run_params, gdef):
"""Return trt converted graphdef."""
return trt_convert.create_inference_graph(
input_graph_def=gdef,
outputs=[self.output_name],
max_batch_size=max([dims[0] for dims in params.input_dims]),
max_workspace_size_bytes=1 << 25,
- precision_mode=precision_mode,
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
- is_dynamic_op=is_dynamic_op)
-
- def _VerifyGraphDef(self,
- params,
- gdef,
- precision_mode=None,
- is_calibrated=None,
- dynamic_engine=None):
+ is_dynamic_op=run_params.dynamic_engine)
+
+ def _WriteGraph(self, params, run_params, gdef, graph_state):
+ if graph_state == GraphState.ORIGINAL:
+ label = "Original"
+ elif graph_state == GraphState.CALIBRATE:
+ label = "CalibEngine"
+ elif graph_state == GraphState.INFERENCE:
+ label = "InferEngine"
+ graph_name = (
+ self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
+ ".pbtxt")
+ logging.info("Writing graph to %s/%s", self.get_temp_dir(), graph_name)
+ graph_io.write_graph(gdef, self.get_temp_dir(), graph_name)
+
+ def _VerifyConnections(self, params, converted_gdef):
+ old_to_new_node_map = {
+ self._ToString(n.name): self._ToString(n.name) for n in params.gdef.node
+ }
+ for engine_name, node_names in params.expected_engines.iteritems():
+ for n in node_names:
+ old_to_new_node_map[n] = engine_name
+ name_to_node_map = {self._ToString(n.name): n for n in params.gdef.node}
+
+ def input_name(inp):
+ inp = self._ToString(inp)
+ prefix = ""
+ if inp[0] == "^":
+ prefix = "^"
+ inp = inp[1:]
+ parts = inp.split(":")
+ if len(parts) > 1 and parts[-1].isdigit():
+ inp = inp[:-len(parts[-1]) - 1]
+ return (prefix, inp)
+
+ expected_input_map = {}
+ for n in params.gdef.node:
+ name_str = self._ToString(n.name)
+ target_node_name = old_to_new_node_map[name_str]
+ is_engine_op = (target_node_name != name_str)
+ if target_node_name not in expected_input_map:
+ expected_input_map[target_node_name] = set()
+ input_set = expected_input_map[target_node_name]
+ for inp in n.input:
+ (prefix, inp_name) = input_name(inp)
+ # Add the input only if it's outside the segment (note that it could be
+ # in a different engine).
+ if (not is_engine_op or
+ old_to_new_node_map[inp_name] != target_node_name):
+ if is_engine_op and name_to_node_map[inp_name].op == "Const":
+ # Const data input nodes to the segment has been copied to the
+ # segment graphdef and the engine, and the dependency has been
+ # converted to control dependendy.
+ input_set.add("^" + old_to_new_node_map[inp_name])
+ else:
+ input_set.add(prefix + old_to_new_node_map[inp_name])
+
+ actual_input_map = {}
+ for n in converted_gdef.node:
+ name_str = self._ToString(n.name)
+ actual_input_map[name_str] = set()
+ input_set = actual_input_map[name_str]
+ for inp in n.input:
+ (prefix, node_name) = input_name(inp)
+ input_set.add(prefix + node_name)
+
+ self.assertEqual(
+ expected_input_map,
+ actual_input_map,
+ msg="expected:\n%s\nvs actual:\n%s" % (expected_input_map,
+ actual_input_map))
+
+ def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
+ self._WriteGraph(params, run_params, gdef, graph_state)
+
num_engines = 0
for n in gdef.node:
- # TODO(jie): we should have coverage for failed conversion (TF fallback).
- # where the conversion will fail and we shouldn't count this engine as the
- # converted engines.
if n.op == "TRTEngineOp":
num_engines += 1
- self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s)
- self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s)
+ self.assertTrue(n.name in params.expected_engines)
+ self.assertTrue(len(n.attr["serialized_segment"].s))
+ self.assertTrue(len(n.attr["segment_funcdef_name"].s))
self.assertEqual(
- self._ToBytes(precision_mode), n.attr["precision_mode"].s)
- self.assertEqual(not dynamic_engine, n.attr["static_engine"].b)
- if _IsQuantizationMode(precision_mode) and is_calibrated:
- self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s)
+ self._ToBytes(run_params.precision_mode),
+ n.attr["precision_mode"].s)
+
+ is_dynamic_engine = not n.attr["static_engine"].b
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+
+ has_calibration_data = len(n.attr["calibration_data"].s)
+ if (_IsQuantizationMode(run_params.precision_mode) and
+ graph_state == GraphState.INFERENCE):
+ self.assertTrue(has_calibration_data)
else:
- self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s)
- if precision_mode is None: # This means gdef is the original GraphDef.
+ self.assertFalse(has_calibration_data)
+ if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, params.num_expected_engines)
+ self.assertEqual(num_engines, len(params.expected_engines))
+ if isinstance(params.expected_engines, dict):
+ self._VerifyConnections(params, gdef)
+ # TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine):
- assert precision_mode in PRECISION_MODES
+ def RunTest(self, params, run_params):
+ assert run_params.precision_mode in PRECISION_MODES
input_data = [np.random.random_sample(dims) for dims in params.input_dims]
input_gdef = params.gdef
- self._VerifyGraphDef(params, input_gdef)
+ self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, False)
+ config_no_trt = self._GetConfigProto(params, run_params,
+ GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt)
+ ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
+ GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(precision_mode):
+ if _IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_calib_engine)
+ calib_config = self._GetConfigProto(params, run_params,
+ GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
- if use_optimizer:
- self.assertTrue(False)
- # TODO(aaroey): uncomment this and get infer_gdef when this mode is
- # supported.
- # result = self._RunCalibration(params, input_gdef, input_data,
- # calib_config)
+ if run_params.use_optimizer:
+ result = self._RunCalibration(params, input_gdef, input_data,
+ calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
- dynamic_calib_engine)
- self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
- dynamic_calib_engine)
+ calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
+ self._VerifyGraphDef(params, run_params, calib_gdef,
+ GraphState.CALIBRATE)
result = self._RunCalibration(params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
- dynamic_calib_engine)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
+ self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -229,18 +389,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_infer_engine)
+ infer_config = self._GetConfigProto(params, run_params,
+ GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config)
+ if run_params.use_optimizer:
+ result = self._RunGraph(params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
else:
- trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
- dynamic_infer_engine)
- self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
- dynamic_infer_engine)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config)
+ trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
+ self._VerifyGraphDef(params, run_params, trt_infer_gdef,
+ GraphState.INFERENCE)
+ result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -263,66 +424,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _AddTests(test_class):
"""Adds test methods to TfTrtIntegrationTestBase."""
- def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine):
+ def _GetTest(run_params):
"""Gets a single test method based on the parameters."""
def _Test(self):
params = self.GetParams()
logging.info(
- "Running test with parameters: use_optimizer=%s, precision_mode=%s, "
- "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer,
- precision_mode, dynamic_infer_engine, dynamic_calib_engine)
- self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine)
+ "Running test %s with parameters: use_optimizer=%s, "
+ "precision_mode=%s, dynamic_engine=%s",
+ "testTfTRT_" + run_params.test_name, run_params.use_optimizer,
+ run_params.precision_mode, run_params.dynamic_engine)
+ self.RunTest(params, run_params)
return _Test
use_optimizer_options = [False, True]
- dynamic_infer_engine_options = [False, True]
- dynamic_calib_engine_options = [False, True]
- for (use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
- use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options,
- dynamic_calib_engine_options):
+ dynamic_engine_options = [False, True]
+ for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
+ use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
if _IsQuantizationMode(precision_mode):
- if not dynamic_calib_engine and dynamic_infer_engine:
- # TODO(aaroey): test this case, the conversion from static calibration
- # engine to dynamic inference engine should be a noop.
- continue
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
# supported yet.
continue
- if not dynamic_calib_engine:
+ if not dynamic_engine:
# TODO(aaroey): construction of static calibration engine is not
# supported yet.
continue
- if dynamic_calib_engine and not dynamic_infer_engine:
- # TODO(aaroey): construction of static inference engine using dynamic
- # calibration engine is not supported yet.
- continue
- else: # In non int8 mode.
- if dynamic_calib_engine:
- # dynamic_calib_engine doesn't affect non-int8 modes, so just let
- # related tests run once on dynamic_calib_engine=False.
- continue
conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
- infer_engine_type = ("DynamicInferEngine"
- if dynamic_infer_engine else "StaticInferEngine")
- calib_engine_type = ""
- if precision_mode == "INT8":
- calib_engine_type = ("DynamicCalibEngine"
- if dynamic_calib_engine else "StaticCalibEngine")
- test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type,
- ("_" + calib_engine_type)
- if len(calib_engine_type) else "")
- setattr(
- test_class, "testTfTRT_" + test_name,
- _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine))
+ engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine")
+ test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type)
+ run_params = RunParams(
+ use_optimizer=use_optimizer,
+ precision_mode=precision_mode,
+ dynamic_engine=dynamic_engine,
+ test_name=test_name)
+ setattr(test_class, "testTfTRT_" + test_name, _GetTest(run_params))
if trt_convert.is_tensorrt_enabled():
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index b9e977cf67..500057a36d 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- num_expected_engines=5,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ],
expected_output_dims=(12, 5, 8, 12),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc
new file mode 100644
index 0000000000..319ddea1b7
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/test/utils.h"
+
+#include <unordered_map>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// TODO(aaroey): make this class thread-safe.
+class TestValueManager {
+ public:
+ static TestValueManager* singleton() {
+ static TestValueManager* manager = new TestValueManager();
+ return manager;
+ }
+
+ void Enable() {
+ VLOG(1) << "Enabling test value";
+ enabled_ = true;
+ }
+
+ void Add(const string& label, const string& value) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ QCHECK_NE("", value);
+ VLOG(1) << "Adding test value: " << label << " -> " << value;
+ values_.insert({label, value});
+ }
+ }
+
+ string Get(const string& label) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Getting test value by " << label;
+ auto itr = values_.find(label);
+ if (itr == values_.end()) return "";
+ return itr->second;
+ }
+ return "";
+ }
+
+ void Clear(const string& pattern) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Clearing test values";
+ if (pattern == "") {
+ values_.clear();
+ return;
+ }
+ std::vector<string> keys_to_clear;
+ for (const auto& kv : values_) {
+ if (RE2::FullMatch(kv.first, pattern)) {
+ keys_to_clear.push_back(kv.first);
+ }
+ }
+ for (const string& key : keys_to_clear) {
+ values_.erase(key);
+ }
+ }
+ }
+
+ private:
+ TestValueManager() : enabled_(false) {}
+
+ bool enabled_;
+ std::unordered_map<string, string> values_;
+};
+
+void EnableTestValue() { TestValueManager::singleton()->Enable(); }
+
+void ClearTestValues(const string& pattern) {
+ TestValueManager::singleton()->Clear(pattern);
+}
+
+void AddTestValue(const string& label, const string& value) {
+ TestValueManager::singleton()->Add(label, value);
+}
+
+string GetTestValue(const string& label) {
+ return TestValueManager::singleton()->Get(label);
+}
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h
new file mode 100644
index 0000000000..625cd3d799
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// Helper methods to inject values used by testing tools.
+void EnableTestValue();
+void ClearTestValues(const string& pattern);
+void AddTestValue(const string& label, const string& value);
+string GetTestValue(const string& label);
+
+#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \
+ do { \
+ if (::tensorflow::tensorrt::test::GetTestValue(label) == value_to_return) {\
+ return errors::Internal("Injected manually"); \
+ } \
+ } while(0)
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index 2b134c3bce..ab4d224db4 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 6, 2, 2),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index bec2f23eff..56bdf848ea 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 2, 2, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 422740fdf6..921c263dfe 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -101,6 +101,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
%}
%ignoreall
@@ -110,6 +111,10 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
%unignore is_tensorrt_enabled;
+%unignore enable_test_value;
+%unignore clear_test_values;
+%unignore add_test_value;
+%unignore get_test_value;
%{
@@ -251,6 +256,22 @@ bool is_tensorrt_enabled() {
return tensorflow::tensorrt::IsGoogleTensorRTEnabled();
}
+void enable_test_value() {
+ tensorflow::tensorrt::test::EnableTestValue();
+}
+
+void clear_test_values(string pattern) {
+ tensorflow::tensorrt::test::ClearTestValues(pattern);
+}
+
+void add_test_value(string label, string value) {
+ tensorflow::tensorrt::test::AddTestValue(label, value);
+}
+
+string get_test_value(string label) {
+ return tensorflow::tensorrt::test::GetTestValue(label);
+}
+
%}
std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
@@ -266,5 +287,9 @@ std::pair<string, string> trt_convert(string graph_def_string,
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();
bool is_tensorrt_enabled();
+void enable_test_value();
+void clear_test_values(string pattern);
+void add_test_value(string label, string value);
+string get_test_value(string label);
%unignoreall