aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Guangda Lai <laigd@google.com>2018-08-21 12:35:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 12:47:00 -0700
commitd648d7e6e12774d5c60418a899d15b81a387c770 (patch)
tree4380806c2e4996a9b429ca882a0d15f7975a533a /tensorflow/contrib/tensorrt
parent4f41091f88cca9c87a627864ccd6962e7bb44313 (diff)
Initialize TRTOptimizationPass members in the constructor, and use a util
function to get the precision mode. PiperOrigin-RevId: 209641428
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc23
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h8
2 files changed, 10 insertions, 21 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index f33f2cc4d6..ff4fba58bf 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -14,6 +14,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
@@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
VLOG(1) << "Called INIT for " << name_ << " with config = " << config;
if (config == nullptr) {
- maximum_workspace_size_ = 2 << 30;
return tensorflow::Status::OK();
}
const auto params = config->parameter_map();
@@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init(
if (params.count("max_batch_size")) {
maximum_batch_size_ = params.at("max_batch_size").i();
}
- is_dynamic_op_ = false;
if (params.count("is_dynamic_op")) {
is_dynamic_op_ = params.at("is_dynamic_op").b();
}
@@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init(
batches_.push_back(i);
}
}
- max_cached_batches_ = 1;
if (params.count("maximum_cached_engines")) {
max_cached_batches_ = params.at("maximum_cached_engines").i();
}
if (params.count("max_workspace_size_bytes")) {
- maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
+ max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
}
if (params.count("precision_mode")) {
- string pm = Uppercase(params.at("precision_mode").s());
- if (pm == "FP32") {
- precision_mode_ = 0;
- } else if (pm == "FP16") {
- precision_mode_ = 1;
- } else if (pm == "INT8") {
- precision_mode_ = 2;
- } else {
- LOG(ERROR) << "Unknown precision mode '" << pm << "'";
- return tensorflow::errors::InvalidArgument(
- "Unknown precision mode argument" + pm +
- " Valid values are FP32, FP16, INT8");
- }
+ TF_RETURN_IF_ERROR(GetPrecisionMode(
+ Uppercase(params.at("precision_mode").s()), &precision_mode_));
}
return tensorflow::Status::OK();
}
@@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize(
cp.input_graph_def = &item.graph;
cp.output_names = &nodes_to_preserve;
cp.max_batch_size = maximum_batch_size_;
- cp.max_workspace_size_bytes = maximum_workspace_size_;
+ cp.max_workspace_size_bytes = max_workspace_size_bytes_;
cp.output_graph_def = optimized_graph;
cp.precision_mode = precision_mode_;
cp.minimum_segment_size = minimum_segment_size_;
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
index 463ed3883e..71b51d1368 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
@@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
minimum_segment_size_(3),
precision_mode_(0),
maximum_batch_size_(-1),
- maximum_workspace_size_(-1) {
+ is_dynamic_op_(false),
+ max_cached_batches_(1),
+ max_workspace_size_bytes_(256LL << 20) {
VLOG(1) << "Constructing " << name_;
}
@@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
const tensorflow::grappler::GrapplerItem& item);
private:
- string name_;
+ const string name_;
int minimum_segment_size_;
int precision_mode_;
int maximum_batch_size_;
bool is_dynamic_op_;
std::vector<int> batches_;
int max_cached_batches_;
- int64_t maximum_workspace_size_;
+ int64_t max_workspace_size_bytes_;
};
} // namespace convert