aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py7
-rw-r--r--tensorflow/contrib/tensorrt/BUILD391
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc255
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h14
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1114
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h56
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc111
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_calib_op.h36
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc42
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_calib_op.cc36
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py1
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py60
-rw-r--r--tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc132
-rw-r--r--tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h39
-rw-r--r--tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc21
-rw-r--r--tensorflow/contrib/tensorrt/resources/TRTResourceManager.h37
-rw-r--r--tensorflow/contrib/tensorrt/resources/TRTResources.h59
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i62
18 files changed, 2040 insertions, 433 deletions
diff --git a/configure.py b/configure.py
index 3aa1a3e956..68c9bbfb1c 100644
--- a/configure.py
+++ b/configure.py
@@ -1050,12 +1050,15 @@ def set_tf_tensorrt_install_path(environ_cp):
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
- nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
+ nvinfer_pattern = re.compile('.*libnvinfer(?:_debug)?.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
if is_compatible(lib_file, cuda_ver, cudnn_ver):
- ver_str = nvinfer_pattern.search(lib_file).group(1)
+ matches=nvinfer_pattern.search(lib_file)
+ if len(matches.groups()) == 0:
+ continue
+ ver_str = matches.group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
if ver > highest_ver[0]:
highest_ver = [ver, ver_str, lib_file]
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index cf67c27b70..dd83c34dfb 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -3,244 +3,279 @@
# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
-package(default_visibility = ["//tensorflow:__subpackages__"])
+package(default_visibility=["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
- "//tensorflow:tensorflow.bzl",
- "tf_cc_test",
- "tf_copts",
- "tf_cuda_library",
- "tf_custom_op_library",
- "tf_custom_op_library_additional_deps",
- "tf_gen_op_libs",
- "tf_gen_op_wrapper_py",
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+ "tf_copts",
+ "tf_cuda_library",
+ "tf_custom_op_library",
+ "tf_custom_op_library_additional_deps",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
- "@local_config_tensorrt//:build_defs.bzl",
- "if_tensorrt",
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
)
tf_cuda_cc_test(
- name = "tensorrt_test_cc",
- size = "small",
- srcs = ["tensorrt_test.cc"],
- tags = [
- "manual",
- "notap",
- ],
- deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ] + if_tensorrt([
- "@local_config_cuda//cuda:cuda_headers",
- "@local_config_tensorrt//:nv_infer",
- ]),
+ name="tensorrt_test_cc",
+ size="small",
+ srcs=["tensorrt_test.cc"],
+ tags=[
+ "manual",
+ "notap",
+ ],
+ deps=[
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_tensorrt//:nv_infer",
+ ]),
)
tf_custom_op_library(
- name = "python/ops/_trt_engine_op.so",
- srcs = ["ops/trt_engine_op.cc"],
- deps = [
- ":trt_engine_op_kernel",
- ":trt_shape_function",
- "//tensorflow/core:lib_proto_parsing",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]),
+ name="python/ops/_trt_engine_op.so",
+ srcs=[
+ "ops/trt_calib_op.cc",
+ "ops/trt_engine_op.cc",
+ ],
+ deps=[
+ ":trt_engine_op_kernel",
+ ":trt_shape_function",
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
)
tf_cuda_library(
- name = "trt_shape_function",
- srcs = ["shape_fn/trt_shfn.cc"],
- hdrs = ["shape_fn/trt_shfn.h"],
- visibility = ["//visibility:public"],
- deps = [
- ":trt_logging",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]) + tf_custom_op_library_additional_deps(),
+ name="trt_shape_function",
+ srcs=["shape_fn/trt_shfn.cc"],
+ hdrs=["shape_fn/trt_shfn.h"],
+ visibility=["//visibility:public"],
+ deps=[
+ ":trt_logging",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
)
cc_library(
- name = "trt_engine_op_kernel",
- srcs = ["kernels/trt_engine_op.cc"],
- hdrs = ["kernels/trt_engine_op.h"],
- copts = tf_copts(),
- deps = [
- ":trt_logging",
- "//tensorflow/core:gpu_headers_lib",
- "//tensorflow/core:lib_proto_parsing",
- "//tensorflow/core:stream_executor_headers_lib",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]) + tf_custom_op_library_additional_deps(),
- alwayslink = 1,
+ name="trt_engine_op_kernel",
+ srcs=[
+ "kernels/trt_calib_op.cc",
+ "kernels/trt_engine_op.cc",
+ ],
+ hdrs=[
+ "kernels/trt_calib_op.h",
+ "kernels/trt_engine_op.h",
+ ],
+ copts=tf_copts(),
+ deps=[
+ ":trt_logging",
+ ":trt_resources",
+ "//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:stream_executor_headers_lib",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+ alwayslink=1,
)
tf_gen_op_libs(
- op_lib_names = ["trt_engine_op"],
- deps = if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]),
+ op_lib_names=[
+ "trt_engine_op",
+ "trt_calib_op",
+ ],
+ deps=if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
)
tf_cuda_library(
- name = "trt_logging",
- srcs = ["log/trt_logger.cc"],
- hdrs = ["log/trt_logger.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:lib_proto_parsing",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]),
+ name="trt_logging",
+ srcs=["log/trt_logger.cc"],
+ hdrs=["log/trt_logger.h"],
+ visibility=["//visibility:public"],
+ deps=[
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
)
tf_gen_op_wrapper_py(
- name = "trt_engine_op",
- deps = [
- ":trt_engine_op_op_lib",
- ":trt_logging",
- ":trt_shape_function",
- ],
+ name="trt_engine_op",
+ deps=[
+ ":trt_engine_op_op_lib",
+ ":trt_calib_op_op_lib",
+ ":trt_logging",
+ ":trt_shape_function",
+ ],
)
tf_custom_op_py_library(
- name = "trt_engine_op_loader",
- srcs = ["python/ops/trt_engine_op.py"],
- dso = [
+ name="trt_engine_op_loader",
+ srcs=["python/ops/trt_engine_op.py"],
+ dso=[
":python/ops/_trt_engine_op.so",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]),
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:resources",
- ],
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+ srcs_version="PY2AND3",
+ deps=[
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:resources",
+ ],
)
py_library(
- name = "init_py",
- srcs = [
- "__init__.py",
- "python/__init__.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":trt_convert_py",
- ":trt_ops_py",
- ],
+ name="init_py",
+ srcs=[
+ "__init__.py",
+ "python/__init__.py",
+ ],
+ srcs_version="PY2AND3",
+ deps=[
+ ":trt_convert_py",
+ ":trt_ops_py",
+ ],
)
py_library(
- name = "trt_ops_py",
- srcs_version = "PY2AND3",
- deps = [
- ":trt_engine_op",
- ":trt_engine_op_loader",
- ],
+ name="trt_ops_py",
+ srcs_version="PY2AND3",
+ deps=[
+ ":trt_engine_op",
+ ":trt_engine_op_loader",
+ ],
)
py_library(
- name = "trt_convert_py",
- srcs = ["python/trt_convert.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":wrap_conversion",
- ],
+ name="trt_convert_py",
+ srcs=["python/trt_convert.py"],
+ srcs_version="PY2AND3",
+ deps=[
+ ":wrap_conversion",
+ ],
)
tf_py_wrap_cc(
- name = "wrap_conversion",
- srcs = ["trt_conversion.i"],
- copts = tf_copts(),
- deps = [
- ":trt_conversion",
- "//tensorflow/core:framework_lite",
- "//util/python:python_headers",
- ],
+ name="wrap_conversion",
+ srcs=["trt_conversion.i"],
+ copts=tf_copts(),
+ deps=[
+ ":trt_conversion",
+ "//tensorflow/core:framework_lite",
+ "//util/python:python_headers",
+ ],
+)
+
+tf_cuda_library(
+ name="trt_resources",
+ srcs=[
+ "resources/TRTInt8Calibrator.cc",
+ "resources/TRTResourceManager.cc",
+ ],
+ hdrs=[
+ "resources/TRTInt8Calibrator.h",
+ "resources/TRTResourceManager.h",
+ "resources/TRTResources.h",
+ ],
+ deps=[
+ "@local_config_tensorrt//:nv_infer",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:lib_proto_parsing",
+
+ ],
)
# Library for the node-level conversion portion of TensorRT operation creation
tf_cuda_library(
- name = "trt_conversion",
- srcs = [
- "convert/convert_graph.cc",
- "convert/convert_nodes.cc",
- ],
- hdrs = [
- "convert/convert_graph.h",
- "convert/convert_nodes.h",
- ],
- deps = [
- ":segment",
- ":trt_logging",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_lite",
- "//tensorflow/core:graph",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:devices",
- "//tensorflow/core/grappler/clusters:virtual_cluster",
- "//tensorflow/core/grappler/costs:graph_properties",
- "//tensorflow/core/grappler/optimizers:constant_folding",
- "//tensorflow/core/grappler/optimizers:layout_optimizer",
- ] + if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]) + tf_custom_op_library_additional_deps(),
+ name="trt_conversion",
+ srcs=[
+ "convert/convert_graph.cc",
+ "convert/convert_nodes.cc",
+ ],
+ hdrs=[
+ "convert/convert_graph.h",
+ "convert/convert_nodes.h",
+ ],
+ deps=[
+ ":segment",
+ ":trt_logging",
+ ":trt_resources",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/optimizers:constant_folding",
+ "//tensorflow/core/grappler/optimizers:layout_optimizer",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
)
# Library for the segmenting portion of TensorRT operation creation
cc_library(
- name = "segment",
- srcs = ["segment/segment.cc"],
- hdrs = [
- "segment/segment.h",
- "segment/union_find.h",
- ],
- linkstatic = 1,
- deps = [
- "//tensorflow/core:graph",
- "//tensorflow/core:lib_proto_parsing",
- "//tensorflow/core:protos_all_cc",
- "@protobuf_archive//:protobuf_headers",
- ],
+ name="segment",
+ srcs=["segment/segment.cc"],
+ hdrs=[
+ "segment/segment.h",
+ "segment/union_find.h",
+ ],
+ linkstatic=1,
+ deps=[
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ "@protobuf_archive//:protobuf_headers",
+ ],
)
tf_cc_test(
- name = "segment_test",
- size = "small",
- srcs = ["segment/segment_test.cc"],
- deps = [
- ":segment",
- "//tensorflow/c:c_api",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
+ name="segment_test",
+ size="small",
+ srcs=["segment/segment_test.cc"],
+ deps=[
+ ":segment",
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
)
filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
+ name="all_files",
+ srcs=glob(
+ ["**/*"],
+ exclude=[
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility=["//tensorflow:__subpackages__"],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 899448004f..31ba30b2d9 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include <list>
#include <map>
#include <set>
#include <unordered_map>
@@ -39,8 +40,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
+//#if GOOGLE_CUDA
+//#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
@@ -48,13 +49,28 @@ namespace tensorrt {
namespace convert {
namespace {
-static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
+bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
// LINT.IfChange
// TODO(jie): Segmentation shouldn't associated with op name.
// Split it into a registration for each kernel.
static const std::set<string> candidate_ops = {
- "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
- "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
+ "Identity",
+ "Const",
+ "Conv2D",
+ "MaxPool",
+ "BiasAdd",
+ "Relu",
+ "Add",
+ "Mul",
+ "Sub",
+ "Rsqrt",
+ "Pad",
+ "Mean",
+ "AvgPool",
+ "ConcatV2",
+ "DepthwiseConv2dNative" //, "MatMul",
+ //"Reshape"
+ // TODO(ben,jie): ...
};
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
return candidate_ops.count(node_def.op());
@@ -69,6 +85,8 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
if (!subgraph_node_ids.count(edge->src()->id()) &&
!edge->src()->IsSource()) {
incoming_edges->insert(edge);
+ } else {
+ VLOG(2) << edge->src()->name() << " N, ";
}
}
}
@@ -82,7 +100,10 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
for (const tensorflow::Edge* edge : node->out_edges()) {
if (!subgraph_node_ids.count(edge->dst()->id()) &&
!edge->dst()->IsSink()) {
+ VLOG(2) << edge->dst()->name() << " Y, ";
outgoing_edges->insert(edge);
+ } else {
+ VLOG(2) << edge->dst()->name() << " N, ";
}
}
}
@@ -110,73 +131,147 @@ std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
return result;
}
-tensorflow::Status ConvertSubGraphToTensorRT(
- const std::vector<string>& output_names,
- const std::set<int>& subgraph_node_ids,
- size_t max_batch_size, // Max batch size that engine will be created for
- // Max amount of memory that engine will be allowed to consume, in bytes
- size_t max_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& graph_properties,
- tensorflow::Graph* graph) {
- tensorflow::EdgeSet subgraph_incoming_edges;
- GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
-
+struct ConvertGraphParams {
+ ConvertGraphParams(
+ tensorflow::Graph& graph_, const std::vector<string>& output_names_,
+ const std::set<int>& subgraph_node_ids_, size_t max_batch_size_,
+ size_t max_workspace_size_bytes_,
+ const tensorflow::grappler::GraphProperties& graph_properties_,
+ std::unordered_map<string, std::pair<int, string>>*
+ output_edge_map_,
+ int precision_mode_)
+ : graph(graph_),
+ output_names(output_names_),
+ subgraph_node_ids(subgraph_node_ids_),
+ max_batch_size(max_batch_size_),
+ max_workspace_size_bytes(max_workspace_size_bytes_),
+ graph_properties(graph_properties_),
+ output_edge_map(output_edge_map_),
+ precision_mode(precision_mode_) {}
+ tensorflow::Graph& graph;
+ const std::vector<string>& output_names;
+ const std::set<int>& subgraph_node_ids;
+ size_t max_batch_size;
+ size_t max_workspace_size_bytes;
+ const tensorflow::grappler::GraphProperties& graph_properties;
+ std::unordered_map<string, std::pair<int, string>>* output_edge_map;
+ int precision_mode;
std::vector<std::pair<int, int>> subgraph_inputs;
+ std::vector<std::pair<int, int>> subgraph_outputs;
+ tensorflow::EdgeSet subgraph_incoming_edges;
+ tensorflow::EdgeSet subgraph_outgoing_edges;
+};
- // Collect inputs by looking for incoming edges
- for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
- subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
+tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams& p) {
+ GetSubGraphIncomingEdges(p.graph, p.subgraph_node_ids,
+ &p.subgraph_incoming_edges);
+ for (tensorflow::Edge const* edge : p.subgraph_incoming_edges) {
+ p.subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
}
+ auto output_name_to_index_map = BuildTensorNameMap(p.output_names);
std::set<std::pair<int, int>> subgraph_outputs_set;
- // Collect outputs referenced from output_names
- auto output_name_to_index_map = BuildTensorNameMap(output_names);
- for (int node_id : subgraph_node_ids) {
- tensorflow::Node* node = graph->FindNodeId(node_id);
+ for (int node_id : p.subgraph_node_ids) {
+ tensorflow::Node* node = p.graph.FindNodeId(node_id);
if (output_name_to_index_map.count(node->name())) {
for (int index : output_name_to_index_map.at(node->name())) {
subgraph_outputs_set.insert({node_id, index});
}
}
}
- // Collect outputs referenced from outgoing edges
- tensorflow::EdgeSet subgraph_outgoing_edges;
- GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
- for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ GetSubGraphOutgoingEdges(p.graph, p.subgraph_node_ids,
+ &p.subgraph_outgoing_edges);
+ for (const tensorflow::Edge *edge : p.subgraph_outgoing_edges) {
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
}
- // Impose an ordering on the outputs
- std::vector<std::pair<int, int>> subgraph_outputs(
- subgraph_outputs_set.begin(), subgraph_outputs_set.end());
- // Build TensorRT node and add it to the graph
+ p.subgraph_outputs.reserve(subgraph_outputs_set.size());
+ p.subgraph_outputs.insert(p.subgraph_outputs.begin(),
+ subgraph_outputs_set.begin(),
+ subgraph_outputs_set.end());
+ return tensorflow::Status::OK();
+};
+
+tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
+ FillSubGraphEdgeSets(*params);
+ tensorflow::NodeDef trt_node_def;
+ SubGraphParams s(params->graph, params->subgraph_node_ids,
+ params->subgraph_inputs, params->subgraph_outputs,
+ params->max_batch_size, params->max_workspace_size_bytes,
+ params->graph_properties, params->output_edge_map,
+ &trt_node_def, params->precision_mode);
+ TF_RETURN_IF_ERROR(InjectCalibrationNode(s));
+ tensorflow::Status status;
+ tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
+
+ TF_RETURN_IF_ERROR(status);
+
+ for (auto in_edge :
+ params->subgraph_incoming_edges) { // loop over incoming edges and
+ // attach them to calib node
+ // tensorflow::Node* src_node = in_edge->src();
+ auto src_output = in_edge->src_output();
+ auto dst_node = in_edge->dst();
+ auto dst_input = in_edge->dst_input();
+ VLOG(1) << " update edge " << trt_node->name() << ":" << src_output
+ << " -> " << dst_node->name() << ":" << dst_input;
+ params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
+ FillSubGraphEdgeSets(*params);
tensorflow::NodeDef trt_node_def;
- TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
- *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
- max_batch_size, max_workspace_size_bytes, graph_properties,
- &trt_node_def));
+
+ SubGraphParams s(params->graph, params->subgraph_node_ids,
+ params->subgraph_inputs, params->subgraph_outputs,
+ params->max_batch_size, params->max_workspace_size_bytes,
+ params->graph_properties, params->output_edge_map,
+ &trt_node_def,params->precision_mode);
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s));
tensorflow::Status status;
- tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
+ tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
+
+ // AddNode does not wire edges.
+ // Re-map incoming edges to use the new TRT node instead of the orig subgraph
+ std::map<std::pair<int, int>, int> subgraph_edge_to_input_map;
+ for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
+ subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
+ }
+ for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
+ std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+ int new_src_output = subgraph_edge_to_input_map.at(old_src);
+ params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
+ new_src_output);
+ params->graph.RemoveEdge(edge);
+ }
+
+ VLOG(2) << "new wiring edges: " << trt_node->in_edges().size();
+ for (const tensorflow::Edge* edge : trt_node->in_edges()) {
+ VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
+ }
+
TF_RETURN_IF_ERROR(status);
// Re-map outgoing edges to use the new TRT node instead of the orig subgraph
std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
- for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
- subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
+ for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) {
+ subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i});
}
TF_RETURN_IF_ERROR(status);
- for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) {
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
int new_src_output = subgraph_edge_to_output_map.at(old_src);
- TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
- edge->dst_input()));
+ TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
+ trt_node, new_src_output, edge->dst(), edge->dst_input()));
}
// Remove the original subgraph
- for (int node_id : subgraph_node_ids) {
- tensorflow::Node* node = graph->FindNodeId(node_id);
+ for (int node_id : params->subgraph_node_ids) {
+ tensorflow::Node* node = params->graph.FindNodeId(node_id);
// Don't remove the input placeholders
if (node->type_string() == "Placeholder") {
continue;
}
- graph->RemoveNode(node);
+ params->graph.RemoveNode(node);
}
return tensorflow::Status::OK();
}
@@ -194,31 +289,64 @@ tensorflow::Status BuildNodeMap(
}
} // namespace
+tensorflow::Status ConvertCalibGraphToInferGraph(
+ const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) {
+ VLOG(0) << "Starting Calib Conversion";
+ tensorflow::Graph graph(tensorflow::OpRegistry::Global());
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), graph_def, &graph));
+ // get calib nodes
+ std::vector<tensorflow::Node*> calibNodes;
+ for (auto node : graph.op_nodes()) {
+ if (node->type_string() == "TRTCalibOp") {
+ VLOG(1) << "Found Calib Node";
+ calibNodes.push_back(node);
+ }
+ }
+ VLOG(0) << "Num Calib nodes in graph= " << calibNodes.size();
+ if (calibNodes.size() == 0)
+ return tensorflow::errors::FailedPrecondition(
+ "Graph doesn't contain any calibration nodes!."
+ " Please generate calibration graph and run calibration first");
+ for (auto n : calibNodes) {
+ TF_RETURN_IF_ERROR(
+ tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n));
+ }
+ return tensorflow::Status::OK();
+}
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
- size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
- // Optimization pass
+ size_t max_workspace_size, tensorflow::GraphDef* new_graph_def,
+ int precision_mode = 0) {
+ // optimization pass
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
tensorflow::GraphDef gdef;
- // Layout optimization
+ // layout optimization
item.graph = graph_def;
tensorflow::grappler::LayoutOptimizer optimizer;
- tensorflow::grappler::Cluster* cluster;
+ tensorflow::grappler::Cluster* gCluster;
- // Virtual cluster
+ // virtual cluster
tensorflow::DeviceProperties device_properties;
+
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
- cluster =
+ gCluster =
new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
- TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
+ // single machine
+ int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
+ int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
+ VLOG(2) << "cpu_cores: " << num_cpu_cores;
+ VLOG(2) << "gpus: " << num_gpus;
+
+ TF_RETURN_IF_ERROR(optimizer.Optimize(gCluster, item, &gdef));
- // Constant folding
+ // constant folding
item.graph = gdef;
tensorflow::grappler::ConstantFolding fold(nullptr);
TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
@@ -252,14 +380,27 @@ tensorflow::Status ConvertGraphDefToTensorRT(
}
std::unordered_map<string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
+ std::unordered_map<string, std::pair<int, string>> output_edge_map;
+ int count = 0;
for (const std::set<string>& subgraph_node_names : segments) {
std::set<int> subgraph_node_ids;
for (const string& node_name : subgraph_node_names) {
subgraph_node_ids.insert(node_map.at(node_name)->id());
}
- TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
- output_names, subgraph_node_ids, max_batch_size,
- max_workspace_size_bytes, static_graph_properties, &graph));
+ ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size,
+ max_workspace_size, static_graph_properties,
+ &output_edge_map, precision_mode);
+ if (precision_mode == 2) {
+ TF_RETURN_IF_ERROR(GetCalibNode(&p));
+ } else {
+ tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
+ if (status != tensorflow::Status::OK()) {
+ LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
+ << " due to: \n"
+ << status.ToString() << "SKIPPING......";
+ }
+ count++;
+ }
}
graph.ToGraphDef(new_graph_def);
return tensorflow::Status::OK();
@@ -269,5 +410,5 @@ tensorflow::Status ConvertGraphDefToTensorRT(
} // namespace tensorrt
} // namespace tensorflow
-#endif // GOOGLE_TENSORRT
-#endif // GOOGLE_CUDA
+//#endif // GOOGLE_TENSORRT
+//#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 154ad3f2e8..846d7f2721 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -21,12 +21,14 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
+//#if GOOGLE_CUDA
+//#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
namespace convert {
+tensorflow::Status ConvertCalibGraphToInferGraph(
+ const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def);
// max_batch_size: maximum batch size which can be used for inference for
// optimization targets inference run with max batch size.
@@ -35,13 +37,13 @@ namespace convert {
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
- size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def);
-
+ size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
+ int precision_mode);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
-#endif // GOOGLE_TENSORRT
-#endif // GOOGLE_CUDA
+//#endif // GOOGLE_TENSORRT
+//#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 9ee717dd7f..ea0eb480f2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -24,6 +24,11 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h"
+#include "tensorflow/contrib/tensorrt/resources/TRTResources.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
@@ -32,14 +37,15 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
-#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+//#if GOOGLE_CUDA
+//#if GOOGLE_TENSORRT
+// #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorrt/include/NvInfer.h"
// Check if the types are equal. Cast to int first so that failure log message
@@ -245,6 +251,11 @@ std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
}
template <>
+std::vector<string> TFAttrs::get<std::vector<string>>(string key) const {
+ auto attr = this->at(key)->list().s();
+ return std::vector<string>(attr.begin(), attr.end());
+}
+template <>
nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
auto values = this->get<std::vector<int>>(key);
nvinfer1::Dims dims;
@@ -265,7 +276,7 @@ template <>
tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
return this->at(key)->type();
}
-
+// TODO(jie): reorder4 & reorder2 should be merged?
template <typename T>
void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
nvinfer1::DimsNCHW istrides, T* odata,
@@ -283,16 +294,53 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
}
}
+
+template <typename T>
+void reorder2(nvinfer1::DimsHW shape, T const* idata, nvinfer1::DimsHW istrides,
+ T* odata, nvinfer1::DimsHW ostrides) {
+ for (int h = 0; h < shape.h(); ++h) {
+ for (int w = 0; w < shape.w(); ++w) {
+ odata[h * ostrides.h() + w * ostrides.w()] =
+ idata[h * ostrides.h() + w * ostrides.w()];
+ }
+ }
+}
+
+// TODO(jie): fail to tensorflow!!
+void reorder_ck_to_kc(TRT_ShapedWeights const& iweights,
+ TRT_ShapedWeights* oweights) {
+ int c = iweights.shape_.d[0];
+ int k = iweights.shape_.d[1];
+ oweights->shape_.d[0] = k;
+ oweights->shape_.d[1] = c;
+ nvinfer1::DimsHW istrides = {1, k};
+ nvinfer1::DimsHW ostrides = {c, 1};
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT:
+ reorder2({k, c}, static_cast<float const*>(iweights.GetValues()), istrides,
+ static_cast<float*>(const_cast<void*>(oweights->GetValues())),
+ ostrides);
+ break;
+ default:
+ LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
+ }
+}
+
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
- TRT_ShapedWeights* oweights) {
+ TRT_ShapedWeights* oweights, int nbGroups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
int r = iweights.shape_.d[0];
int s = iweights.shape_.d[1];
- int c = iweights.shape_.d[2];
- int k = iweights.shape_.d[3];
- oweights->shape_.d[0] = k;
- oweights->shape_.d[1] = c;
+ // TRT requires GKcRS, while TF depthwise has RSCK
+ // where c=1, C=G
+ VLOG(2) << "nbGroups: " << nbGroups;
+ int c = iweights.shape_.d[2] / nbGroups;
+ VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
+ int k = iweights.shape_.d[3] * nbGroups;
+ VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
+ oweights->shape_.d[0] = k / nbGroups;
+ oweights->shape_.d[1] = c * nbGroups;
oweights->shape_.d[2] = r;
oweights->shape_.d[3] = s;
nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
@@ -342,9 +390,32 @@ class Converter {
std::vector<TRT_TensorOrWeights> get_inputs(
const tensorflow::NodeDef& node_def) {
std::vector<TRT_TensorOrWeights> inputs;
- for (const auto& input_name : node_def.input()) {
- VLOG(2) << "Retrieve input: " << input_name;
- inputs.push_back(trt_tensors_.at(input_name));
+ for (auto const& input_name : node_def.input()) {
+ /*************************************************************************
+ * TODO(jie) handle case 1) here
+ * Normalizes the inputs and extracts associated metadata:
+ * 1) Inputs can contain a colon followed by a suffix of characters.
+ * That suffix may be a single number (e.g. inputName:1) or several
+ * word characters separated from a number by a colon
+ * (e.g. inputName:foo:1). The
+ * latter case is used to denote inputs and outputs of functions.
+ * 2) Control dependency inputs contain caret at the beginning and we
+ * remove this and annotate the edge as a control dependency.
+ ************************************************************************/
+ string name =
+ input_name[0] == '^' ? input_name.substr(1) : input_name;
+ auto first = name.find_first_of(':');
+ if (first != string::npos && first + 2 == name.size() &&
+ name[first + 1] == '0')
+ name.erase(first);
+
+ VLOG(2) << "retrieve input: " << name;
+ if (trt_tensors_.count(name)) {
+ inputs.push_back(trt_tensors_.at(name));
+ } else {
+ LOG(FATAL) << "input: " << name << "not availabled for node at, "
+ << node_def.name();
+ }
}
return inputs;
}
@@ -722,38 +793,83 @@ tensorflow::Status BinaryTensorOpWeight(
auto dims_w = weights.shape_;
auto dims_t = tensor->getDimensions();
- // Default to channel-wise
+ // default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+ // TODO(jie): maybe use a permuatation instead to support more cases;
+ bool permutation_flag = false;
+
if (weights.count() == 1) {
VLOG(2) << "UNIFORM";
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
} else {
- // No broadcasting on Batch dimension;
- assert(dims_w.d[0] == 1);
-
- // Broadcasting on Channel dimension only allowed in kUNIFORM
- assert(dims_w.d[1] == dims_t.d[0]);
- assert(dims_w.nbDims == dims_t.nbDims);
-
- // Default is element;
- for (int i = 2; i < dims_w.nbDims; i++) {
- if (dims_w.d[i] != dims_t.d[i - 1]) {
- scale_mode = nvinfer1::ScaleMode::kCHANNEL;
- break;
+ // no broadcasting on Batch dimension;
+ VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims
+ << " tensor DIM: " << dims_t.nbDims;
+ if (dims_w.nbDims == dims_t.nbDims + 1) {
+ if (dims_w.d[0] == 1) {
+ for (int i = 1; i < dims_w.nbDims; i++){
+ dims_w.d[i - 1] = dims_w.d[i];
+ }
+ dims_w.nbDims--;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Binary op cannot operate on batch, " + node_def.name());
}
}
- if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
+
+ if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
- for (int i = 2; i < dims_w.nbDims; i++) {
- if (dims_w.d[i] != 1)
- return tensorflow::errors::InvalidArgument(
- "Weight shape not compatible at, " + node_def.name());
+ // default is element;
+ for (int i = 1; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != dims_t.d[i]) {
+ // if dimension does not match, switch back to channel;
+ VLOG(2) << "channel";
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ break;
+ }
+ }
+ // if channel as candidate, validate it
+ if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
+ for (int i = 1; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != 1)
+ return tensorflow::errors::InvalidArgument(
+ "Weight shape not compatible at, " + node_def.name());
+ }
+ } else {
+ VLOG(2) << "elementwise";
}
+ } else if (dims_w.nbDims == 1 &&
+ dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
+ // channel wise and broadcast required;
+ permutation_flag = true;
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Weight shape not compatible at, " + node_def.name());
}
}
- // Prepare weights
+ // transpose last dimension
+ std::vector<int> permutation(dims_t.nbDims + 1);
+ if (permutation_flag) {
+ if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
+ // we swap the last dimension into channel for trt.
+ // because of tensorflow default broadcasting rules.
+ for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
+ permutation[i] = i;
+ }
+ permutation[1] = dims_t.nbDims;
+ permutation[dims_t.nbDims] = 1;
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ permutation);
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Transpose cannot be applied, " + node_def.name());
+ }
+ }
+
+ // prepare weights
TRT_ShapedWeights shift_weights(weights.type_);
TRT_ShapedWeights scale_weights(weights.type_);
TRT_ShapedWeights power_weights(weights.type_);
@@ -779,90 +895,29 @@ tensorflow::Status BinaryTensorOpWeight(
scale_weights, power_weights);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ // transpose back dimension
+ if (permutation_flag) {
+ output_tensor = ctx.TransposeTensor(output_tensor, permutation);
+ }
// Pass the output
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
-tensorflow::Status BinaryTensorOpTensor(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
- std::vector<TRT_TensorOrWeights>* outputs) {
- static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
- {"Add", nvinfer1::ElementWiseOperation::kSUM},
- {"Mul", nvinfer1::ElementWiseOperation::kPROD},
- // {"max", nvinfer1::ElementWiseOperation::kMAX},
- // {"min", nvinfer1::ElementWiseOperation::kMIN},
- {"Sub", nvinfer1::ElementWiseOperation::kSUB},
- {"Div", nvinfer1::ElementWiseOperation::kDIV},
- };
-
- // FIXME assume type matches input weights
- // Get trt type & shape
- TFAttrs attrs(node_def);
- // Maybe this part has to be moved into the block of rsqrt later
- nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV };
- // Check type consistency
- CHECK_EQ_TYPE(tensor_l->getType(), dtype);
- CHECK_EQ_TYPE(tensor_r->getType(), dtype);
- auto op_pair = ops.find(node_def.op());
- if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
- " not supported at: " +
- node_def.name());
-
- nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
- *const_cast<nvinfer1::ITensor*>(tensor_l),
- *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
-
- nvinfer1::ITensor* output_tensor = layer->getOutput(0);
-
- // Pass the output
- outputs->push_back(TRT_TensorOrWeights(output_tensor));
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status ConvertPlaceholder(
- Converter& ctx, const tensorflow::NodeDef& node_def,
+tensorflow::Status ConvertConv2DHelper(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
std::vector<TRT_TensorOrWeights> const& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- VLOG(2) << "Placeholder should have been replace already";
- return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
- // OK this make sense since we are supposed to replace it with input
- TFAttrs attrs(node_def);
- nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
- nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
-
- dims.nbDims--;
- for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
-
- nvinfer1::ITensor* output =
- ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
- if (!output) {
- return tensorflow::errors::InvalidArgument("Failed to create Input layer");
- }
- outputs->push_back(TRT_TensorOrWeights(output));
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status ConvertConv2D(Converter& ctx,
- const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
+ std::vector<TRT_TensorOrWeights>* outputs,
+ int group // group ==0 specifies depthwise conv
+) {
nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
- // TODO(jie): handle NHWC/NCHW transpose;
- TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
- TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
- ReorderRSCKToKCRS(weights_rsck, &weights);
- TRT_ShapedWeights biases(weights.type_);
- int noutput = weights.shape_.d[0];
- nvinfer1::DimsHW kernel_size;
- kernel_size.h() = weights.shape_.d[2];
- kernel_size.w() = weights.shape_.d[3];
+
TFAttrs attrs(node_def);
+ int c_index = 1;
int h_index = 2;
int w_index = 3;
auto data_format = attrs.get<string>("data_format");
@@ -871,14 +926,35 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
{0, 3, 1, 2});
h_index = 1;
w_index = 2;
+ c_index = 3;
// TODO(jie): transpose it
}
+ // tensor after transpose (NCHW)
+ auto tensor_dim = tensor->getDimensions();
+
+ int nbGroups = group;
+ if (nbGroups == 0) // depthwise convolution
+ nbGroups = tensor_dim.d[0];
+ VLOG(2) << "groups count: " << nbGroups;
+
+ TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+ TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
+ ReorderRSCKToKCRS(weights_rsck, &weights, nbGroups);
+ TRT_ShapedWeights biases(weights.type_);
+ int noutput = weights.shape_.d[0] * nbGroups;
+ nvinfer1::DimsHW kernel_size;
+ kernel_size.h() = weights.shape_.d[2];
+ kernel_size.w() = weights.shape_.d[3];
+ VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
+
// TODO(jie): stride. (NHWC/NCHW)
auto tf_stride = attrs.get<std::vector<int>>("strides");
+ VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
+ VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
+ << tf_stride[3];
nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
- auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding;
// TODO(jie): padding.
if (attrs.get<string>("padding") == "SAME") {
@@ -919,6 +995,7 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
layer->setName(node_def.name().c_str());
+ layer->setNbGroups(nbGroups);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
auto dim_after = output_tensor->getDimensions();
@@ -935,6 +1012,99 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertConv2DHelper(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs, ConvolutionType type) {
+ switch (type) {
+ case ConvolutionType::DEFAULT:
+ return ConvertConv2DHelper(ctx, node_def, inputs, outputs, 1);
+ case ConvolutionType::DEPTHWISE_CONV:
+ return ConvertConv2DHelper(ctx, node_def, inputs, outputs, 0);
+ }
+ return tensorflow::errors::Unimplemented("unsupported convolution type at, " +
+ node_def.name());
+}
+
+tensorflow::Status BinaryTensorOpTensor(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ static const std::unordered_map<string, nvinfer1::ElementWiseOperation>
+ ops{
+ {"Add", nvinfer1::ElementWiseOperation::kSUM},
+ {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+ // {"max", nvinfer1::ElementWiseOperation::kMAX},
+ // {"min", nvinfer1::ElementWiseOperation::kMIN},
+ {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+ {"Div", nvinfer1::ElementWiseOperation::kDIV},
+ };
+
+ // FIXME assume type matches input weights
+ // get trt type & shape
+ TFAttrs attrs(node_def);
+ // maybe this part has to be moved into the block of rsqrt later
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+
+ // check type consistency
+ CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ auto op_pair = ops.find(node_def.op());
+ if (op_pair == ops.end())
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
+
+ nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor_l),
+ *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPlaceholder(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ VLOG(2) << "Placeholder should have been replace already";
+ return tensorflow::errors::Unimplemented("cannot convert Placeholder op");
+ // OK this make sense since we are supposed to replace it with input
+ TFAttrs attrs(node_def);
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+ nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
+
+ dims.nbDims--;
+ for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
+
+ nvinfer1::ITensor* output =
+ ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
+ if (!output) {
+ return tensorflow::errors::InvalidArgument("Failed to create Input layer");
+ }
+ outputs->push_back(TRT_TensorOrWeights(output));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConv2D(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ return ConvertConv2DHelper(ctx, node_def, inputs, outputs,
+ ConvolutionType::DEFAULT);
+}
+
+tensorflow::Status ConvertConv2DDepthwise(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ return ConvertConv2DHelper(ctx, node_def, inputs, outputs,
+ ConvolutionType::DEPTHWISE_CONV);
+}
+
tensorflow::Status ConvertPool(Converter& ctx,
const tensorflow::NodeDef& node_def,
std::vector<TRT_TensorOrWeights> const& inputs,
@@ -957,6 +1127,8 @@ tensorflow::Status ConvertPool(Converter& ctx,
// TODO(jie): support other pooling type
if (node_def.op() == "MaxPool")
type = nvinfer1::PoolingType::kMAX;
+ else if (node_def.op() == "AvgPool")
+ type = nvinfer1::PoolingType::kAVERAGE;
else
return tensorflow::errors::Unimplemented("Only supports Max pool");
@@ -1055,9 +1227,26 @@ tensorflow::Status ConvertScale(Converter& ctx,
} else {
VLOG(2) << "NCHW !!!!";
}
- nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
- *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
- weights, empty_weights, empty_weights);
+
+ auto dims = tensor->getDimensions();
+ VLOG(2) << "tensor dimensions: " << dims.nbDims;
+ for (int i = 0; i < dims.nbDims; i++) {
+ VLOG(2) << "i: " << dims.d[i];
+ }
+ dims = weights.shape_;
+ VLOG(2) << "tensor dimensions: " << dims.nbDims;
+ for (int i = 0; i < dims.nbDims; i++) {
+ VLOG(2) << "i: " << dims.d[i];
+ }
+
+ nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
+ if (weights.shape_.d[0] == 1) {
+ mode = nvinfer1::ScaleMode::kUNIFORM;
+ }
+
+ nvinfer1::IScaleLayer* layer =
+ ctx.network()->addScale(*const_cast<nvinfer1::ITensor*>(tensor), mode,
+ weights, empty_weights, empty_weights);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
if (data_format == "NHWC") {
@@ -1091,21 +1280,73 @@ tensorflow::Status ConvertConst(Converter& ctx,
VLOG(2) << "SCALAR!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
if (tensor.dims() > 0) {
- VLOG(2) << "Dimensions: " << tensor.dims();
- weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
- GetTensorShape(tensor));
+ VLOG(2) << "dimensions: " << tensor.dims();
+ VLOG(2) << "size: " << weights_tensor.float_val_size();
+ scalar_shape = GetTensorShape(tensor);
+ for (int i = 0; i < scalar_shape.nbDims; i++)
+ VLOG(2) << scalar_shape.d[i];
+ if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) {
+ if (weights_tensor.float_val_size() == 1 ||
+ scalar_shape.d[0] == weights_tensor.float_val_size()) {
+ scalar_shape.nbDims = 1;
+ // no dimension provided. flatten it
+ scalar_shape.d[0] = weights_tensor.float_val_size();
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ } else {
+ LOG(FATAL) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ }
+ }
} else {
VLOG(2) << "Dimensions: " << tensor.dims();
scalar_shape.nbDims = 1;
- scalar_shape.d[0] = 1;
+ // no dimension provided. flatten it
+ scalar_shape.d[0] = weights_tensor.float_val_size();
scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
scalar_shape.d[i] = 0;
scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
- weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
- scalar_shape);
}
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ scalar_shape);
+ // LOG(INFO) << " add: " << weights_tensor.float_val().data();
+ // LOG(INFO) << " value: " << (*weights_tensor.float_val().data());
+
+ // weights = ctx.get_temp_weights(dtype, scalar_shape);
+ // std::memcpy(const_cast<void*>(weights.values),
+ // weights_tensor.float_val().data(), weights.size_bytes());
+ } else if (!weights_tensor.int_val().empty()) {
+ VLOG(2) << "int!!!" << node_def.name();
+ nvinfer1::Dims scalar_shape;
+ if (tensor.dims() > 0) {
+ VLOG(2) << "dimensions: " << tensor.dims();
+ scalar_shape = GetTensorShape(tensor);
+ if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) {
+ if (weights_tensor.int_val_size() == 1 ||
+ scalar_shape.d[0] == weights_tensor.int_val_size()) {
+ scalar_shape.nbDims = 1;
+ // no dimension provided. flatten it
+ scalar_shape.d[0] = weights_tensor.int_val_size();
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ } else {
+ LOG(FATAL) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ }
+ }
+ } else {
+ VLOG(2) << "dimensions: " << tensor.dims();
+ scalar_shape.nbDims = 1;
+ // no dimension provided. flatten it
+ scalar_shape.d[0] = weights_tensor.int_val_size();
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
+ scalar_shape.d[i] = 0;
+ scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
+ }
+ }
+ weights =
+ TRT_ShapedWeights(dtype, weights_tensor.int_val().data(), scalar_shape);
} else if (!weights_tensor.tensor_content().empty()) {
VLOG(2) << "TENSOR!!!" << node_def.name();
const auto& content = weights_tensor.tensor_content();
@@ -1229,6 +1470,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
node_def.name());
if (index_list_data[i] == 1) permuted_index = 1;
+
idx_set.emplace(index_list_data[i]);
}
@@ -1236,7 +1478,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
nvinfer1::DimsHW pool_kernel;
if (permuted_index == 1) {
for (int i = 2; i < nb_dims; i++) {
- if (idx_set.count(i)) {
+ if (idx_set.count(i) == 0 ) {
permuted_index = i;
break;
}
@@ -1271,6 +1513,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
output_tensor = ctx.TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
}
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1371,19 +1614,206 @@ tensorflow::Status ConvertPad(Converter& ctx,
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertConcat(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // not including the last input (axis) here
+ int input_size = static_cast<int>(inputs.size()) - 1;
+
+ if (!inputs.at(0).is_tensor())
+ return tensorflow::errors::InvalidArgument(
+ "Concat in TRT support only Tensor input, at " + node_def.name());
+
+ // We are retrieving the axis
+ TRT_ShapedWeights axis = inputs.at(input_size).weights();
+
+ TFAttrs attrs(node_def);
+ // auto attr_size = attrs.at("N")->i();
+ // auto data_type = attrs.get<nvinfer1::DataType>("T");
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+
+ // TODO(jie): handle data type
+ // Only expect to handle INT32 as index attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented(
+ "Tidx supports only DT_INT32, at " + node_def.name());
+
+ int index = *(static_cast<int*>(const_cast<void*>(axis.GetValues())));
+
+ // TODO(jie): early termination with no-op (attr_size==1)
+
+ auto dim = inputs.at(0).tensor()->getDimensions();
+ // dimension check
+ if (index > dim.nbDims + 1)
+ return tensorflow::errors::InvalidArgument(
+ "Concatenate on axis out of dimension range, at " + node_def.name());
+
+ if (index == 0)
+ return tensorflow::errors::InvalidArgument(
+ "Concatenate on batch dimension not supported, at " + node_def.name());
+
+ // incase we need permutation;
+ std::vector<int> permutation_order(dim.nbDims + 1);
+
+ for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i;
+
+ if (index != 1) {
+ permutation_order[1] = index - 1;
+ permutation_order[index - 1] = 1;
+ }
+
+ std::vector<nvinfer1::ITensor const*> inputs_vec;
+ // Shap chack (all input tensor should have same shape)
+ // starting from 0 since we are probably also doing transpose here;
+ for (int i = 0; i < input_size; i++) {
+ auto tensor_i = inputs.at(i).tensor();
+ auto dim_i = tensor_i->getDimensions();
+ if (dim_i.nbDims != dim.nbDims)
+ return tensorflow::errors::InvalidArgument(
+ "Concatenate receives inputs with inconsistent dimensions, at " +
+ node_def.name());
+
+ for (int j = 0; j < dim.nbDims; j++) {
+ // check dimension consistency on non-concatenate axis
+ if (j != index - 1 && dim_i.d[j] != dim.d[j])
+ return tensorflow::errors::InvalidArgument(
+ "Concatenate receives inputs with inconsistent shape, at" +
+ node_def.name());
+ }
+
+ // TRT does concatenation only on channel!
+ if (index != 1)
+ tensor_i = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor_i),
+ permutation_order);
+
+ inputs_vec.push_back(tensor_i);
+ }
+
+ // nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation(
+ const_cast<nvinfer1::ITensor* const*>(inputs_vec.data()),
+ inputs_vec.size());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (index != 1) {
+ output_tensor = ctx.TransposeTensor(output_tensor, permutation_order);
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertMatMul(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+
+ // TODO(jie): transpose!
+ TFAttrs attrs(node_def);
+ // bool transpose_w = bool(attrs->at("transpose_b")->i());
+
+ // tensor after transpose (NCHW)
+ auto tensor_dim = tensor->getDimensions();
+
+ TRT_ShapedWeights weights_ck = inputs.at(1).weights();
+ TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck);
+ reorder_ck_to_kc(weights_ck, &weights);
+ TRT_ShapedWeights biases(weights.type_);
+
+ int noutput = weights.shape_.d[0];
+
+ nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected(
+ *const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertReshape(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // restore implicit batch dimension
+ int nbDims = dims.nbDims + 1;
+
+ TRT_ShapedWeights shape = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+
+ auto padding_type = attrs.get<tensorflow::DataType>("Tshape");
+
+ if (shape.shape_.nbDims != 1)
+ return tensorflow::errors::InvalidArgument(
+ "reshape new shape is not 1 dimensional, at " + node_def.name());
+
+ // Only expect to handle INT32 as attributes for now
+ if (padding_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented(
+ "reshape new shape supports only DT_INT32, at " + node_def.name());
+
+ auto shape_data = static_cast<int*>(const_cast<void*>(shape.GetValues()));
+
+ if (shape_data[0] != -1)
+ return tensorflow::errors::InvalidArgument(
+ "reshape new shape first dimension is not -1, at " + node_def.name());
+
+ auto shape_num_dims = shape.shape_.d[0];
+ VLOG(2) << "shape dimensions: " << shape_num_dims;
+ int volume_w = 1;
+ for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i];
+
+ int volume_t = 1;
+ for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i];
+
+ VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w;
+ if (volume_w != volume_t)
+ return tensorflow::errors::InvalidArgument(
+ "volume does not agree between tensor and new shape, at " +
+ node_def.name());
+
+ nvinfer1::IShuffleLayer* layer =
+ ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor));
+
+ nvinfer1::Dims reshapeDims;
+ VLOG(2) << "new dimension: " << shape_num_dims - 1;
+ reshapeDims.nbDims = shape_num_dims - 1;
+ for (int32_t i = 0; i < reshapeDims.nbDims; ++i) {
+ reshapeDims.d[i] = shape_data[i + 1];
+ }
+ layer->setReshapeDimensions(reshapeDims);
+ VLOG(2) << "new dimension: " << shape_num_dims - 1;
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ auto dims_output = output_tensor->getDimensions();
+ VLOG(2) << "output tensor dimension:" << dims_output.nbDims;
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
void Converter::register_op_converters() {
// vgg_16 slim implementation
op_registry_["Placeholder"] = ConvertPlaceholder;
op_registry_["Conv2D"] = ConvertConv2D;
+ op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
op_registry_["Relu"] = ConvertActivation;
op_registry_["MaxPool"] = ConvertPool;
+ op_registry_["AvgPool"] = ConvertPool;
// This could be really handled as ConvertBinary
op_registry_["BiasAdd"] = ConvertScale;
op_registry_["Const"] = ConvertConst;
- // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg
+ // op_registry_["MatMul"] = ConvertFullyConnected; // not used in vgg
// TODO(ben,jie): this is a temp hack.
op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
- // op_registry_["AvgPool"] = ConvertPool;
// resnet_50_v1 slim implementation
op_registry_["Add"] = ConvertBinary;
@@ -1393,26 +1823,355 @@ void Converter::register_op_converters() {
op_registry_["Mean"] = ConvertReduce;
op_registry_["Pad"] = ConvertPad;
// TODO(ben,jie): Add more ops
+
+ op_registry_["ConcatV2"] = ConvertConcat;
+ op_registry_["MatMul"] = ConvertMatMul;
+ op_registry_["Reshape"] = ConvertReshape;
}
} // namespace
+tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) {
+ return tensorflow::errors::Unimplemented("Not implemented yet");
+}
+tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph &graph,
+ tensorflow::Node *c_node) {
+ const auto ndef=c_node->def();
+
+ TFAttrs attrs(ndef);
+ std::vector<string> segment_nodes(attrs.get<std::vector<string>>("segment_nodes"));
+ std::vector<string> output_nodes(attrs.get<std::vector<string>>("segment_output_names"));
+ std::vector<string> input_names(attrs.get<std::vector<string>>("input_names"));
+ string res_name = attrs.get<string>("resource_name");
+ VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name;
+ string engine_name="my_trt_op";
+ {
+ const auto node_id=tensorflow::str_util::Split(res_name,"_");
+ engine_name+=node_id.back();
+ }
+ std::map<string,tensorflow::Node*> nodeMaps;
+
+ for(auto n: graph.op_nodes()){
+ nodeMaps.insert({n->name(),n});
+ }
+ VLOG(1)<<"Output Nodes:";
+ std::vector<tensorflow::DataType> out_types;
+ std::vector<const tensorflow::Edge*> out_edges;
+ for(auto &i : output_nodes ){
+ auto node_port=tensorflow::str_util::Split(i,":");
+ VLOG(1) << " " << i << " in graph " << nodeMaps.count(i);
+ auto out_node_name = node_port.at(0);
+ if(node_port.size()>1){
+ VLOG(1) << "Multi port output" << node_port.at(0) <<
+ " " << node_port.at(1) << " size=" << node_port.size();
+ }
+ auto nodeIt=nodeMaps.find(out_node_name);
+ if(nodeIt!=nodeMaps.end()){
+ tensorflow::Node* outNode=nodeIt->second;
+ int port=0;
+ if(node_port.size()==2){
+ port=std::strtoul(node_port.at(1).c_str(),nullptr,10);
+ out_types.push_back(outNode->output_type(port));
+ }else{
+ out_types.push_back(outNode->output_type(0));
+ }
+ for(auto outEdge : outNode->out_edges()){
+ if(outEdge->src_output()==port){
+ out_edges.push_back(outEdge);
+ break;
+ }
+ }
+ }else{
+ LOG(WARNING)<<" couldn't find output node "<<out_node_name;
+ }
+ }
+ VLOG(1)<<"Input Nodes:";
+ for(auto &i : input_names){
+ VLOG(1) << " " << i << " in graph " << nodeMaps.count(i);
+ }
+ auto trt_rm = tensorflow::trt::TRTResourceManager::instance();
+ auto resmgr = trt_rm->getManager("TRTCalibOps");
+ tensorflow::trt::TRTCalibrationResource* calibRes = nullptr;
+ auto status = resmgr->Lookup(res_name, res_name, &calibRes);
+ if(!status.ok() || !calibRes->calibrator){
+ return tensorflow::errors::FailedPrecondition("You must run calibration"\
+ " and inference conversion in the same proces");
+ }
+
+ calibRes->calibrator->setDone();
+ VLOG(1)<<"Waiting for calibration thread to join";
+ calibRes->thr->join();
+ delete calibRes->thr;
+ if(!calibRes->engine){
+ LOG(FATAL)<<"Calibration failed!, engine is nullptr";
+ }
+ auto engine_plan_string=calibRes->engine->serialize();
+ calibRes->engine->destroy();
+ calibRes->network->destroy();
+ calibRes->builder->destroy();
+ calibRes->thr= nullptr;
+ calibRes->engine= nullptr;
+ calibRes->builder= nullptr;
+ tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ for(const auto in_edge : c_node->in_edges()){
+ auto src=in_edge->src();
+ int dest_port=in_edge->dst_input();
+ income_edges.emplace_back(src->name(),in_edge->src_output(),c_node->input_type(dest_port));
+ }
+ tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
+ income_edges);
+ op_builder.Input(input_list);
+ tensorflow::NodeDef engine_node;
+ status = op_builder.Attr("serialized_engine", engine_plan_string)
+ .Attr("input_nodes", input_names)
+ .Attr("output_nodes", output_nodes)
+ .Attr("OutT", out_types)
+ .Finalize(&engine_node);
+ if(!status.ok()){
+ LOG(ERROR)<<"Engine Node creation failed";
+ return status;
+ }
+ auto trt_engine_node=graph.AddNode(engine_node,&status);
+ TF_CHECK_OK(status);
+ for(size_t i=0;i<out_edges.size();i++) {
+
+ VLOG(1)<<"Connecting trt_engine_node output " << i << " with "
+ << out_edges.at(i)->dst()->name() << " port "
+ << out_edges.at(i)->dst_input();
+ TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i,
+ out_edges.at(i)->dst(),
+ out_edges.at(i)->dst_input()));
+ }
+ VLOG(1) << "Segment nodes:";
+ for (auto &i : segment_nodes){
+ VLOG(1) << " " << i << " in graph " << nodeMaps.count(i);
+ auto it=nodeMaps.find(i);
+ if(it!=nodeMaps.end()){
+ graph.RemoveNode(it->second);
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
+ // Visit nodes in reverse topological order and construct the TRT network.
+
+ // Toposort
+ std::vector<tensorflow::Node*> order_vec;
+ tensorflow::GetPostOrder(s.graph, &order_vec);
+ // Select just the subgraph
+ std::list<tensorflow::Node*> order;
+ for (tensorflow::Node* node : order_vec) {
+ if (s.subgraph_node_ids.count(node->id())) {
+ // order.push_back(node);
+ order.push_front(node); // we want topological order to contstruct the
+ // network layer by layer
+ }
+ }
+ // topological order is needed to build TRT network
+ VLOG(2) << "BUILDING 1";
+ static int static_id = 0;
+ string calib_op_name =
+ tensorflow::strings::StrCat("my_trt_calib_op_", static_id);
+ string engine_name = tensorflow::strings::StrCat("my_trt_op", static_id);
+
+ static_id++;
+ VLOG(2) << "BUILDING 2";
+ auto trt_rmgr = tensorflow::trt::TRTResourceManager::instance();
+ auto op_rmgr = trt_rmgr->getManager("TRTCalibOps");
+ auto op_res = new tensorflow::trt::TRTCalibrationResource();
+ VLOG(1)<<"SAMI Creating calibresource "<<calib_op_name<<" @ "<<op_res;
+ TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
+ op_res->logger = new tensorflow::tensorrt::Logger();
+ op_res->builder = nvinfer1::createInferBuilder(*(op_res->logger));
+
+ if (!op_res->builder) {
+ return tensorflow::errors::Internal(
+ "failed to create TensorRT builder object");
+ }
+
+ VLOG(2) << "BUILDING 3";
+
+ op_res->network = op_res->builder->createNetwork();
+ if (!op_res->network) {
+ return tensorflow::errors::Internal(
+ "failed to create TensorRT network object");
+ }
+
+ VLOG(2) << "BUILDING 4";
+
+ // Build the network
+ Converter converter(op_res->network);
+
+ VLOG(2) << "BUILDING 5";
+ std::vector<string> input_names;
+ std::vector<tensorflow::DataType> input_dtypes;
+ for (std::pair<int, int> const& input : s.input_inds) {
+ VLOG(2) << "parsing input!!!!!";
+ int node_id = input.first;
+ int output_idx = input.second;
+ tensorflow::Node* node = s.graph.FindNodeId(node_id);
+ auto node_name = node->name();
+ input_names.push_back(node_name); // insert original node name without port
+ // TODO(jie): alternative :)
+ // tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ if (!s.graph_properties.HasOutputProperties(node_name))
+ return tensorflow::errors::Internal("failed to find input node: " +
+ node_name);
+
+ auto op_info_vec = s.graph_properties.GetOutputProperties(node_name);
+ if (static_cast<int>(op_info_vec.size()) < output_idx)
+ return tensorflow::errors::Internal(
+ "accessing output index of: " + std::to_string(output_idx) +
+ ", at node: " + node_name + "with output entry from shape_map: " +
+ std::to_string(op_info_vec.size()));
+
+ auto op_info = op_info_vec.at(output_idx);
+
+ tensorflow::DataType tf_dtype = op_info.dtype();
+ input_dtypes.push_back(tf_dtype);
+
+ nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
+
+ VLOG(2) << "accessing output index of: " << std::to_string(output_idx)
+ << ", at node: " << node_name
+ << "with output entry from shape_map: "
+ << std::to_string(op_info_vec.size());
+
+ // TODO(ben,jie): update TRT input format/dimension
+ nvinfer1::DimsCHW input_dim_psuedo_chw;
+ for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+
+ for (int i = 1; i < op_info.shape().dim_size(); i++) {
+ VLOG(2) << "dimension: " << i
+ << " , size: " << op_info.shape().dim(i).size();
+ input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
+ }
+
+ // TODO(ben,jie): proper way to restore input tensor name?
+ auto input_tensor_name = node_name;
+ if (output_idx != 0)
+ input_tensor_name = node_name + ":" + std::to_string(output_idx);
+
+ nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+ input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
+
+ if (!input_tensor)
+ return tensorflow::errors::InvalidArgument(
+ "Failed to create Input layer");
+ VLOG(2) << "input tensor name :" << input_tensor_name;
+
+ if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
+ return tensorflow::errors::AlreadyExists(
+ "output tensor already exists for op: " + input_tensor_name);
+ }
+
+ VLOG(2) << "finished sorting";
+
+ for (const tensorflow::Node* node : order) {
+ tensorflow::NodeDef const& node_def = node->def();
+ VLOG(2) << "converting node: " << node_def.name() << " , "
+ << node_def.op();
+ TF_RETURN_IF_ERROR(converter.convert_node(node_def));
+ }
+
+ VLOG(2) << "finished conversion";
+
+ // Gather output metadata
+ std::vector<string> output_names;
+ std::vector<tensorflow::DataType> output_dtypes;
+ int trt_engine_op_output_idx = 0;
+ for (std::pair<int, int> const& output : s.output_inds) {
+ int node_id = output.first;
+ int output_idx = output.second;
+ tensorflow::Node* node = s.graph.FindNodeId(node_id);
+ string op_name = node->name();
+ string tensor_name = op_name;
+
+ s.output_edge_map->insert(
+ {trt_engine_op_output_idx == 0
+ ? engine_name
+ : engine_name + ":" + std::to_string(trt_engine_op_output_idx),
+ {output_idx, tensor_name}});
+ trt_engine_op_output_idx++;
+ if (output_idx != 0)
+ tensor_name = tensor_name + ":" + std::to_string(output_idx);
+ VLOG(1) << "output tensor name: " << tensor_name;
+ output_names.push_back(tensor_name);
+ auto tensor_or_weights = converter.get_tensor(tensor_name);
+ if (!tensor_or_weights.is_tensor()) {
+ return tensorflow::errors::InvalidArgument(
+ "Output node is weights not tensor");
+ }
+ nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+ if (!tensor) {
+ return tensorflow::errors::NotFound("Output tensor not found: " +
+ tensor_name);
+ }
+ converter.network()->markOutput(*tensor);
+ tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ output_dtypes.push_back(tf_dtype);
+ nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
+ TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
+ tensor->setType(trt_dtype);
+ }
+
+ VLOG(2) << "finished output";
+
+ // Build the engine
+ op_res->builder->setMaxBatchSize(s.max_batch_size);
+ op_res->builder->setMaxWorkspaceSize(s.max_workspace_size_bytes);
+
+ // Build the TRT op
+ // TODO(sami,ben,jie): proper naming!
+ tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp");
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ int output_idx = s.input_inds.at(i).second;
+ // we wired up the input here already, it is redundant to do it again in
+ // ConvertSubGraphToTensorRT(convert_graph.cc)
+ auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
+ input_names.at(i), output_idx, input_dtypes.at(i));
+ VLOG(1) << calib_op_name << " input " << i << " = " << input_names.at(i)
+ << ":" << output_idx
+ <<" dType= "<< tensorflow::DataTypeString(input_dtypes.at(i));
+ income_edges.push_back(incoming_edge);
+ }
+ tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
+ income_edges);
+ op_builder.Input(input_list);
+ std::vector<string> segment_names;
+ segment_names.reserve(s.subgraph_node_ids.size());
+ for (int i : s.subgraph_node_ids) {
+ auto node = s.graph.FindNodeId(i);
+ segment_names.push_back(node->name());
+ }
+ LOG(INFO) << "finished op preparation";
+
+ auto status = op_builder.Attr("segment_nodes", segment_names)
+ .Attr("input_names",input_names)
+ .Attr("segment_output_names", output_names)
+ .Attr("resource_name",calib_op_name)
+ .Finalize(s.trt_node);
+
+ LOG(INFO) << status.ToString();
+ LOG(INFO) << "finished op building";
+
+ return tensorflow::Status::OK();
+}
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
- const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
- const std::vector<std::pair<int, int>>& input_inds,
- const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
- size_t max_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& graph_properties,
- tensorflow::NodeDef* trt_node) {
+ tensorrt::convert::SubGraphParams& s) {
// Visit nodes in reverse topological order and construct the TRT network.
// Toposort
std::vector<tensorflow::Node*> order_vec;
- tensorflow::GetPostOrder(graph, &order_vec);
+ tensorflow::GetPostOrder(s.graph, &order_vec);
// Select just the subgraph
std::list<tensorflow::Node*> order;
for (tensorflow::Node* node : order_vec) {
- if (subgraph_node_ids.count(node->id())) {
+ if (s.subgraph_node_ids.count(node->id())) {
// We want topological order to contstruct the
// network layer by layer
order.push_front(node);
@@ -1439,26 +2198,46 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
std::vector<string> input_names;
std::vector<tensorflow::DataType> input_dtypes;
- for (std::pair<int, int> const& input : input_inds) {
+ for (const std::pair<int, int>& input : s.input_inds) {
+ VLOG(2) << "parsing input!!!!!";
int node_id = input.first;
int output_idx = input.second;
- tensorflow::Node* node = graph.FindNodeId(node_id);
+ tensorflow::Node* node = s.graph.FindNodeId(node_id);
auto node_name = node->name();
- input_names.push_back(node_name); // Insert original node name without port
- // TODO(jie): alternative :)
- if (!graph_properties.HasOutputProperties(node_name))
- return tensorflow::errors::Internal("Failed to find input node: " +
- node_name);
+ // input_names should use the node name in the graph
+ // here it should be the input tensor name -> matching the binding
+ // insert original node name without port
+ auto tensor_name = node_name;
+ if (output_idx != 0)
+ tensor_name = tensor_name + ":" + std::to_string(output_idx);
- auto op_info_vec = graph_properties.GetOutputProperties(node_name);
- if (static_cast<int>(op_info_vec.size()) < output_idx)
+ VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
+ << " idx: " << output_idx;
+
+ auto shape_inference_node_name = node_name;
+ auto shape_inference_output_idx = output_idx;
+ // rewire the shape inference to original node in the graph
+ if (s.output_edge_map->count(tensor_name)) {
+ shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
+ shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
+ }
+ VLOG(2) << "shapeinference name: " << shape_inference_node_name
+ << " idx: " << shape_inference_output_idx;
+
+ if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
+ return tensorflow::errors::Internal("failed to find input node: " +
+ shape_inference_node_name);
+
+ auto op_info_vec =
+ s.graph_properties.GetOutputProperties(shape_inference_node_name);
+ if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
return tensorflow::errors::Internal(
- "Accessing output index of: " + std::to_string(output_idx) +
- ", at node: " + node_name + " with output entry from shape_map: " +
+ "accessing output index of: " +
+ std::to_string(shape_inference_output_idx) + ", at node: " +
+ shape_inference_node_name + " with output entry from shape_map: " +
std::to_string(op_info_vec.size()));
- auto op_info = op_info_vec.at(output_idx);
-
+ auto op_info = op_info_vec.at(shape_inference_output_idx);
tensorflow::DataType tf_dtype = op_info.dtype();
input_dtypes.push_back(tf_dtype);
@@ -1469,14 +2248,18 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
<< ", at node: " << node_name
<< " with output entry from shape_map: "
<< std::to_string(op_info_vec.size());
-
// TODO(ben,jie): update TRT input format/dimension
nvinfer1::DimsCHW input_dim_psuedo_chw;
for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+ // TODO(jie): TRT 3.x only support 4 dimensional input tensor.
+ // update the code once TRT 4.0 comes out.
+ if (op_info.shape().dim_size() != 4)
+ return tensorflow::errors::Unimplemented("require 4 dimensional input");
+
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(2) << "dimension: " << i
- << " , size: " << op_info.shape().dim(i).size();
+ << " , size: " << op_info.shape().dim(i).size();
input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
}
@@ -1485,6 +2268,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
if (output_idx != 0)
input_tensor_name = node_name + ":" + std::to_string(output_idx);
+ input_names.push_back(input_tensor_name);
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
@@ -1508,17 +2292,29 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
VLOG(2) << "Finished conversion";
+ // TODO(sami,ben,jie): proper naming!
+ static int static_id = 0;
+ string engine_name = tensorflow::strings::StrCat("my_trt_op", static_id++);
+
// Gather output metadata
std::vector<string> output_names;
std::vector<tensorflow::DataType> output_dtypes;
- for (std::pair<int, int> const& output : output_inds) {
+ int trt_engine_op_output_idx = 0;
+ for (std::pair<int, int> const& output : s.output_inds) {
int node_id = output.first;
int output_idx = output.second;
- tensorflow::Node* node = graph.FindNodeId(node_id);
+ tensorflow::Node* node = s.graph.FindNodeId(node_id);
string op_name = node->name();
string tensor_name = op_name;
+
+ s.output_edge_map->insert(
+ {trt_engine_op_output_idx == 0
+ ? engine_name
+ : tensorflow::strings::StrCat(engine_name,":",trt_engine_op_output_idx),
+ {output_idx, tensor_name}});
+ trt_engine_op_output_idx++;
if (output_idx != 0)
- tensor_name = tensor_name + ":" + std::to_string(output_idx);
+ tensorflow::strings::StrAppend(&tensor_name, ":" ,std::to_string(output_idx));
VLOG(2) << "Output tensor name: " << tensor_name;
output_names.push_back(tensor_name);
auto tensor_or_weights = converter.get_tensor(tensor_name);
@@ -1541,12 +2337,15 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
VLOG(2) << "Finished output";
// TODO(jie): static_id is not thread safe.
- static int static_id = 0;
+
// Build the engine
- trt_builder->setMaxBatchSize(max_batch_size);
- trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
- VLOG(0) << "Starting build engine " << static_id;
+ trt_builder->setMaxBatchSize(s.max_batch_size);
+ trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes);
+ if(s.precision_mode==1){
+ trt_builder->setHalf2Mode(true);
+ }
+ LOG(INFO) << "starting build engine";
// TODO(ben,jie): half2 and int8 mode support
string engine_plan_string;
{
@@ -1561,17 +2360,18 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
string(engine_plan_data, engine_plan_data + engine_plan->size());
}
- VLOG(0) << "Finished engine";
+ LOG(INFO) << "finished engine " << engine_name;
// Build the TRT op
- // TODO(sami,ben,jie): proper naming!
- tensorflow::NodeDefBuilder op_builder(
- tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
+ tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ VLOG(2) << "input edge size: " << input_names.size();
for (size_t i = 0; i < input_names.size(); ++i) {
- int output_idx = input_inds.at(i).second;
- // We wired up the input here already, it is redundant to do it again in
- // ConvertSubGraphToTensorRT(convert_graph.cc)
+ VLOG(2) << "input edges: " << std::to_string(i) << " "
+ << input_names.at(i);
+ int output_idx = s.input_inds.at(i).second;
+ // we wired up the input here already, it is redundant to do it again in
+ // ConvertSubGraphToTensorRT(convert_graph.cc)
auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
input_names.at(i), output_idx, input_dtypes.at(i));
income_edges.push_back(incoming_edge);
@@ -1586,7 +2386,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
.Attr("input_nodes", input_names)
.Attr("output_nodes", output_names)
.Attr("OutT", output_dtypes)
- .Finalize(trt_node);
+ .Finalize(s.trt_node);
VLOG(0) << status.ToString() << " finished op building";
@@ -1597,5 +2397,5 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
} // namespace tensorrt
} // namespace tensorflow
-#endif // GOOGLE_TENSORRT
-#endif // GOOGLE_CUDA
+//#endif // GOOGLE_TENSORRT
+//#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 2e7fd19566..49e060a553 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#include <set>
+#include <string>
+#include <unordered_map>
#include <utility>
#include <vector>
@@ -25,28 +27,56 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
+//#if GOOGLE_CUDA
+//#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
namespace convert {
-tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
- const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
- const std::vector<std::pair<int, int>>&
- input_inds, // {node_id, output_idx}
- const std::vector<std::pair<int, int>>&
- output_inds, // {node_id, output_idx}
- size_t max_batch_size, size_t max_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& graph_prop,
- tensorflow::NodeDef* trt_node);
+struct SubGraphParams {
+ SubGraphParams(tensorflow::Graph& graph_,
+ const std::set<int>& subgraph_node_ids_,
+ const std::vector<std::pair<int, int>>& input_inds_,
+ const std::vector<std::pair<int, int>>& output_inds_,
+ size_t max_batch_size_, size_t max_workspace_size_bytes_,
+ const tensorflow::grappler::GraphProperties& graph_properties_,
+ std::unordered_map<string, std::pair<int, string>>*
+ output_edge_map_,
+ tensorflow::NodeDef* trt_node_,
+ int precision_mode_ = 0)
+ : graph(graph_),
+ subgraph_node_ids(subgraph_node_ids_),
+ input_inds(input_inds_),
+ output_inds(output_inds_),
+ max_batch_size(max_batch_size_),
+ max_workspace_size_bytes(max_workspace_size_bytes_),
+ graph_properties(graph_properties_),
+ output_edge_map(output_edge_map_),
+ trt_node(trt_node_),
+ precision_mode(precision_mode_) {}
+ tensorflow::Graph& graph;
+ const std::set<int>& subgraph_node_ids;
+ const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
+ size_t max_batch_size;
+ size_t max_workspace_size_bytes;
+ const tensorflow::grappler::GraphProperties& graph_properties;
+ std::unordered_map<string, std::pair<int, string>>* output_edge_map;
+ tensorflow::NodeDef* trt_node;
+ const int precision_mode;
+};
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams &params);
+tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
+tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
+ tensorflow::Node* c_node);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
-#endif // GOOGLE_TENSORRT
-#endif // GOOGLE_CUDA
+//#endif // GOOGLE_TENSORRT
+//#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
new file mode 100644
index 0000000000..7cd41c4933
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
@@ -0,0 +1,111 @@
+//
+// Created by skama on 1/25/18.
+//
+
+#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h"
+#include "tensorrt/include/NvInfer.h"
+#include <cuda_runtime_api.h>
+#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h"
+#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h"
+#include "tensorflow/contrib/tensorrt/resources/TRTResources.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+namespace trt {
+TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_));
+ OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_));
+ OP_REQUIRES_OK(context, context->GetAttr("resource_name", &repo_name));
+};
+
+#define TYPECASE(dt, X, Y) \
+ case dt: { \
+ Y = (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
+ break; \
+ }
+#define GET_TENSOR_ADDRESS(tensor_ptr, dest_ptr) \
+ { \
+ auto TENSOR_TYPE = tensor_ptr->dtype(); \
+ switch (TENSOR_TYPE) { \
+ TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); \
+ TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); \
+ TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); \
+ default: { \
+ LOG(FATAL) << "Unsupported Data type " \
+ << tensorflow::DataTypeString(TENSOR_TYPE); \
+ break; \
+ } \
+ } \
+ }
+void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
+ auto trt_rm = tensorflow::trt::TRTResourceManager::instance();
+ VLOG(2) << "Op Name= " << name() << " nodedef name= " << repo_name;
+ auto resmgr = trt_rm->getManager("TRTCalibOps");
+ tensorflow::trt::TRTCalibrationResource* calibRes = nullptr;
+ auto status = resmgr->Lookup(repo_name, repo_name, &calibRes);
+ if (status.ok()) {
+ int batchSize = ctx->input(0).dim_size(0);
+ VLOG(2) << "SAMI Batchsize= " << batchSize;
+ int numInputs = ctx->num_inputs();
+ VLOG(2) << "SAMI numInputs= " << numInputs;
+ dev_tensors_.resize(numInputs);
+ if (calibRes->calibrator == nullptr) {
+ VLOG(1) << " Constructing calibrator";
+ // first run
+ for (int i = 0; i < numInputs; i++) {
+ const tensorflow::Tensor& t = ctx->input(i);
+ VLOG(1) << "Tensor " << i << " " << t.shape().DebugString();
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_persistent(t.dtype(), t.shape(),
+ &dev_tensors_.at(i), nullptr));
+ const auto dTensor = dev_tensors_.at(i).AccessTensor(ctx);
+ CHECK_EQ(t.TotalBytes(), dTensor->TotalBytes());
+ void* devAddr = nullptr;
+ GET_TENSOR_ADDRESS(dTensor, devAddr)
+ device_buffers_.emplace(
+ input_names_.at(i),
+ std::pair<void*, size_t>(devAddr, dTensor->TotalBytes()));
+ }
+ calibRes->calibrator = new TRTInt8Calibrator(device_buffers_, batchSize);
+ calibRes->thr = new std::thread([calibRes]() {
+ VLOG(0)<<"Starting calibration thread, Calibration Resource @ "<<calibRes;
+ calibRes->builder->setInt8Calibrator(calibRes->calibrator);
+ calibRes->builder->setInt8Mode(true);
+ calibRes->engine = calibRes->builder->buildCudaEngine(
+ *calibRes->network); // will loop until we terminate calibrator
+ VLOG(0) << "SAMI Calibration loop terminated";
+ });
+ VLOG(0) << "SAMI initialized calibrator resource";
+ }
+
+ std::unordered_map<string, void*> input_data;
+ for (int i = 0; i < numInputs; i++) {
+ const Tensor& t = ctx->input(i);
+ void* data_address = nullptr;
+ const Tensor* t_ptr = &t;
+ GET_TENSOR_ADDRESS(t_ptr, data_address);
+ const auto dTensor = dev_tensors_.at(i).AccessTensor(ctx);
+ CHECK_EQ(t.TotalBytes(),
+ dTensor->TotalBytes()); // use the tensor so FW keeps it
+ input_data.emplace(input_names_.at(i), data_address);
+ ctx->set_output(i, t);
+ }
+ VLOG(1) << "Filled map for sending";
+ calibRes->calibrator->setBatch(input_data);
+ VLOG(1) << "Passed calibration data";
+ } else {
+ ctx->SetStatus(status);
+ return;
+ }
+};
+
+#undef TYPECASE
+
+REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp);
+
+} // namespace trt
+} // namespace tensorflow \ No newline at end of file
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h
new file mode 100644
index 0000000000..792e7bae4c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h
@@ -0,0 +1,36 @@
+//
+// Created by skama on 1/25/18.
+//
+
+#ifndef TFGITHUB_TRT_CALIB_OP_H
+#define TFGITHUB_TRT_CALIB_OP_H
+
+#include <memory>
+#include <string>
+#include <vector>
+#include <utility>
+#include <unordered_map>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace trt {
+class TRTCalibOp: public OpKernel {
+public:
+ explicit TRTCalibOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ std::string repo_name;
+ std::vector<std::string> segment_nodes_;
+ std::vector<std::string> input_names_;
+ std::vector<tensorflow::TensorShape> shapes_;
+ std::unordered_map<std::string, std::pair<void*, size_t>> device_buffers_;
+ std::vector<tensorflow::PersistentTensor> dev_tensors_;
+
+};
+}
+}
+#endif //TFGITHUB_TRT_CALIB_OP_H
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 8efdf63ebe..e4e8ab9e0a 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -24,8 +24,11 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
+static ::tensorflow::tensorrt::Logger gLogger;
+
+using IRuntime=nvinfer1::IRuntime;
+using Dims=nvinfer1::Dims;
namespace tensorrt {
-static ::tensorflow::tensorrt::Logger logger;
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// read serialized_engine
@@ -38,14 +41,22 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
// TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+
+ // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
+ // gpu where the input/output is also located.
+ int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
+ cudaSetDevice(gpu_id);
+ int device;
+ cudaGetDevice(&device);
+ if (gpu_id != device)
+ LOG(FATAL) << "set device failed!";
+
+ IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr));
-
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
- // Runtime is safe to delete after engine creation
infer->destroy();
}
@@ -64,10 +75,16 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
const TensorShape& input_shape = input_tensor.shape();
if (i == 0) {
num_batch = input_shape.dim_size(0);
+ if (num_batch > trt_engine_ptr_->getMaxBatchSize())
+ LOG(FATAL) << "input tensor batch larger than max_batch_size: "
+ << trt_engine_ptr_->getMaxBatchSize();
} else if (num_batch != input_shape.dim_size(0)) {
valid = false;
break;
}
+ // int64 input_shape.dim_size(int d)
+ // int input_shape.dims()
+ LOG(INFO) << "INPUT BINDING index: " << binding_index << " with name: " << input_nodes_[i];
switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
@@ -81,9 +98,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
}
}
- // Might want a different way to inform the user of batch size inconsistency
- if (!valid) LOG(WARNING) << "input data inconsistent batch size";
-
+ if (!valid) LOG(FATAL) << "input data inconsistent batch size";
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor
@@ -119,6 +134,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
break;
}
}
+ LOG(INFO) << "getting stream";
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
@@ -126,9 +142,11 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
->implementation()
->CudaStreamMemberHack()));
- // execution handled by TF since we are getting stream from TF.
- // it is safe for CPU pointer array (buffers) to go out of scope after enqueue
- trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
+ // TODO(jie): trt enqueue does not return error
+ auto ret=trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
+ VLOG(2) << "enqueue returns: " << ret;
+ // sync should be done by TF.
+ // cudaStreamSynchronize(*stream);
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc
new file mode 100644
index 0000000000..dab5a3e0e8
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+namespace tensorflow {
+
+
+REGISTER_OP("TRTCalibOp")
+ .Attr("segment_nodes: list(string)") // names of the ops in segment
+ .Attr("segment_output_names: list(string)") // names of the output ops in segment
+ .Attr("input_names: list(string)") // names of the inputs for passing into tensorrt
+ .Attr("resource_name: string")
+ .Attr("InT: list({int8, float16, float32})")
+ .Input("in_tensor: InT")
+ .Output("out_tensor: InT")
+ .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
+ for (int i = 0; i < c->num_inputs(); i++){
+ c->set_output(i, c->input(i));
+ }
+ return Status::OK();
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index 7e050a768c..3941d150d1 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -21,4 +21,5 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 9454862f85..94afb75897 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -20,10 +20,18 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
import six as _six
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert,calib_convert
+from tensorflow.python.util import compat
+import tensorflow as tf
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+
+
from tensorflow.python.framework import ops
@@ -32,7 +40,8 @@ from tensorflow.python.framework import ops
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
- max_workspace_size_bytes=2 << 20):
+ max_workspace_size_bytes=2 << 20,
+ precision_mode="FP32"):
"""Python wrapper for the TRT transormation.
@@ -48,7 +57,13 @@ def create_inference_graph(input_graph_def,
Raises:
RuntimeError: if the returned status message is malformed.
"""
-
+ supported_precision_modes={"FP32":0,
+ "FP16":1,
+ "INT8":2}
+ if precision_mode.upper() not in supported_precision_modes:
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(precision_mode,"{'FP32','FP16','INT8'}"))
+ mode=supported_precision_modes[precision_mode.upper()]
def py2bytes(inp):
return inp
@@ -83,7 +98,7 @@ def create_inference_graph(input_graph_def,
# pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
- max_workspace_size_bytes)
+ max_workspace_size_bytes,mode)
status = to_string(out[0])
output_graph_def_string = out[1]
del input_graph_def_str # Save some memory
@@ -101,3 +116,40 @@ def create_inference_graph(input_graph_def,
output_graph_def.ParseFromString(output_graph_def_string)
del output_graph_def_string # Save some memory
return output_graph_def
+
+def calib_graph_to_infer_graph(calibration_graph_def):
+ def py2bytes(inp):
+ return inp
+
+ def py3bytes(inp):
+ return inp.encode("utf-8", errors="surrogateescape")
+
+ def py2string(inp):
+ return inp
+
+ def py3string(inp):
+ return inp.decode("utf-8")
+
+ if _six.PY2:
+ to_bytes = py2bytes
+ to_string = py2string
+ else:
+ to_bytes = py3bytes
+ to_string = py3string
+
+ graph_str=calibration_graph_def.SerializeToString()
+ out=calib_convert(graph_str)
+ status=to_string(out[0])
+ output_graph_def_string = out[1]
+ del graph_str #save some memory
+ if len(status) < 2:
+ raise _impl.UnknownError(None,None,status)
+ if status[:2] != "OK":
+ msg=status.split(";")
+ if len(msg) == 1:
+ raise RuntimeError("Status message is malformed {}".format(status))
+ raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
+ output_graph_def = graph_pb2.GraphDef()
+ output_graph_def.ParseFromString(output_graph_def_string)
+ del output_graph_def_string #save some memory
+ return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc
new file mode 100644
index 0000000000..f5dc4886af
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.cc
@@ -0,0 +1,132 @@
+//
+// Created by skama on 1/24/18.
+//
+
+#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h"
+
+#include <cuda_runtime_api.h>
+#include <atomic>
+#include <chrono>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace trt {
+// set the batch size before constructing the thread to execute engine
+int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }
+
+TRTInt8Calibrator::TRTInt8Calibrator(const std::unordered_map<
+ string, std::pair<void*, size_t>>& dev_buffers,
+ int batch_size)
+ : batch_size_(batch_size),
+ done_(false),
+ dev_buffers_(dev_buffers),
+ calib_running_(false){
+ cudaPointerAttributes pa;
+ int devid=-1;
+ cudaGetDevice(&devid);
+ VLOG(0)<<"Constructing calibrator with batch size "<<batch_size<<" on device"<<devid;
+ for(auto b : dev_buffers_) {
+ if(cudaPointerGetAttributes(&pa,b.second.first)==cudaSuccess){
+ VLOG(1) << "CALIBRATOR Device buffer name " << b.first << " size" << b.second.second
+ << " @ " << b.second.first << " onDevice "<<((pa.memoryType==cudaMemoryTypeHost)?"HOST":"DEVICE");
+ }else {
+ VLOG(1) << "CALIBRATOR Device buffer name " << b.first << " size" << b.second.second << " @ " << b.second.first;
+ }
+ }
+}
+
+bool TRTInt8Calibrator::setBatch(
+ const std::unordered_map<string, void*>& data) {
+ VLOG(1)<<"SAMI SAMI Waiting to set new batch";
+ if(done_)return false;
+ while (calib_running_.load(
+ std::memory_order_acquire)) { // wait while calibration is running
+ tensorflow::mutex_lock l(cond_mtx_);
+ cond_.wait_for(l, std::chrono::milliseconds(50));
+ if(done_)return false;
+ }
+ VLOG(1)<<"Set Batch Waiting finished";
+ for (const auto it : data) {
+
+ auto devptr = dev_buffers_.find(it.first);
+ if (devptr == dev_buffers_.end()) {
+ LOG(FATAL) << "FATAL input name '" << it.first
+ << "' does not match with the buffer names";
+ }
+ cudaPointerAttributes pa;
+ const auto& d = devptr->second;
+ VLOG(1)<<"cuda memcopy buff name= "<<it.first<<" dst= "
+ <<d.first<<" size= "<<d.second<<" inp= "<<it.second;
+ if(cudaPointerGetAttributes(&pa,it.second)==cudaSuccess) {
+ VLOG(1) << "CALIBRATOR Device buffer name " << it.first << " size" << d.second
+ << " @ " << d.first << " onDevice " << ((pa.memoryType == cudaMemoryTypeHost) ? "HOST" : "DEVICE");
+ }
+
+ auto status =
+ cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice);
+ if (status != cudaSuccess) {
+ LOG(FATAL) << "cudaMemcpy for '" << it.first << "' failed with "
+ << status;
+ }
+ float f[2];
+ f[0]=3.;
+ f[1]=0.14159;
+ status=cudaMemcpy(f,d.first,sizeof(float)*2,cudaMemcpyDeviceToHost);
+ int devid=-1;
+ cudaGetDevice(&devid);
+ VLOG(0)<<"SAMI ORDER SETTING Data in perm storage [0]="<<f[0]<<" [1]="<<f[1]<<" current device="<<devid;
+ }
+ calib_running_.store(true, std::memory_order_release); // release builder
+ cond_.notify_all();
+ return true;
+}
+
+bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
+ int nbBindings) {
+ calib_running_.store(false, std::memory_order_release); // wait for new batch
+ VLOG(1)<<"SAMI SAMI Calibrator is waiting for new batch";
+ cond_.notify_all();
+ while (!calib_running_.load(
+ std::memory_order_acquire)) { // wait until new batch arrives
+ tensorflow::mutex_lock l(cond_mtx_);
+ cond_.wait_for(l, std::chrono::milliseconds(50));
+ if(done_)return false;
+ }
+ if (done_) {
+ return false;
+ }
+
+ for (int i = 0; i < nbBindings; i++) {
+ auto it = dev_buffers_.find(names[i]);
+ if (it == dev_buffers_.end()) {
+ LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
+ << names[i] << "' at position " << i;
+ }
+ VLOG(1)<<"Setting buffer "<< i <<" named=" << names[i] <<" @ "<<it->second.first;
+ bindings[i] = it->second.first;
+ float f[2];
+ f[0]=3.;
+ f[1]=0.14159;
+ auto status=cudaMemcpy(f,bindings[i],sizeof(float)*2,cudaMemcpyDeviceToHost);
+ int devid=-1;
+ cudaGetDevice(&devid);
+ VLOG(0)<<"SAMI ORDER GETTING, Data in perm storage [0]="<<f[0]<<" [1]="
+ <<f[1]<<" on device="<<devid;
+
+ }
+ return true;
+}
+const void *TRTInt8Calibrator::readCalibrationCache(std::size_t &length) {
+ return nullptr;
+}
+void TRTInt8Calibrator::writeCalibrationCache(const void *ptr, std::size_t length) {
+
+}
+TRTInt8Calibrator::~TRTInt8Calibrator() {
+ VLOG(1)<<"Destroying calibrator";
+}
+
+} // namespace trt
+} // namespace tensorflow \ No newline at end of file
diff --git a/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h
new file mode 100644
index 0000000000..b8bf55f56e
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h
@@ -0,0 +1,39 @@
+//
+// Created by skama on 1/24/18.
+//
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
+
+#include "tensorrt/include/NvInfer.h"
+#include <atomic>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include "tensorflow/core/platform/mutex.h"
+namespace tensorflow {
+namespace trt {
+
+struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
+ public:
+ TRTInt8Calibrator(const std::unordered_map<
+ string, std::pair<void*, size_t>>& dev_buffers,
+ int batch_size);
+ int getBatchSize() const;
+ bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
+ bool setBatch(const std::unordered_map<string, void*> &data);
+ void setDone(){done_=true;}
+ const void *readCalibrationCache(std::size_t &length) override;
+ void writeCalibrationCache(const void *ptr, std::size_t length) override;
+ ~TRTInt8Calibrator();
+ private:
+ int batch_size_;
+ tensorflow::mutex cond_mtx_;
+ tensorflow::condition_variable cond_;
+ bool done_;
+ const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
+ std::atomic_bool calib_running_;
+};
+} // namespace trt
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTINT8CALIBRATOR_H_
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc
new file mode 100644
index 0000000000..62d27c1104
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.cc
@@ -0,0 +1,21 @@
+//
+// Created by skama on 1/23/18.
+//
+
+#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h"
+#include "tensorflow/core/platform/default/logging.h"
+
+
+std::shared_ptr<tensorflow::ResourceMgr> tensorflow::trt::TRTResourceManager::getManager(const std::string &mgr_name) {
+ // mutex is held for lookup only. Most instantiations where mutex will be held longer
+ // will be during op creation and should be ok.
+ tensorflow::mutex_lock lock(map_mutex_);
+ auto s=managers_.find(mgr_name);
+ if(s==managers_.end()){
+ auto it=managers_.emplace(mgr_name,std::make_shared<tensorflow::ResourceMgr>(mgr_name));
+ VLOG(0)<<"Returning a new manager "<<mgr_name;
+ return it.first->second;
+ }
+ VLOG(1)<<"Returning old manager "<<mgr_name;
+ return s->second;
+}
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResourceManager.h b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.h
new file mode 100644
index 0000000000..e3b50093e7
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/TRTResourceManager.h
@@ -0,0 +1,37 @@
+//
+// Created by skama on 1/23/18.
+//
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCEMANAGER_H_
+
+#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRTRESOURCEMANAGER_H_
+#include <memory>
+
+#include <string>
+#include <unordered_map>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace trt {
+class TRTResourceManager {
+ TRTResourceManager() = default;
+
+ public:
+ static std::shared_ptr<TRTResourceManager> instance() {
+ static std::shared_ptr<TRTResourceManager> instance_(
+ new TRTResourceManager);
+ return instance_;
+ }
+ // returns a manager for given op, if it doesn't exists it creates one
+ std::shared_ptr<tensorflow::ResourceMgr> getManager(
+ const string& op_name);
+
+ private:
+ std::unordered_map<string, std::shared_ptr<tensorflow::ResourceMgr>>
+ managers_;
+ tensorflow::mutex map_mutex_;
+};
+} // namespace trt
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCEMANAGER_H_
diff --git a/tensorflow/contrib/tensorrt/resources/TRTResources.h b/tensorflow/contrib/tensorrt/resources/TRTResources.h
new file mode 100644
index 0000000000..cd23100af8
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/TRTResources.h
@@ -0,0 +1,59 @@
+//
+// Created by skama on 1/23/18.
+//
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
+
+#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
+
+#include <string>
+#include <sstream>
+#include "tensorrt/include/NvInfer.h"
+#include <thread>
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+namespace trt {
+
+struct TRTCalibrationResource : public tensorflow::ResourceBase {
+ TRTCalibrationResource()
+ : calibrator(nullptr),
+ builder(nullptr),
+ network(nullptr),
+ engine(nullptr),
+ logger(nullptr),
+ thr(nullptr) {}
+ string DebugString() override {
+ std::stringstream oss;
+#define VALID_OR_NULL(ptr) (!ptr ? "nullptr" : std::hex<<(void)ptr<<std::dec<<std::endl)
+ oss<<" Calibrator = "<<std::hex<<calibrator<<std::dec<<std::endl
+ <<" Builder = "<<std::hex<<builder<<std::dec<<std::endl
+ <<" Network = "<<std::hex<<network<<std::dec<<std::endl
+ <<" Engine = "<<std::hex<<engine<<std::dec<<std::endl
+ <<" Logger = "<<std::hex<<logger<<std::dec<<std::endl
+ <<" Thread = "<<std::hex<<thr<<std::dec<<std::endl;
+ return oss.str();
+#undef VALID_OR_NULL
+ }
+ ~TRTCalibrationResource(){
+ VLOG(0)<<"Destroying Calibration Resource "<<std::endl<<DebugString();
+ }
+ TRTInt8Calibrator* calibrator;
+ nvinfer1::IBuilder* builder;
+ nvinfer1::INetworkDefinition* network;
+ nvinfer1::ICudaEngine* engine;
+ tensorflow::tensorrt::Logger* logger;
+ std::thread* thr;
+};
+
+struct TRTEngineResource : public tensorflow::ResourceBase {
+ TRTEngineResource() : runtime(nullptr), ctx(nullptr){};
+ nvinfer1::IRuntime* runtime;
+ nvinfer1::IExecutionContext* ctx;
+};
+
+} // namespace trt
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index d679945d56..0ae3c91a63 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -64,13 +64,16 @@ PyObject* pair_helper(std::pair<string, string>* in) {
%ignoreall
%unignore tensorflow;
%unignore trt_convert;
+%unignore calib_convert;
%{
+
std::pair<string, string> trt_convert(
string graph_def_string, // The serialized GraphDef string.
std::vector<string> output_names,
size_t max_batch_size,
- size_t max_workspace_size_bytes
+ size_t max_workspace_size_bytes,
+ int precision_mode
// Unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
@@ -90,16 +93,19 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{out_status, ""};
}
+ if(precision_mode < 0 || precision_mode > 2){
+ out_status = "InvalidArgument;Invalid precision_mode";
+ return std::pair<string, string>{out_status, ""};
+ }
if (!output_names.size()) {
out_status = "InvalidArgument;Size of the output_names vector is 0";
return std::pair<string, string>{out_status, ""};
- // return "";
}
tensorflow::GraphDef outGraph;
tensorflow::Status conversion_status =
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &outGraph);
+ &outGraph, precision_mode);
if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code();
char buff[2000];
@@ -120,12 +126,60 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
}
+
+std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef&
+ // unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+) {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ string out_status;
+
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def,
+ &outGraph);
+ if (!conversion_status.ok()) {
+ auto retCode = (int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff, 2000, "%d;%s", retCode,
+ conversion_status.error_message().c_str());
+ out_status=buff;
+ return std::pair<string, string>{out_status, ""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+ out_status="OK;All good!";
+ return std::pair<string, string>{out_status, result};
+#else
+ // Returns FAILED_PRECONDITION.
+ return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+}
%}
+std::pair<string, string> calib_convert(string graph_def_string);
+
std::pair<string, string> trt_convert(string graph_def_string,
std::vector<string> output_names,
size_t max_batch_size,
- size_t max_workspace_size_bytes);
+ size_t max_workspace_size_bytes,
+ int precision_mode);
%unignoreall