diff options
author | Jie <jiej@nvidia.com> | 2018-07-24 23:40:39 -0700 |
---|---|---|
committer | Jie <jiej@nvidia.com> | 2018-07-24 23:40:39 -0700 |
commit | 1d4a8296b26150f7eabf5bbb981b9b2438a9fb2a (patch) | |
tree | e630ecacdb91f240e142fc508ab485687270069a /tensorflow/contrib/tensorrt | |
parent | b372882d429ecff6c69fc18ac55efc94ad3a9501 (diff) |
merge upstream master; addressing review comments per changes upstream
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 24 |
1 files changed, 6 insertions, 18 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 7782919566..9d881eda90 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -2693,13 +2693,6 @@ tensorflow::Status ConvertGraphDefToEngine( VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && (node_def.op() == "Placeholder")) { - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = - ConvertDType(node_def.attr().at("dtype").type(), &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } int32 slot_number = -1; if (!tensorflow::strings::safe_strto32( node_name.c_str() + strlen(kInputPHName), &slot_number)) { @@ -2729,21 +2722,12 @@ tensorflow::Status ConvertGraphDefToEngine( #if NV_TENSORRT_MAJOR == 3 nvinfer1::DimsCHW input_dim; - // TRT 3.x only support 4 dimensional input tensor. - if (shape.dims() != 4) { - string err_str = "Require 4 dimensional input."; - StrAppend(&err_str, " Got ", shape.dims(), " ", - node_name); - return tensorflow::errors::Unimplemented(err_str); - } #elif NV_TENSORRT_MAJOR > 3 nvinfer1::Dims input_dim; #endif - for (int i = 1; i < shape.dims(); i++) { input_dim.d[i - 1] = shape.dim_size(i); } - input_dim.nbDims = shape.dims() - 1; nvinfer1::ITensor* input_tensor = converter.network()->addInput( node_name.c_str(), dtype, input_dim); @@ -2920,12 +2904,16 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { << ": " << status; return false; } - if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { + +#if NV_TENSORRT_MAJOR == 3 + // TRT 3.x only support 4 dimensional input tensor. + if (shape.dims() != 4 && in_edge->src()->type_string() != "Const") { VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() << " which has an input at port " << in_edge->dst_input() - << " with #dim<3 and is not a const: " << shape; + << " with #dim!=4 and is not a const: " << shape; return false; } +#endif return true; } |