aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Jie <jiej@nvidia.com>2018-07-24 23:40:39 -0700
committerGravatar Jie <jiej@nvidia.com>2018-07-24 23:40:39 -0700
commit1d4a8296b26150f7eabf5bbb981b9b2438a9fb2a (patch)
treee630ecacdb91f240e142fc508ab485687270069a /tensorflow/contrib/tensorrt
parentb372882d429ecff6c69fc18ac55efc94ad3a9501 (diff)
merge upstream master; addressing review comments per changes upstream
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc24
1 files changed, 6 insertions, 18 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 7782919566..9d881eda90 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -2693,13 +2693,6 @@ tensorflow::Status ConvertGraphDefToEngine(
VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
- nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- auto type_status =
- ConvertDType(node_def.attr().at("dtype").type(), &dtype);
- if (type_status != tensorflow::Status::OK()) {
- LOG(WARNING) << "Type conversion failed for " << node_name;
- return type_status;
- }
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
@@ -2729,21 +2722,12 @@ tensorflow::Status ConvertGraphDefToEngine(
#if NV_TENSORRT_MAJOR == 3
nvinfer1::DimsCHW input_dim;
- // TRT 3.x only support 4 dimensional input tensor.
- if (shape.dims() != 4) {
- string err_str = "Require 4 dimensional input.";
- StrAppend(&err_str, " Got ", shape.dims(), " ",
- node_name);
- return tensorflow::errors::Unimplemented(err_str);
- }
#elif NV_TENSORRT_MAJOR > 3
nvinfer1::Dims input_dim;
#endif
-
for (int i = 1; i < shape.dims(); i++) {
input_dim.d[i - 1] = shape.dim_size(i);
}
-
input_dim.nbDims = shape.dims() - 1;
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
node_name.c_str(), dtype, input_dim);
@@ -2920,12 +2904,16 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
<< ": " << status;
return false;
}
- if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+
+#if NV_TENSORRT_MAJOR == 3
+ // TRT 3.x only support 4 dimensional input tensor.
+ if (shape.dims() != 4 && in_edge->src()->type_string() != "Const") {
VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
<< " which has an input at port " << in_edge->dst_input()
- << " with #dim<3 and is not a const: " << shape;
+ << " with #dim!=4 and is not a const: " << shape;
return false;
}
+#endif
return true;
}