diff options
author | Guangda Lai <laigd@google.com> | 2018-08-21 12:35:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 12:47:00 -0700 |
commit | d648d7e6e12774d5c60418a899d15b81a387c770 (patch) | |
tree | 4380806c2e4996a9b429ca882a0d15f7975a533a /tensorflow/contrib/tensorrt | |
parent | 4f41091f88cca9c87a627864ccd6962e7bb44313 (diff) |
Initialize TRTOptimizationPass members in the constructor, and use a util
function to get the precision mode.
PiperOrigin-RevId: 209641428
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc | 23 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h | 8 |
2 files changed, 10 insertions, 21 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index f33f2cc4d6..ff4fba58bf 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -14,6 +14,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -37,7 +38,6 @@ tensorflow::Status TRTOptimizationPass::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { VLOG(1) << "Called INIT for " << name_ << " with config = " << config; if (config == nullptr) { - maximum_workspace_size_ = 2 << 30; return tensorflow::Status::OK(); } const auto params = config->parameter_map(); @@ -47,7 +47,6 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - is_dynamic_op_ = false; if (params.count("is_dynamic_op")) { is_dynamic_op_ = params.at("is_dynamic_op").b(); } @@ -58,27 +57,15 @@ tensorflow::Status TRTOptimizationPass::Init( batches_.push_back(i); } } - max_cached_batches_ = 1; if (params.count("maximum_cached_engines")) { max_cached_batches_ = params.at("maximum_cached_engines").i(); } if (params.count("max_workspace_size_bytes")) { - maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i(); } if (params.count("precision_mode")) { - string pm = Uppercase(params.at("precision_mode").s()); - if (pm == "FP32") { - precision_mode_ = 0; - } else if (pm == "FP16") { - precision_mode_ = 1; - } else if (pm == "INT8") { - precision_mode_ = 2; - } else { - LOG(ERROR) << "Unknown precision mode '" << pm << "'"; - return tensorflow::errors::InvalidArgument( - "Unknown precision mode argument" + pm + - " Valid values are FP32, FP16, INT8"); - } + TF_RETURN_IF_ERROR(GetPrecisionMode( + Uppercase(params.at("precision_mode").s()), &precision_mode_)); } return tensorflow::Status::OK(); } @@ -255,7 +242,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.input_graph_def = &item.graph; cp.output_names = &nodes_to_preserve; cp.max_batch_size = maximum_batch_size_; - cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.max_workspace_size_bytes = max_workspace_size_bytes_; cp.output_graph_def = optimized_graph; cp.precision_mode = precision_mode_; cp.minimum_segment_size = minimum_segment_size_; diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index 463ed3883e..71b51d1368 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -36,7 +36,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { minimum_segment_size_(3), precision_mode_(0), maximum_batch_size_(-1), - maximum_workspace_size_(-1) { + is_dynamic_op_(false), + max_cached_batches_(1), + max_workspace_size_bytes_(256LL << 20) { VLOG(1) << "Constructing " << name_; } @@ -57,14 +59,14 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { const tensorflow::grappler::GrapplerItem& item); private: - string name_; + const string name_; int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; bool is_dynamic_op_; std::vector<int> batches_; int max_cached_batches_; - int64_t maximum_workspace_size_; + int64_t max_workspace_size_bytes_; }; } // namespace convert |