aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-05-02 15:50:03 -0700
committerGravatar GitHub <noreply@github.com>2018-05-02 15:50:03 -0700
commit187cd5da6bf7bf28a873cd043905f354f047c988 (patch)
tree05354fded4c81f4b112dae9f67f9fbdfc0215d8c
parentc7a00a3731e08a1b532f317f6ef8b9abc884f127 (diff)
parentbf70368d36df3ee9a16f5285940d73fb54d911c0 (diff)
Merge pull request #18909 from samikama/optimization_pass
Optimization pass and Memory allocator integration
-rw-r--r--tensorflow/contrib/tensorrt/BUILD11
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc119
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h10
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc45
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h14
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc246
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h73
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc52
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h11
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.cc62
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.h68
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h44
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc379
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h18
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc16
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py64
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py19
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh2
18 files changed, 1069 insertions, 184 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 742be7baf0..675f0b1fd6 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -197,10 +197,12 @@ tf_py_wrap_cc(
tf_cuda_library(
name = "trt_resources",
srcs = [
+ "resources/trt_allocator.cc",
"resources/trt_int8_calibrator.cc",
"resources/trt_resource_manager.cc",
],
hdrs = [
+ "resources/trt_allocator.h",
"resources/trt_int8_calibrator.h",
"resources/trt_resource_manager.h",
"resources/trt_resources.h",
@@ -221,18 +223,24 @@ tf_cuda_library(
srcs = [
"convert/convert_graph.cc",
"convert/convert_nodes.cc",
+ "convert/trt_optimization_pass.cc",
],
hdrs = [
"convert/convert_graph.h",
"convert/convert_nodes.h",
+ "convert/trt_optimization_pass.h",
],
deps = [
":segment",
":trt_logging",
":trt_resources",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
"//tensorflow/core:framework_lite",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@@ -241,8 +249,7 @@ tf_cuda_library(
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_properties",
- "//tensorflow/core/grappler/optimizers:constant_folding",
- "//tensorflow/core/grappler/optimizers:layout_optimizer",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 0774027711..4df54a749f 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -24,6 +24,9 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -31,8 +34,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/constant_folding.h"
-#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -144,7 +146,8 @@ struct ConvertGraphParams {
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
const tensorflow::grappler::GraphProperties& current_graph_properties,
std::unordered_map<string, std::pair<int, string>>* output_edges,
- int engine_precision_mode)
+ int engine_precision_mode, const string& device_name,
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator, int cuda_gpu_id)
: graph(inp_graph),
output_names(output_node_names),
subgraph_node_ids(subgraph_node_id_numbers),
@@ -152,7 +155,10 @@ struct ConvertGraphParams {
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
graph_properties(current_graph_properties),
output_edge_map(output_edges),
- precision_mode(engine_precision_mode) {}
+ precision_mode(engine_precision_mode),
+ device_name_(device_name),
+ allocator_(allocator),
+ cuda_gpu_id_(cuda_gpu_id) {}
tensorflow::Graph& graph;
const std::vector<string>& output_names;
const std::set<int>& subgraph_node_ids;
@@ -161,6 +167,9 @@ struct ConvertGraphParams {
const tensorflow::grappler::GraphProperties& graph_properties;
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
int precision_mode;
+ string device_name_;
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
+ int cuda_gpu_id_;
std::vector<std::pair<int, int>> subgraph_inputs;
std::vector<std::pair<int, int>> subgraph_outputs;
tensorflow::EdgeSet subgraph_incoming_edges;
@@ -194,7 +203,7 @@ static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
subgraph_outputs_set.begin(),
subgraph_outputs_set.end());
return tensorflow::Status::OK();
-};
+}
tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
@@ -203,7 +212,8 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
params->subgraph_inputs, params->subgraph_outputs,
params->max_batch_size, params->max_workspace_size_bytes,
params->graph_properties, params->output_edge_map,
- &trt_node_def, params->precision_mode);
+ &trt_node_def, params->precision_mode, params->device_name_,
+ params->allocator_, params->cuda_gpu_id_);
TF_RETURN_IF_ERROR(InjectCalibrationNode(s));
tensorflow::Status status;
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
@@ -233,7 +243,8 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
params->subgraph_inputs, params->subgraph_outputs,
params->max_batch_size, params->max_workspace_size_bytes,
params->graph_properties, params->output_edge_map,
- &trt_node_def, params->precision_mode);
+ &trt_node_def, params->precision_mode, params->device_name_,
+ params->allocator_, params->cuda_gpu_id_);
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s));
tensorflow::Status status;
tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
@@ -331,19 +342,12 @@ tensorflow::Status ConvertGraphDefToTensorRT(
// optimization pass
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
- tensorflow::GraphDef gdef;
-
- // Layout optimization
item.graph = graph_def;
- tensorflow::grappler::LayoutOptimizer optimizer;
- tensorflow::grappler::Cluster* cluster;
- // virtual cluster
tensorflow::DeviceProperties device_properties;
-
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
- cluster =
+ tensorflow::grappler::Cluster* cluster =
new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
// single machine
@@ -351,27 +355,38 @@ tensorflow::Status ConvertGraphDefToTensorRT(
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
VLOG(2) << "cpu_cores: " << num_cpu_cores;
VLOG(2) << "gpus: " << num_gpus;
-
- TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
-
- // constant folding
+ tensorflow::RewriterConfig rw_cfg;
+ tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg);
+ tensorflow::GraphDef gdef;
+ TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef));
item.graph = gdef;
- tensorflow::grappler::ConstantFolding fold(nullptr);
- TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
// AJ refactoring shape inference through grappler/GraphProperties.
tensorflow::grappler::GraphProperties static_graph_properties(item);
- TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
+ TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
// Build full graph
+
+ return ConvertAfterShapes(gdef, output_names, max_batch_size,
+ max_workspace_size_bytes, new_graph_def,
+ precision_mode, minimum_segment_size,
+ static_graph_properties, nullptr);
+}
+
+tensorflow::Status ConvertAfterShapes(
+ const tensorflow::GraphDef& gdef, const std::vector<string>& output_names,
+ size_t max_batch_size, size_t max_workspace_size_bytes,
+ tensorflow::GraphDef* new_graph_def, int precision_mode,
+ int minimum_segment_size,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ const tensorflow::grappler::Cluster* cluster) {
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorflow::tensorrt::segment::SegmentOptions segment_options;
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
gdef.library());
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), gdef, &graph));
- // Segment the graph into subgraphs that can be converted to TensorRT
- tensorflow::tensorrt::segment::SegmentOptions segment_options;
-
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
for (auto node : output_names) {
segment_options.exclude_node_list.insert(node);
@@ -381,7 +396,7 @@ tensorflow::Status ConvertGraphDefToTensorRT(
segment_options.minimum_segment_size = minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
- gdef, IsTensorRTCandidate, segment_options, &segments));
+ &graph, IsTensorRTCandidate, segment_options, &segments));
if (segments.size() > 1) {
VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
}
@@ -391,9 +406,21 @@ tensorflow::Status ConvertGraphDefToTensorRT(
int count = 0;
float total_num_nodes_in_segments = 0.;
for (auto s : segments) {
- total_num_nodes_in_segments += s.size();
+ total_num_nodes_in_segments += s.first.size();
}
- for (const std::set<string>& subgraph_node_names : segments) {
+ // We create the map here since cluster may not be available in all cases.
+ std::map<string, tensorflow::Device*> name_to_device_map;
+ if (cluster) {
+ // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a
+ // distributed environment, devices from different workers can have same
+ // short name.
+ for (const auto dm : cluster->GetDeviceSet()->devices()) {
+ name_to_device_map[dm->name()] = dm;
+ }
+ }
+ for (const auto& segment_nodes_and_device : segments) {
+ const std::set<string>& subgraph_node_names =
+ segment_nodes_and_device.first;
std::set<int> subgraph_node_ids;
size_t max_mem_per_engine =
max_workspace_size_bytes *
@@ -403,10 +430,40 @@ tensorflow::Status ConvertGraphDefToTensorRT(
oss << " " << node_name;
subgraph_node_ids.insert(node_map.at(node_name)->id());
}
- VLOG(2) << "Subgraph nodes" << oss.str();
+ VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second
+ << " : " << oss.str();
+ auto target_device =
+ name_to_device_map.find(segment_nodes_and_device.second);
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator(0);
+
+ int cuda_device_id = 0;
+ if (target_device != name_to_device_map.end()) {
+ tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id);
+ CudaGpuId cuda_gpu_id;
+ Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ if (!s.ok()) {
+ LOG(ERROR)
+ << "Cuda device identification failed, using device 0. Error= "
+ << s;
+ } else {
+ cuda_device_id = cuda_gpu_id.value();
+ }
+ tensorflow::GPUOptions gpuoptions;
+ // we need to us PM here since in python path there is no way to get to
+ // allocators
+ auto pm = tensorflow::ProcessState::singleton();
+ // this should be instantiated by now
+ auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
+ VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value()
+ << " cuda device= " << cuda_device_id << " at " << dev_allocator;
+ allocator = std::make_shared<TRTDeviceAllocator>(dev_allocator);
+ } else { // device unknown or not available
+ allocator = std::make_shared<TRTCudaAllocator>();
+ }
ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size,
- max_mem_per_engine, static_graph_properties,
- &output_edge_map, precision_mode);
+ max_mem_per_engine, graph_properties, &output_edge_map,
+ precision_mode, segment_nodes_and_device.second,
+ allocator, cuda_device_id);
if (precision_mode == INT8MODE) {
tensorflow::Status status = GetCalibNode(&p);
if (status != tensorflow::Status::OK()) {
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index e01e4a5328..65a67d7e73 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
@@ -43,6 +45,14 @@ tensorflow::Status ConvertGraphDefToTensorRT(
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode, int minimum_segment_size);
+// Method to call from optimization pass
+tensorflow::Status ConvertAfterShapes(
+ const tensorflow::GraphDef& graph, const std::vector<string>& output_names,
+ size_t max_batch_size, size_t max_workspace_size_bytes,
+ tensorflow::GraphDef* new_graph_def, int precision_mode,
+ int minimum_segment_size,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ const tensorflow::grappler::Cluster* cluster);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index b81ae9dc3e..4d3710a514 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -346,10 +346,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
break;
}
case tensorflow::DataType::DT_HALF: {
- Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
- istrides, static_cast<Eigen::half*>(
- const_cast<void*>(oweights->GetValues())),
- ostrides);
+ Reorder2(
+ {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
+ istrides,
+ static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
+ ostrides);
break;
}
default:
@@ -481,7 +482,7 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
- bool isFP16() { return fp16_; };
+ bool isFP16() { return fp16_; }
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
}
@@ -672,7 +673,7 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
return [](Eigen::half t) -> Eigen::half {
- return Eigen::half(1.0 / sqrt(float(t)));
+ return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
};
}
case OP_CATEGORY::NEG:
@@ -2246,8 +2247,12 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
auto op_res = new tensorflow::tensorrt::TRTCalibrationResource();
TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
op_res->logger_ = new tensorflow::tensorrt::Logger();
+ cudaSetDevice(s.cuda_gpu_id_);
op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_));
-
+ op_res->allocator_ = s.allocator_;
+#if NV_TENSORRT_MAJOR > 3
+ op_res->builder_->setGpuAllocator(s.allocator_.get());
+#endif
if (!op_res->builder_) {
return tensorflow::errors::Internal(
"failed to create TensorRT builder object");
@@ -2323,8 +2328,8 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
<< ", at node: " << node_name
<< "with output entry from shape_map: " << op_info_vec.size();
// TODO(ben,jie): update TRT input format/dimension
- nvinfer1::DimsCHW input_dim_psuedo_chw;
- for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+ nvinfer1::DimsCHW input_dim_pseudo_chw;
+ for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
// update the code once TRT 4.0 comes out.
@@ -2338,7 +2343,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(2) << "dimension: " << i
<< " , size: " << op_info.shape().dim(i).size();
- input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
+ input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
}
// TODO(ben,jie): proper way to restore input tensor name?
@@ -2349,7 +2354,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
input_names.push_back(input_tensor_name);
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
+ input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
if (!input_tensor)
return tensorflow::errors::InvalidArgument(
@@ -2476,13 +2481,15 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
// Topological order is needed to build TRT network
tensorflow::tensorrt::Logger trt_logger;
-
+ cudaSetDevice(s.cuda_gpu_id_);
auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
if (!trt_builder) {
return tensorflow::errors::Internal(
"Failed to create TensorRT builder object");
}
-
+#if NV_TENSORRT_MAJOR > 3
+ trt_builder->setGpuAllocator(s.allocator_.get());
+#endif
auto trt_network = infer_object(trt_builder->createNetwork());
if (!trt_network) {
return tensorflow::errors::Internal(
@@ -2565,8 +2572,8 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
<< ", at node: " << node_name
<< " with output entry from shape_map: " << op_info_vec.size();
// TODO(ben,jie): update TRT input format/dimension
- nvinfer1::DimsCHW input_dim_psuedo_chw;
- for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+ nvinfer1::DimsCHW input_dim_pseudo_chw;
+ for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
// update the code once TRT 4.0 comes out.
@@ -2580,7 +2587,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(2) << "dimension: " << i
<< " , size: " << op_info.shape().dim(i).size();
- input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
+ input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
}
// TODO(ben,jie): proper way to restore input tensor name?
@@ -2591,7 +2598,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
input_names.push_back(input_tensor_name);
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
+ input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
if (!input_tensor)
return tensorflow::errors::InvalidArgument(
@@ -2707,9 +2714,11 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
.Attr("input_nodes", input_names)
.Attr("output_nodes", output_names)
.Attr("OutT", output_dtypes)
+ .Device(s.device_name_)
.Finalize(s.trt_node);
- VLOG(0) << status.ToString() << " finished op building";
+ VLOG(0) << status.ToString() << " finished op building for " << engine_name
+ << " on device " << s.device_name_;
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 954a1e72f8..3f6592cd25 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -22,11 +22,11 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
-
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -48,7 +48,9 @@ struct SubGraphParams {
const tensorflow::grappler::GraphProperties& current_graph_properties,
std::unordered_map<string, std::pair<int, string>>* output_edges,
tensorflow::NodeDef* constructed_trt_node,
- int engine_precision_mode = FP32MODE)
+ int engine_precision_mode = FP32MODE, const string& device_name = "",
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator = nullptr,
+ int cuda_gpu_id = 0)
: graph(inp_graph),
subgraph_node_ids(subgraph_node_id_numbers),
input_inds(input_indices),
@@ -58,7 +60,10 @@ struct SubGraphParams {
graph_properties(current_graph_properties),
output_edge_map(output_edges),
trt_node(constructed_trt_node),
- precision_mode(engine_precision_mode) {}
+ precision_mode(engine_precision_mode),
+ device_name_(device_name),
+ allocator_(allocator),
+ cuda_gpu_id_(cuda_gpu_id) {}
tensorflow::Graph& graph;
const std::set<int>& subgraph_node_ids;
@@ -70,6 +75,9 @@ struct SubGraphParams {
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
tensorflow::NodeDef* trt_node;
const int precision_mode;
+ const string device_name_;
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
+ const int cuda_gpu_id_;
};
// TODO(sami): Replace references with const reference or pointers
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
new file mode 100644
index 0000000000..8f634b1f74
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -0,0 +1,246 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+// TODO(sami): Remove VLOG messages once the code matures
+using tensorflow::str_util::Uppercase;
+using tensorflow::strings::StrAppend;
+using tensorflow::strings::StrCat;
+
+tensorflow::Status TRTOptimizationPass::Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
+ VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
+ if (config == nullptr) {
+ maximum_workspace_size_ = 2 << 30;
+ return tensorflow::Status::OK();
+ }
+ const auto params = config->parameter_map();
+ if (params.count("minimum_segment_size")) {
+ minimum_segment_size_ = params.at("minimum_segment_size").i();
+ }
+ if (params.count("max_batch_size")) {
+ maximum_batch_size_ = params.at("max_batch_size").i();
+ }
+ if (params.count("max_workspace_size_bytes"))
+ maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
+ if (params.count("precision_mode")) {
+ string pm = Uppercase(params.at("precision_mode").s());
+ if (pm == "FP32") {
+ precision_mode_ = 0;
+ } else if (pm == "FP16") {
+ precision_mode_ = 1;
+ } else if (pm == "INT8") {
+ precision_mode_ = 2;
+ } else {
+ LOG(ERROR) << "Unknown precision mode '" << pm << "'";
+ return tensorflow::errors::InvalidArgument(
+ "Unknown precision mode argument" + pm +
+ " Valid values are FP32, FP16, INT8");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+void TRTOptimizationPass::PrintDebugInfo(
+ tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item) {
+ VLOG(1) << "Cluster = " << cluster;
+ string offset(" ");
+ string offset2 = StrCat(offset, offset);
+ string offset3 = StrCat(offset2, offset);
+ string offset4 = StrCat(offset2, offset2);
+ if (cluster) {
+ VLOG(1) << offset << "type = " << cluster->type();
+ VLOG(1) << offset << "num warmup steps = " << cluster->NumWarmupSteps();
+ const auto dev_names = cluster->GetDeviceNames();
+ if (dev_names.size()) {
+ VLOG(1) << offset << " Device names:";
+ for (const auto s : dev_names) {
+ VLOG(1) << offset2 << s;
+ }
+ }
+ std::unordered_map<string, uint64> peak_mem;
+ auto status = cluster->GetPeakMemoryUsage(&peak_mem);
+ if (status == tensorflow::Status::OK()) {
+ VLOG(1) << offset << "Peak Memory Usage :";
+ for (auto s : peak_mem) {
+ VLOG(1) << offset2 << s.first << " = " << s.second;
+ }
+ }
+
+ const auto dev_props = cluster->GetDevices();
+ if (dev_props.size()) {
+ VLOG(1) << offset << "Device properties:";
+ for (auto k : dev_props) {
+ VLOG(1) << offset2 << k.first;
+ const auto& dt = k.second;
+ VLOG(1) << offset3 << "type = " << dt.type();
+ VLOG(1) << offset3 << "vendor = " << dt.vendor();
+ VLOG(1) << offset3 << "model = " << dt.model();
+ VLOG(1) << offset3 << "frequency = " << dt.frequency();
+ VLOG(1) << offset3 << "num cores = " << dt.num_cores();
+ VLOG(1) << offset3 << "num registers = " << dt.num_registers();
+ VLOG(1) << offset3 << "L1 cache size = " << dt.l1_cache_size();
+ VLOG(1) << offset3 << "L2 cache size = " << dt.l2_cache_size();
+ VLOG(1) << offset3 << "L3 cache size = " << dt.l3_cache_size();
+ VLOG(1) << offset3 << "SHMem per SMP = "
+ << dt.shared_memory_size_per_multiprocessor();
+ VLOG(1) << offset3 << "memory size = " << dt.memory_size();
+ VLOG(1) << offset3 << "bandwidth = " << dt.bandwidth();
+ if (dt.environment_size()) {
+ VLOG(1) << offset3 << "environment :";
+ for (const auto e : dt.environment()) {
+ VLOG(1) << offset4 << e.first << " = " << e.second;
+ }
+ }
+ }
+ }
+ }
+ VLOG(1) << "item: " << item.id;
+ if (item.feed.size()) {
+ VLOG(1) << offset << "Feeds :";
+ for (const auto& f : item.feed) {
+ const auto& shape = f.second.shape();
+ VLOG(1) << offset2 << f.first << " = shaped " << shape.DebugString();
+ }
+ } else {
+ VLOG(1) << offset << "No Feeds";
+ }
+ if (item.fetch.size()) {
+ VLOG(1) << offset << "Fetches :";
+ for (const auto& f : item.fetch) {
+ VLOG(1) << offset2 << f;
+ }
+ } else {
+ VLOG(1) << offset << "No Fetches";
+ }
+
+ if (item.init_ops.size()) {
+ VLOG(1) << offset << "init ops :";
+ for (const auto& f : item.init_ops) {
+ VLOG(1) << offset2 << f;
+ }
+ } else {
+ VLOG(1) << offset << "No init ops";
+ }
+ VLOG(1) << "Save Op = " << item.save_op;
+ VLOG(1) << "Restore Op = " << item.restore_op;
+ VLOG(1) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor;
+ if (item.keep_ops.size()) {
+ VLOG(1) << offset << "keep ops :";
+ for (const auto& f : item.keep_ops) {
+ VLOG(1) << offset2 << f;
+ }
+ } else {
+ VLOG(1) << offset << "No keep ops";
+ }
+ VLOG(3) << item.graph.DebugString();
+ for (const auto dev : cluster->GetDeviceSet()->devices()) {
+ const auto& pname = dev->parsed_name();
+ VLOG(1) << "Device name= " << dev->name()
+ << " parsedname job= " << pname.job << " id= " << pname.id
+ << " has_id: " << pname.has_id << " has_job: " << pname.has_job
+ << "has_type: " << pname.has_type << " type =" << pname.type;
+ }
+}
+
+tensorflow::Status TRTOptimizationPass::Optimize(
+ tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) {
+ VLOG(1) << "Called TRTOptimization Pass " << name_;
+ if (VLOG_IS_ON(1)) {
+ PrintDebugInfo(cluster, item);
+ }
+ int max_dim = -1;
+ if (item.feed.size()) {
+ for (const auto& f : item.feed) {
+ const auto& shape = f.second.shape();
+ if (shape.dims() > 0) {
+ if (shape.dim_size(0) > max_dim) max_dim = shape.dim_size(0);
+ }
+ }
+ }
+ if (maximum_batch_size_ < 0) { // automatic batch size from input
+ if (max_dim > 0) {
+ maximum_batch_size_ = max_dim;
+ VLOG(1) << "Setting maximum batch size to " << max_dim;
+ } else {
+ maximum_batch_size_ = 128;
+ LOG(WARNING) << "Maximum batch size is not set"
+ " and can't be deduced from inputs setting it to"
+ << maximum_batch_size_
+ << ". Suggest configuring it from configuration parameters";
+ }
+ } else {
+ if (max_dim > maximum_batch_size_) {
+ LOG(WARNING) << "Configured batch size " << maximum_batch_size_
+ << " is less than input batch size " << max_dim
+ << " adjusting maximum batch size to match input batch size";
+ }
+ }
+ tensorflow::grappler::GraphProperties static_graph_properties(item);
+ TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
+ auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(
+ item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_,
+ optimized_graph, precision_mode_, minimum_segment_size_,
+ static_graph_properties, cluster);
+ VLOG(2) << optimized_graph->DebugString();
+ return status;
+}
+
+void TRTOptimizationPass::Feedback(
+ tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) {}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+class VerboseCustomGraphOptimizerRegistrar
+ : public tensorflow::grappler::CustomGraphOptimizerRegistrar {
+ public:
+ VerboseCustomGraphOptimizerRegistrar(
+ const tensorflow::grappler::CustomGraphOptimizerRegistry::Creator& cr,
+ const tensorflow::string& name)
+ : tensorflow::grappler::CustomGraphOptimizerRegistrar(cr, name) {
+ VLOG(1) << "Constructing a CustomOptimizationPass registration object for "
+ << name;
+ }
+};
+
+static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar(
+ []() {
+ VLOG(1)
+ << "Instantiating CustomOptimizationPass object TensorRTOptimizer";
+ return new tensorflow::tensorrt::convert::TRTOptimizationPass(
+ "TensorRTOptimizer");
+ },
+ ("TensorRTOptimizer"));
+
+#endif
+#endif
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
new file mode 100644
index 0000000000..d8ecead23e
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
@@ -0,0 +1,73 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/platform/logging.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
+ public:
+ TRTOptimizationPass(const string& name = "TRTOptimizationPass")
+ : name_(name),
+ minimum_segment_size_(3),
+ precision_mode_(0),
+ maximum_batch_size_(-1),
+ maximum_workspace_size_(-1) {
+ VLOG(1) << "Constructing " << name_;
+ }
+
+ string name() const override { return name_; };
+
+ tensorflow::Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer*
+ config = nullptr) override;
+
+ tensorflow::Status Optimize(tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override;
+
+ void PrintDebugInfo(tensorflow::grappler::Cluster* cluster,
+ const tensorflow::grappler::GrapplerItem& item);
+
+ private:
+ string name_;
+ int minimum_segment_size_;
+ int precision_mode_;
+ int maximum_batch_size_;
+ int64_t maximum_workspace_size_;
+};
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index b8f881ceb1..5c5b2e3c07 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -32,38 +32,39 @@ namespace tensorrt {
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// read serialized_engine
- string serialized_engine;
OP_REQUIRES_OK(context,
- context->GetAttr("serialized_engine", &serialized_engine));
+ context->GetAttr("serialized_engine", &serialized_engine_));
// register input output node name in trt_sub_graph
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+}
- // TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
- // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
- // gpu where the input/output is also located.
- int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
- cudaSetDevice(gpu_id);
- int device;
- cudaGetDevice(&device);
- if (gpu_id != device) LOG(FATAL) << "set device failed!";
-
+void TRTEngineOp::Compute(OpKernelContext* context) {
// TODO(samikama) runtime should be taken from a resourcemanager as well.
// Only engine should be in the op and context and runtime should be taken
// from resourcemanager
- IRuntime* infer = nvinfer1::createInferRuntime(logger);
- trt_engine_ptr_.reset(infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(), nullptr));
- trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
- // Runtime is safe to delete after engine creation
- infer->destroy();
-}
-
-void TRTEngineOp::Compute(OpKernelContext* context) {
+ if (!trt_execution_context_ptr_) {
+ IRuntime* infer = nvinfer1::createInferRuntime(logger);
+#if NV_TENSORRT_MAJOR > 3
+ auto device = context->device();
+ auto dev_allocator =
+ device->GetAllocator(tensorflow::AllocatorAttributes());
+ if (!dev_allocator) {
+ LOG(FATAL) << "Can't find device allocator for gpu device "
+ << device->name();
+ }
+ allocator_ = std::make_shared<TRTDeviceAllocator>(dev_allocator);
+ infer->setGpuAllocator(allocator_.get());
+#endif
+ trt_engine_ptr_.reset(infer->deserializeCudaEngine(
+ serialized_engine_.c_str(), serialized_engine_.size(), nullptr));
+ trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
+ // Runtime is safe to delete after engine creation
+ infer->destroy();
+ serialized_engine_.clear();
+ }
int num_binding = context->num_inputs() + context->num_outputs();
std::vector<void*> buffers(num_binding);
@@ -154,7 +155,12 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
VLOG(2) << "enqueue returns: " << ret;
// sync should be done by TF.
}
-
+TRTEngineOp::~TRTEngineOp() {
+ // Order matters!
+ trt_execution_context_ptr_.reset();
+ trt_engine_ptr_.reset();
+ allocator_.reset();
+}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 0964b4b18a..e613a71422 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -17,25 +17,28 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
#include <memory>
-#include <string>
#include <vector>
+#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
class Logger;
+// TODO(Sami): Remove this file?
class TRTEngineOp : public OpKernel {
public:
explicit TRTEngineOp(OpKernelConstruction* context);
void Compute(OpKernelContext* context) override;
+ ~TRTEngineOp();
private:
template <typename T>
@@ -51,6 +54,8 @@ class TRTEngineOp : public OpKernel {
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
+ string serialized_engine_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
new file mode 100644
index 0000000000..0f0508331c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+#if NV_TENSORRT_MAJOR > 2
+#include "cuda/include/cuda_runtime_api.h"
+
+namespace tensorflow {
+namespace tensorrt {
+void* TRTCudaAllocator::allocate(uint64_t size, uint64_t alignment,
+ uint32_t flags) {
+ assert((alignment & (alignment - 1)) == 0); // zero or a power of 2.
+ void* memory;
+ cudaMalloc(&memory, size);
+ return memory;
+}
+
+void TRTCudaAllocator::free(void* memory) { cudaFree(memory); }
+
+void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
+ uint32_t flags) {
+ assert((alignment & (alignment - 1)) == 0); // zero or a power of 2.
+ void* mem = allocator_->AllocateRaw(alignment, size);
+ VLOG(2) << "Allocated " << size << " bytes with alignment " << alignment
+ << " @ " << mem;
+ return mem;
+}
+
+TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator)
+ : allocator_(allocator) {
+ VLOG(1) << "Using " << allocator->Name() << " allocator from TensorFlow";
+}
+
+void TRTDeviceAllocator::free(void* memory) {
+ VLOG(2) << "Deallocating " << memory;
+ allocator_->DeallocateRaw(memory);
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif
+#endif
+#endif
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
new file mode 100644
index 0000000000..a0c2540a76
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
+
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/framework/allocator.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+#if NV_TENSORRT_MAJOR == 3
+// Define interface here temporarily until TRT 4.0 is released
+namespace nvinfer1 {
+class IGpuAllocator {
+ public:
+ virtual void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) = 0;
+ virtual void free(void* memory) = 0;
+};
+} // namespace nvinfer1
+#endif
+
+namespace tensorflow {
+namespace tensorrt {
+
+class TRTCudaAllocator : public nvinfer1::IGpuAllocator {
+ // Allocator implementation that is using cuda allocator instead of device
+ // allocator in case we can't get device allocator from TF.
+ public:
+ TRTCudaAllocator() {}
+ virtual ~TRTCudaAllocator() {}
+ void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override;
+ void free(void* memory) override;
+};
+
+class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
+ // Allocator implementation wrapping TF device allocators.
+ public:
+ TRTDeviceAllocator(tensorflow::Allocator* allocator);
+ virtual ~TRTDeviceAllocator() {}
+ void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override;
+ void free(void* memory) override;
+
+ private:
+ tensorflow::Allocator* allocator_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 3c85968ae7..e3469124ac 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -13,20 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
-#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_
#include <list>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
+
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/core/framework/resource_mgr.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
+
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
@@ -40,6 +43,11 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
engine_(nullptr),
logger_(nullptr),
thr_(nullptr) {}
+
+ ~TRTCalibrationResource() {
+ VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
+ }
+
string DebugString() override {
std::stringstream oss;
oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl
@@ -47,16 +55,17 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
<< " Network = " << std::hex << network_ << std::dec << std::endl
<< " Engine = " << std::hex << engine_ << std::dec << std::endl
<< " Logger = " << std::hex << logger_ << std::dec << std::endl
+ << " Allocator = " << std::hex << allocator_.get() << std::dec
+ << std::endl
<< " Thread = " << std::hex << thr_ << std::dec << std::endl;
return oss.str();
}
- ~TRTCalibrationResource() {
- VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
- }
+
TRTInt8Calibrator* calibrator_;
nvinfer1::IBuilder* builder_;
nvinfer1::INetworkDefinition* network_;
nvinfer1::ICudaEngine* engine_;
+ std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
tensorflow::tensorrt::Logger* logger_;
// TODO(sami): Use threadpool threads!
std::thread* thr_;
@@ -65,31 +74,28 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
class TRTWeightStore : public tensorflow::ResourceBase {
public:
TRTWeightStore() {}
- std::list<std::vector<uint8_t>> store_;
+
+ virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
+
string DebugString() override {
std::stringstream oss;
- size_t lenBytes = 0;
+ size_t len_bytes = 0;
for (const auto& v : store_) {
- lenBytes += v.size() * sizeof(uint8_t);
+ len_bytes += v.size() * sizeof(uint8_t);
}
oss << " Number of entries = " << store_.size() << std::endl
<< " Total number of bytes = "
- << store_.size() * sizeof(std::vector<uint8_t>) + lenBytes << std::endl;
+ << store_.size() * sizeof(std::vector<uint8_t>) + len_bytes
+ << std::endl;
return oss.str();
}
- virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
-};
-class TRTEngineResource : public tensorflow::ResourceBase {
- public:
- TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){};
- string DebugString() override { return string(""); }
- nvinfer1::IRuntime* runtime_;
- nvinfer1::IExecutionContext* ctx_;
+ std::list<std::vector<uint8_t>> store_;
};
} // namespace tensorrt
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_
+
#endif
#endif
+#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 8fc4697c51..cc42913eca 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -25,18 +25,239 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tensorrt {
namespace segment {
+using ::tensorflow::strings::StrAppend;
+// A simple graph representation to mirror tensorflow::Graph. This structure
+// helps saving memory since segmenter modifies the graph in place, preventing
+// the need to create a copy of the graph. It is composed of edges and nodes.
+// Nodes keep pointers to original TF nodes.
+class SimpleNode;
+class SimpleGraph;
+class SimpleEdge {
+ public:
+ SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
+ int dst_port, bool is_control = false)
+ : id_(id),
+ src_(src),
+ src_port_(src_port),
+ dst_(dst),
+ dst_port_(dst_port),
+ control_(is_control) {}
+ ~SimpleEdge() {}
+
+ SimpleNode* src() const { return src_; }
+ SimpleNode* dst() const { return dst_; }
+ int src_output() const { return src_port_; }
+ int dst_input() const { return dst_port_; }
+ int id() const { return id_; }
+ bool IsControlEdge() const { return control_; }
+
+ private:
+ int id_;
+ SimpleNode* src_;
+ int src_port_;
+ SimpleNode* dst_;
+ int dst_port_;
+ bool control_;
+};
+
+class SimpleNode {
+ public:
+ SimpleNode(const tensorflow::Node* node, const int id);
+
+ const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
+ const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
+ std::vector<SimpleNode*> in_nodes() const {
+ std::vector<SimpleNode*> res;
+ res.reserve(in_edges_.size());
+ for (const auto e : in_edges_) {
+ if (e) res.push_back(e->src());
+ }
+ return res;
+ }
+ const string& name() const { return node_->name(); }
+ const tensorflow::Node* tf_node() const { return node_; }
+ int id() const { return id_; }
+
+ private:
+ const tensorflow::Node* node_;
+ std::vector<SimpleEdge*> in_edges_;
+ std::vector<SimpleEdge*> out_edges_;
+ int id_;
+
+ friend class SimpleGraph;
+};
+
+class SimpleGraph {
+ public:
+ explicit SimpleGraph(const tensorflow::Graph* g);
+ ~SimpleGraph();
+
+ void AddControlEdge(SimpleNode* src, SimpleNode* dst);
+ void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
+ void RemoveEdge(const SimpleEdge*);
+ SimpleNode* FindNodeId(int node_id) {
+ if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
+ return nullptr;
+ }
+ return nodes_[node_id];
+ }
+ int num_node_ids() const { return nodes_.size(); }
+ const SimpleNode* source_node() const {
+ return nodes_[tensorflow::Graph::kSourceId];
+ }
+ const SimpleNode* sink_node() const {
+ return nodes_[tensorflow::Graph::kSinkId];
+ }
+
+ private:
+ const tensorflow::Graph* g_;
+ std::vector<SimpleNode*> nodes_;
+ std::vector<SimpleEdge*> edges_;
+ // free_edge_ids_ and free_node_ids_ contain freed indices.
+ std::set<int> free_edge_ids_;
+ std::set<int> free_node_ids_;
+};
+
+SimpleNode::SimpleNode(const tensorflow::Node* node, const int id)
+ : node_(node), id_(id) {
+ if (node_) {
+ in_edges_.reserve(node_->in_edges().size());
+ out_edges_.reserve(node_->out_edges().size());
+ }
+}
+
+SimpleGraph::SimpleGraph(const tensorflow::Graph* g) : g_(g) {
+ int n_nodes = g_->num_node_ids();
+ nodes_.resize(n_nodes, nullptr);
+ nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
+ nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
+ int n_edges = g->num_edge_ids();
+ edges_.resize(n_edges, nullptr);
+ for (int i = 2; i < n_nodes; i++) {
+ const auto n = g->FindNodeId(i);
+ if (n) {
+ nodes_[i] = new SimpleNode(n, i);
+ } else {
+ free_node_ids_.insert(i);
+ }
+ }
+ for (int i = 0; i < n_edges; i++) {
+ const auto e = g->FindEdgeId(i);
+ if (e) {
+ const auto tfsrc = e->src();
+ const auto tfdst = e->dst();
+ bool is_control = e->IsControlEdge();
+ auto src = nodes_[tfsrc->id()];
+ auto dst = nodes_[tfdst->id()];
+ auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
+ is_control);
+ edges_[i] = edge;
+ src->out_edges_.push_back(edge);
+ dst->in_edges_.push_back(edge);
+ } else {
+ free_edge_ids_.insert(i);
+ }
+ }
+}
+
+void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
+ int in_port) {
+ int i = edges_.size();
+ if (!free_edge_ids_.empty()) {
+ auto it = free_edge_ids_.begin();
+ i = *it;
+ free_edge_ids_.erase(it);
+ } else {
+ edges_.push_back(nullptr);
+ }
+ bool is_control = (out_port == tensorflow::Graph::kControlSlot);
+ is_control |= (in_port == tensorflow::Graph::kControlSlot);
+ auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
+ edges_[i] = edge;
+ src->out_edges_.push_back(edge);
+ dst->in_edges_.push_back(edge);
+}
+
+void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
+ AddEdge(src, tensorflow::Graph::kControlSlot, dst,
+ tensorflow::Graph::kControlSlot);
+}
+
+void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
+ auto src = edge->src();
+ auto dst = edge->dst();
+ for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
+ if (*it == edge) {
+ src->out_edges_.erase(it);
+ break;
+ }
+ }
+ for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
+ if (*it == edge) {
+ dst->in_edges_.erase(it);
+ break;
+ }
+ }
+}
+
+SimpleGraph::~SimpleGraph() {
+ for (auto x : nodes_) delete x;
+ for (auto x : edges_) delete x;
+}
namespace {
-bool CanContractEdge(const tensorflow::Edge* edge,
- const tensorflow::Graph& graph) {
- const tensorflow::Node* src = edge->src();
- const tensorflow::Node* dst = edge->dst();
+bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
+ const std::vector<SimpleNode*>& start) {
+ // copied from TF ReverseDFS.
+ struct Work {
+ SimpleNode* node;
+ bool leave; // Are we entering or leaving n?
+ };
+
+ std::vector<Work> stack(start.size());
+ for (int i = 0; i < start.size(); ++i) {
+ stack[i] = Work{start[i], false};
+ }
+
+ std::vector<bool> visited(g->num_node_ids(), false);
+ while (!stack.empty()) {
+ Work w = stack.back();
+ stack.pop_back();
+
+ auto n = w.node;
+ if (w.leave) {
+ if (n == src) {
+ return true;
+ }
+ continue;
+ }
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+ // Arrange to call leave(n) when all done with descendants.
+ stack.push_back(Work{n, true});
+
+ auto nodes = n->in_nodes();
+ for (const auto node : nodes) {
+ if (!visited[node->id()]) {
+ stack.push_back(Work{node, false});
+ }
+ }
+ }
+ return false;
+}
+
+bool CanContractEdge(const SimpleEdge* edge,
+ const std::unique_ptr<SimpleGraph>& graph) {
+ const auto src = edge->src();
+ const auto dst = edge->dst();
// Can't contract edge if doing so would cause a cycle in the
// graph. So, if there is a directed path from 'src' to 'dst', other
@@ -48,46 +269,38 @@ bool CanContractEdge(const tensorflow::Edge* edge,
// 1. Get all nodes incoming to 'dst', excluding 'src'
// 2. Reverse DFS from those nodes
// 3. If reverse DFS reaches 'src' then we have a cycle
- std::vector<tensorflow::Node*> dfs_start_nodes;
- for (tensorflow::Node* node : dst->in_nodes()) {
+ std::vector<SimpleNode*> dfs_start_nodes;
+ for (SimpleNode* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
- bool is_cycle = false;
- if (!dfs_start_nodes.empty()) {
- tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
- [&is_cycle, src](tensorflow::Node* node) {
- if (node == src) {
- is_cycle = true;
- }
- });
- }
-
+ bool is_cycle = CheckCycles(graph, src, dfs_start_nodes);
return !is_cycle;
}
+} // namespace
-void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
- std::vector<const tensorflow::Edge*>* remove_edges) {
+void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
+ std::vector<const SimpleEdge*>* remove_edges) {
// Transfer all inputs and outputs of 'dst' to 'src' except edges
// connecting the two.
- tensorflow::Node* src = edge->src();
- tensorflow::Node* dst = edge->dst();
+ auto src = edge->src();
+ auto dst = edge->dst();
// We can use '0' for input/output index because we don't need them
// to be accurate for the way we are using the graph.
- std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
- dst->in_edges().end());
- for (const tensorflow::Edge* in_edge : in_edges) {
+ std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
+ dst->in_edges().end());
+ for (const SimpleEdge* in_edge : in_edges) {
if (in_edge->IsControlEdge()) {
if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
graph->AddControlEdge(e->src(), src);
}
} else {
if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
if (e->src() == graph->source_node()) {
graph->AddEdge(e->src(), e->src_output(), src,
tensorflow::Graph::kControlSlot);
@@ -98,14 +311,14 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
}
}
- std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
- dst->out_edges().end());
- for (const tensorflow::Edge* out_edge : out_edges) {
+ std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
+ dst->out_edges().end());
+ for (const SimpleEdge* out_edge : out_edges) {
if (out_edge->IsControlEdge()) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
graph->AddControlEdge(src, e->dst());
} else {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
if (e->dst() == graph->sink_node()) {
VLOG(1) << " edge to sink node " << src->name() << " -> "
<< e->dst()->name();
@@ -128,8 +341,6 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
}
}
-} // namespace
-
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
@@ -140,17 +351,22 @@ tensorflow::Status SegmentGraph(
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), gdef, &graph));
+ return SegmentGraph(&graph, candidate_fn, options, segments);
+}
- // tensorflow::DumpGraph("Pre-Segment", &graph);
-
+tensorflow::Status SegmentGraph(
+ tensorflow::Graph* tf_graph,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments) {
+ auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
// Use a union-find to collect the nodes that belong to the same
- // segment. A node value of nullptr indicates that the node is not a
- // candidate for TRT.
- std::vector<UnionFind<tensorflow::Node*>> node_segments;
- for (int i = 0; i < graph.num_node_ids(); ++i) {
- tensorflow::Node* node = graph.FindNodeId(i);
+ // segment. A node value of nullptr indicates that the node is not a candidate
+ // for TRT.
+ std::vector<UnionFind<SimpleNode*>> node_segments;
+ for (int i = 0; i < graph->num_node_ids(); ++i) {
+ SimpleNode* node = graph->FindNodeId(i);
if (options.exclude_node_list.count(node->name()) != 0 ||
- !candidate_fn(node)) {
+ !candidate_fn(node->tf_node())) {
node = nullptr;
}
node_segments.emplace_back(node);
@@ -164,10 +380,16 @@ tensorflow::Status SegmentGraph(
// a measure of how beneficial it is to include a given node in a
// TRT subgraph then we can revisit this algorithm to take advantage
// of that information.
- std::vector<tensorflow::Node*> order;
- tensorflow::GetPostOrder(graph, &order);
-
- for (const tensorflow::Node* node : order) {
+ std::vector<tensorflow::Node*> tforder;
+ tensorflow::GetPostOrder(*tf_graph, &tforder);
+ // use postorder implementation from tensorflow and construct mirror in
+ // internal format
+ std::vector<SimpleNode*> order;
+ order.reserve(tforder.size());
+ for (const auto tfnode : tforder) {
+ order.push_back(graph->FindNodeId(tfnode->id()));
+ }
+ for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
@@ -181,8 +403,8 @@ tensorflow::Status SegmentGraph(
// nodes. Iterate since combining two nodes may unblock other
// combining.
while (true) {
- std::set<const tensorflow::Edge*> contract_edges;
- for (const tensorflow::Edge* out_edge : node->out_edges()) {
+ std::set<const SimpleEdge*> contract_edges;
+ for (const SimpleEdge* out_edge : node->out_edges()) {
VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
@@ -210,9 +432,9 @@ tensorflow::Status SegmentGraph(
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
- const tensorflow::Edge* contract_edge = *contract_edges.begin();
- const tensorflow::Node* src = contract_edge->src();
- const tensorflow::Node* dst = contract_edge->dst();
+ const SimpleEdge* contract_edge = *contract_edges.begin();
+ const SimpleNode* src = contract_edge->src();
+ const SimpleNode* dst = contract_edge->dst();
VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
@@ -221,13 +443,13 @@ tensorflow::Status SegmentGraph(
// Contracting the edge leaves disconnected graph edges.
// Remove these from the graph and from 'contract_edges' so we
// don't visit them again.
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
- std::vector<const tensorflow::Edge*> remove_edges;
- ContractEdge(e, &graph, &remove_edges);
+ SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
+ std::vector<const SimpleEdge*> remove_edges;
+ ContractEdge(e, graph.get(), &remove_edges);
- for (const tensorflow::Edge* r : remove_edges) {
+ for (const SimpleEdge* r : remove_edges) {
contract_edges.erase(r);
- graph.RemoveEdge(r);
+ graph->RemoveEdge(r);
}
}
}
@@ -236,9 +458,27 @@ tensorflow::Status SegmentGraph(
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
std::unordered_map<string, std::set<string>> sg_map;
+ std::unordered_map<string, std::set<string>> device_maps;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ auto tf_node = u.Value()->tf_node();
+ // has_assigned_device_name() is expected to return true
+ // when called from optimization pass. However, since graph
+ // is converted back and forth between graph and graphdef,
+ // assigned devices demoted to requested devices. If the graph
+ // is passed directly to this module, assigned devices will be set.
+ if (tf_node->has_assigned_device_name()) {
+ device_maps[u.ParentValue()->name()].insert(
+ tf_node->assigned_device_name());
+ } else if (!tf_node->requested_device().empty()) {
+ device_maps[u.ParentValue()->name()].insert(
+ tf_node->requested_device());
+ } else {
+ VLOG(1) << "Node " << tf_node->name()
+ << " has no device assigned requested device is: "
+ << tf_node->requested_device();
+ }
}
}
@@ -260,10 +500,35 @@ tensorflow::Status SegmentGraph(
<< segment_node_names.size() << " nodes, dropping";
continue;
}
-
- segments->emplace_back(segment_node_names);
+ // TODO(sami): Make segmenter placement aware once trtscopes are in place
+ const auto& dev_itr = device_maps.find(itr.first);
+ if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
+ VLOG(1) << "No device assigned to segment " << segments->size();
+ segments->emplace_back(std::make_pair(segment_node_names, string()));
+ } else if (dev_itr->second.size() > 1) {
+ string s("Segment ");
+ StrAppend(&s, segments->size(), " has multiple devices attached: ");
+ for (const auto& dev : dev_itr->second) {
+ StrAppend(&s, dev, ", ");
+ }
+ LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
+ segments->emplace_back(
+ std::make_pair(segment_node_names, *(dev_itr->second.begin())));
+ } else {
+ segments->emplace_back(
+ std::make_pair(segment_node_names, *(dev_itr->second.begin())));
+ }
+ }
+ if (VLOG_IS_ON(1)) {
+ for (const auto& d : device_maps) {
+ string s("Segment ");
+ StrAppend(&s, ": '", d.first, "' ");
+ for (const auto& dd : d.second) {
+ StrAppend(&s, dd, ", ");
+ }
+ VLOG(1) << "Devices " << s;
+ }
}
-
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 7e8685f44a..1568dd9153 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,7 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-using SegmentNodesVector = std::vector<std::set<string>>;
+// vector of segments, each entry contains a device name and a set of nodes in
+// segment
+using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
// Segment must contain at least this many nodes.
@@ -51,6 +53,20 @@ tensorflow::Status SegmentGraph(
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param graph tensorflow::Graph of the network
+// @param candidate_fn A function that returns true for a Node* if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ tensorflow::Graph* tf_graph,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
} // namespace segment
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 7ddabec268..8038085a06 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -35,7 +35,7 @@ class SegmentTest : public ::testing::Test {
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name);
- std::function<bool(const Node*)> MakeCandidateFn(
+ std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
const std::set<string>& node_names);
protected:
@@ -60,9 +60,9 @@ bool SegmentTest::GetGraphDef(TF_Graph* graph,
return ret;
}
-std::function<bool(const Node*)> SegmentTest::MakeCandidateFn(
+std::function<bool(const tensorflow::Node*)> SegmentTest::MakeCandidateFn(
const std::set<string>& node_names) {
- return [node_names](const Node* node) -> bool {
+ return [node_names](const tensorflow::Node* node) -> bool {
return node_names.find(node->name()) != node_names.end();
};
}
@@ -165,7 +165,7 @@ TEST_F(SegmentTest, Simple) {
ASSERT_EQ(segments.size(), 1);
std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
for (const auto& ex : expected) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);
@@ -278,13 +278,13 @@ TEST_F(SegmentTest, Multiple) {
std::vector<string> expected0{"add0", "add1", "add2", "add3"};
for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
std::vector<string> expected1{"add6", "add8"};
for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);
@@ -348,13 +348,13 @@ TEST_F(SegmentTest, BigIfElse) {
std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
std::vector<string> expected1{"add0", "add1"};
for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index ad01bedd8f..175ccd8006 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import numpy as np
+
# normally we should do import tensorflow as tf and then
# tf.placeholder, tf.constant, tf.nn.conv2d etc but
# it looks like internal builds don't like it so
@@ -26,6 +28,7 @@ import numpy as np
from tensorflow.contrib import tensorrt as trt
from tensorflow.core.protobuf import config_pb2 as cpb2
+from tensorflow.core.protobuf import rewriter_config_pb2 as rwpb2
from tensorflow.python.client import session as csess
from tensorflow.python.framework import constant_op as cop
from tensorflow.python.framework import dtypes as dtypes
@@ -59,9 +62,11 @@ def get_simple_graph_def():
return g.as_graph_def()
-def run_graph(gdef, dumm_inp):
+def execute_graph(gdef, dumm_inp):
"""Run given graphdef once."""
+ print("executing")
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
ops.reset_default_graph()
g = ops.Graph()
with g.as_default():
@@ -69,15 +74,14 @@ def run_graph(gdef, dumm_inp):
graph_def=gdef, return_elements=["input", "output"])
inp = inp.outputs[0]
out = out.outputs[0]
- with csess.Session(
- config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
+ with csess.Session(config=sessconfig, graph=g) as sess:
val = sess.run(out, {inp: dumm_inp})
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
-def run_calibration(gdef, dumm_inp):
+def execute_calibration(gdef, dumm_inp):
"""Run given calibration graph multiple times."""
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph()
@@ -96,7 +100,9 @@ def run_calibration(gdef, dumm_inp):
return val
-if "__main__" in __name__:
+def user(run_graph=execute_graph, run_calibration=execute_calibration):
+ """Example function that converts a graph to TFTRT graph."""
+
inp_dims = (100, 24, 24, 2)
dummy_input = np.random.random_sample(inp_dims)
orig_graph = get_simple_graph_def() # use a frozen graph for inference
@@ -137,3 +143,51 @@ if "__main__" in __name__:
assert np.allclose(o1, o4)
assert np.allclose(o1, o5)
print("Pass")
+
+
+def auto():
+ """Run the conversion as an optimization pass."""
+ inp_dims = (100, 24, 24, 2)
+ dummy_input = np.random.random_sample(inp_dims)
+ orig_graph = get_simple_graph_def()
+ opt_config = rwpb2.RewriterConfig()
+ opt_config.optimizers.extend(["constfold", "layout"])
+ custom_op = opt_config.custom_optimizers.add()
+ custom_op.name = "TensorRTOptimizer"
+ custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["precision_mode"].s = "FP32"
+ custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
+ custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
+ print(custom_op)
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ graph_options = cpb2.GraphOptions(rewrite_options=opt_config)
+ sessconfig = cpb2.ConfigProto(
+ gpu_options=gpu_options, graph_options=graph_options)
+ print(sessconfig)
+ g = ops.Graph()
+ ops.reset_default_graph()
+ with g.as_default():
+ inp, out = importer.import_graph_def(
+ graph_def=orig_graph, return_elements=["input", "output"])
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+ with csess.Session(config=sessconfig, graph=g) as sess:
+ val = sess.run(out, {inp: dummy_input})
+ print(val.shape)
+
+
+if "__main__" in __name__:
+ P = argparse.ArgumentParser(
+ prog="tftrt_test",
+ description="Example utilization of TensorFlow-TensorRT integration")
+ P.add_argument(
+ "--automatic",
+ "-a",
+ action="store_true",
+ help="Do TRT conversion automatically",
+ default=False)
+ flags, unparsed = P.parse_known_args()
+ if flags.automatic:
+ auto()
+ else:
+ user()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
index 7a47328762..a5c00dd633 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
@@ -45,8 +45,7 @@ class IntegrationTest(test_util.TensorFlowTestCase):
inp_dims = (100, 24, 24, 2)
self._input = np.random.random_sample(inp_dims)
self._original_graph = self.get_simple_graph_def()
- self._gpu_options = cpb2.GPUOptions(
- per_process_gpu_memory_fraction=0.50)
+ self._gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
self._config = cpb2.ConfigProto(gpu_options=self._gpu_options)
self._reference = self.run_graph(self._original_graph, self._input)
@@ -61,11 +60,7 @@ class IntegrationTest(test_util.TensorFlowTestCase):
name="weights",
dtype=dtypes.float32)
conv = nn.conv2d(
- input=a,
- filter=e,
- strides=[1, 2, 2, 1],
- padding="SAME",
- name="conv")
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
b = cop.constant(
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
t = nn.bias_add(conv, b, name="biasAdd")
@@ -86,8 +81,7 @@ class IntegrationTest(test_util.TensorFlowTestCase):
inp = inp.outputs[0]
out = out.outputs[0]
with self.test_session(
- graph=g, config=self._config, use_gpu=True,
- force_gpu=True) as sess:
+ graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess:
val = sess.run(out, {inp: dumm_inp})
return val
@@ -105,15 +99,14 @@ class IntegrationTest(test_util.TensorFlowTestCase):
# run over real calibration data here, we are mimicking a calibration
# set of 30 different batches. Use as much calibration data as you want
with self.test_session(
- graph=g, config=self._config, use_gpu=True,
- force_gpu=True) as sess:
+ graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess:
for _ in range(30):
val = sess.run(out, {inp: dumm_inp})
return val
def get_trt_graph(self, mode):
"""Return trt converted graph."""
- if mode in ["FP32", "FP16", "INT8"]:
+ if mode in ["FP32", "FP16", "INT8"]:
return trt.create_inference_graph(
input_graph_def=self._original_graph,
outputs=["output"],
@@ -121,7 +114,7 @@ class IntegrationTest(test_util.TensorFlowTestCase):
max_workspace_size_bytes=1 << 25,
precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8"
minimum_segment_size=2 # minimum number of nodes in an engine
- )
+ )
return None
def testFP32(self):
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 8f0cf8c3d1..3af79ee170 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -24,7 +24,7 @@ function real_path() {
function cp_external() {
local src_dir=$1
local dest_dir=$2
- for f in `find "$src_dir" -maxdepth 1 -mindepth 1 ! -name '*local_config_cuda*' ! -name '*org_tensorflow*'`; do
+ for f in `find "$src_dir" -maxdepth 1 -mindepth 1 ! -name '*local_config_cuda*' ! -name '*local_config_tensorrt*' ! -name '*org_tensorflow*'`; do
cp -R "$f" "$dest_dir"
done
mkdir -p "${dest_dir}/local_config_cuda/cuda/cuda/"