aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc375
1 files changed, 200 insertions, 175 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 92a692baa7..370911e4d9 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -53,8 +53,8 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
namespace convert {
+using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-
namespace {
inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
@@ -430,9 +430,8 @@ class Converter {
tensorflow::tensorrt::TRTWeightStore* weight_store_;
bool fp16_;
void register_op_converters();
- std::vector<TRT_TensorOrWeights> get_inputs(
- const tensorflow::NodeDef& node_def) {
- std::vector<TRT_TensorOrWeights> inputs;
+ tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights>* inputs) {
for (auto const& input_name : node_def.input()) {
/*************************************************************************
* TODO(jie) handle case 1) here
@@ -453,13 +452,17 @@ class Converter {
VLOG(2) << "retrieve input: " << name;
if (trt_tensors_.count(name)) {
- inputs.push_back(trt_tensors_.at(name));
+ inputs->push_back(trt_tensors_.at(name));
} else {
- LOG(FATAL) << "input: " << name << " not available for node at, "
- << node_def.name();
+ string str("Node ");
+ StrAppend(&str, node_def.name(), " should have an input named '", name,
+ "' but it is not available");
+ LOG(WARNING) << "input: " << name << " not available for node at "
+ << node_def.name();
+ return tensorflow::errors::InvalidArgument(str);
}
}
- return inputs;
+ return tensorflow::Status::OK();
}
public:
@@ -483,7 +486,8 @@ class Converter {
}
tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
- std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
+ std::vector<TRT_TensorOrWeights> inputs;
+ TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs));
string op = node_def.op();
if (!op_registry_.count(op)) {
return tensorflow::errors::Unimplemented(
@@ -548,6 +552,19 @@ class Converter {
}
};
+TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
+ const TRT_ShapedWeights& weights_src) {
+ auto dtype_new = tensorflow::DataType::DT_HALF;
+ TRT_ShapedWeights weights =
+ ctx.get_temp_weights(dtype_new, weights_src.shape_);
+ const float* src = static_cast<const float*>(weights_src.GetValues());
+ Eigen::half* dst = const_cast<Eigen::half*>(
+ static_cast<Eigen::half const*>(weights.GetValues()));
+ for (int64_t i = 0; i < weights_src.count(); i++) {
+ dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]);
+ }
+ return weights;
+}
// ****************************************************************************
// Constant folding functions
// TODO(jie): once optimizer kicks in, we should have done constant folding
@@ -875,7 +892,7 @@ tensorflow::Status BinaryTensorOpWeight(
// Check type consistency
nvinfer1::DataType ttype;
- TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
+ TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype));
// Check scale mode
auto dims_w = weights.shape_;
@@ -957,6 +974,10 @@ tensorflow::Status BinaryTensorOpWeight(
}
}
+ if (ctx.isFP16()) {
+ weights = ConvertFP32ToFP16(ctx, weights);
+ }
+
// prepare weights
TRT_ShapedWeights shift_weights(weights.type_);
TRT_ShapedWeights scale_weights(weights.type_);
@@ -998,9 +1019,7 @@ enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV };
tensorflow::Status ConvertConv2DHelper(
Converter& ctx, const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs,
- int group // group ==0 specifies depthwise conv
-) {
+ std::vector<TRT_TensorOrWeights>* outputs, int group) {
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TFAttrs attrs(node_def);
@@ -1025,6 +1044,10 @@ tensorflow::Status ConvertConv2DHelper(
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+ if (ctx.isFP16()) {
+ weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
+ }
+
TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_);
@@ -1134,9 +1157,9 @@ tensorflow::Status BinaryTensorOpTensor(
CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
- " not supported at: " +
- node_def.name());
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(tensor_l),
@@ -1295,8 +1318,11 @@ tensorflow::Status ConvertScale(Converter& ctx,
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- // TODO(jie): handle NHWC/NCHW transpose;
TRT_ShapedWeights weights = inputs.at(1).weights();
+ if (ctx.isFP16()) {
+ weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
+ }
+
TRT_ShapedWeights empty_weights(weights.type_);
TFAttrs attrs(node_def);
@@ -1376,8 +1402,11 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.d[0] = weights_tensor.float_val_size();
scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
} else {
- LOG(FATAL) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
+ LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
}
} else {
@@ -1391,33 +1420,16 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
- if (ctx.isFP16()) {
- auto dtype_new = tensorflow::DataType::DT_HALF;
- size_t len_data = tensorflow::DataTypeSize(dtype_new);
- for (int i = 0; i < scalar_shape.nbDims; i++)
- len_data *= scalar_shape.d[i];
- ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
- void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- tensorflow::Tensor temp_tensor(tensorflow::DT_HALF, tensor.shape());
- auto half_tensor = temp_tensor.flat<Eigen::half>();
- Eigen::DefaultDevice defd;
- half_tensor.device(defd) =
- tensor.flat<float>().template cast<Eigen::half>();
- memcpy(dst, half_tensor.data(), len_data); // store into weight store
- weights = TRT_ShapedWeights(dtype_new, dst, scalar_shape);
- } else {
- size_t len_data = tensorflow::DataTypeSize(dtype);
- for (int i = 0; i < scalar_shape.nbDims; i++)
- len_data *= scalar_shape.d[i];
- ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
- void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<float> tensor_data(
- weights_tensor.float_val().begin(),
- weights_tensor.float_val()
- .end()); // make a local copy first to flatten
- memcpy(dst, tensor_data.data(), len_data); // store into weight store
- weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
- }
+ size_t len_data = tensorflow::DataTypeSize(dtype);
+ for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
+ ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
+ void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
+ std::vector<float> tensor_data(
+ weights_tensor.float_val().begin(),
+ weights_tensor.float_val()
+ .end()); // make a local copy first to flatten
+ memcpy(dst, tensor_data.data(), len_data); // store into weight store
+ weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.int_val().empty()) {
VLOG(2) << "int!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
@@ -1432,8 +1444,11 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.d[0] = weights_tensor.int_val_size();
scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
} else {
- LOG(FATAL) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
+ LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
}
} else {
@@ -1447,62 +1462,23 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
- if (ctx.isFP16()) {
- auto dtype_new = tensorflow::DataType::DT_HALF;
- size_t len_data = tensorflow::DataTypeSize(dtype_new);
- for (int i = 0; i < scalar_shape.nbDims; i++)
- len_data *= scalar_shape.d[i];
- ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
- void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- tensorflow::Tensor temp_tensor(tensorflow::DT_HALF, tensor.shape());
- TTypes<Eigen::half>::Flat half_tensor = temp_tensor.flat<Eigen::half>();
- Eigen::DefaultDevice defd;
- switch (dtype) {
- case (tensorflow::DT_INT32): {
- half_tensor.device(defd) =
- tensor.flat<int32>().template cast<Eigen::half>();
- break;
- }
- case (tensorflow::DT_INT16): {
- half_tensor.device(defd) =
- tensor.flat<int16>().template cast<Eigen::half>();
- break;
- }
- case (tensorflow::DT_INT8): {
- half_tensor.device(defd) =
- tensor.flat<int8>().template cast<Eigen::half>();
- break;
- }
- case (tensorflow::DT_UINT8): {
- half_tensor.device(defd) =
- tensor.flat<uint8>().template cast<Eigen::half>();
- break;
- }
- default:
- return tensorflow::errors::InvalidArgument(
- "Datatype " + tensorflow::DataTypeString(dtype) +
- " for FP16 conversion");
- break;
- };
- memcpy(dst, half_tensor.data(), len_data); // store into weight store
- weights = TRT_ShapedWeights(dtype_new, dst, scalar_shape);
- } else {
- size_t len_data = tensorflow::DataTypeSize(dtype);
- for (int i = 0; i < scalar_shape.nbDims; i++)
- len_data *= scalar_shape.d[i];
- size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32);
- len_data = std::max(len_data, len_tensor);
- ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
- void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<int32> tensor_data(
- weights_tensor.int_val().begin(),
- weights_tensor.int_val()
- .end()); // make a local copy first to flatten
- // doesn't have to be contiguous
- memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
- weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
- }
+ // we should not have converted //if (ctx.isFP16()) {
+ size_t len_data = tensorflow::DataTypeSize(dtype);
+ for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
+ size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32);
+ len_data = std::max(len_data, len_tensor);
+ ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
+ void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
+ std::vector<int32> tensor_data(
+ weights_tensor.int_val().begin(),
+ weights_tensor.int_val().end()); // make a local copy first to flatten
+ // doesn't have to be contigous
+ memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
+ weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.tensor_content().empty()) {
+ // obsolete method.
+ // After optimization path, we do not see weights in this format.
+ // fp16 conversion technically should be needed here.
VLOG(2) << "TENSOR!!!" << node_def.name();
const auto& content = weights_tensor.tensor_content();
@@ -1784,8 +1760,6 @@ tensorflow::Status ConvertConcat(Converter& ctx,
TRT_ShapedWeights axis = inputs.at(input_size).weights();
TFAttrs attrs(node_def);
- // auto attr_size = attrs.at("N")->i();
- // auto data_type = attrs.get<nvinfer1::DataType>("T");
auto index_type = attrs.get<tensorflow::DataType>("Tidx");
// TODO(jie): handle data type
@@ -1875,71 +1849,103 @@ tensorflow::Status ConvertFusedBatchNorm(
"only is_training=false is supported, at " + node_def.name());
}
nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
- TRT_ShapedWeights scale_weights = inputs.at(1).weights();
- TRT_ShapedWeights offset_weights = inputs.at(2).weights();
- TRT_ShapedWeights mean_weights = inputs.at(3).weights();
- TRT_ShapedWeights variance_weights = inputs.at(4).weights();
- TRT_ShapedWeights dummy_power_weights(scale_weights.type_);
- TRT_ShapedWeights combined_scale_weights =
- ctx.get_temp_weights_like(scale_weights);
- TRT_ShapedWeights combined_offset_weights =
- ctx.get_temp_weights_like(offset_weights);
- size_t nweight = scale_weights.count();
- if ((scale_weights.type_ == offset_weights.type_) &&
- (mean_weights.type_ == variance_weights.type_) &&
- (scale_weights.type_ == variance_weights.type_)) {
- if ((scale_weights.type_ != tensorflow::DataType::DT_FLOAT) &&
- (scale_weights.type_ != tensorflow::DataType::DT_HALF)) {
+
+ // Check parameter types
+ auto parameter_type = inputs.at(1).weights().type_;
+ if ((parameter_type != tensorflow::DataType::DT_FLOAT) &&
+ (parameter_type != tensorflow::DataType::DT_HALF)) {
+ return tensorflow::errors::Unimplemented(
+ "only float32 or float16 weight data type is supported, for node " +
+ node_def.name() + " got " + tensorflow::DataTypeString(parameter_type));
+ }
+ for (int i = 1; i < 5; i++) {
+ if (inputs.at(i).weights().type_ != parameter_type) {
return tensorflow::errors::Unimplemented(
- "only float32 or float16 weight data type is supported, for node " +
- node_def.name() + " got " +
- tensorflow::DataTypeString(scale_weights.type_));
+ "Inconsistent parameter type for batchnormis not supported, at: " +
+ node_def.name());
}
- if (scale_weights.type_ == tensorflow::DT_FLOAT) {
- for (size_t i = 0; i < nweight; ++i) {
- float scale = (static_cast<float const*>(scale_weights.GetValues()))[i];
- float offset =
- (static_cast<float const*>(offset_weights.GetValues()))[i];
- float mean = (static_cast<float const*>(mean_weights.GetValues()))[i];
- float variance =
- (static_cast<float const*>(variance_weights.GetValues()))[i];
- float& combined_scale_ref = const_cast<float*>(
- static_cast<float const*>(combined_scale_weights.GetValues()))[i];
- float& combined_offset_ref = const_cast<float*>(
- static_cast<float const*>(combined_offset_weights.GetValues()))[i];
- combined_scale_ref = scale / sqrtf(variance + epsilon);
- combined_offset_ref = offset - mean * combined_scale_ref;
- }
- } else {
- const Eigen::half* scale_vals =
- (static_cast<Eigen::half const*>(scale_weights.GetValues()));
- const Eigen::half* off_vals =
- (static_cast<Eigen::half const*>(offset_weights.GetValues()));
- const Eigen::half* mean_vals =
- (static_cast<Eigen::half const*>(mean_weights.GetValues()));
- const Eigen::half* variance_vals =
- (static_cast<Eigen::half const*>(variance_weights.GetValues()));
- Eigen::half* comb_scale_vals = const_cast<Eigen::half*>(
- static_cast<Eigen::half const*>(combined_scale_weights.GetValues()));
- Eigen::half* comb_off_vals = const_cast<Eigen::half*>(
- static_cast<Eigen::half const*>(combined_offset_weights.GetValues()));
- for (size_t i = 0; i < nweight; ++i) {
- float scale(scale_vals[i]);
- float offset(off_vals[i]);
- float mean(mean_vals[i]);
- float variance(variance_vals[i]);
- float combined_scale_ref = scale / sqrtf(variance + epsilon);
- comb_scale_vals[i] = Eigen::half(combined_scale_ref);
- float combined_offset_ref = offset - mean * combined_scale_ref;
- comb_off_vals[i] = Eigen::half(combined_offset_ref);
+ }
+
+ TRT_ShapedWeights dummy_power_weights(parameter_type);
+ size_t nweight = 0;
+ for (int i = 1; i < 5; i++) {
+ nweight = std::max(nweight, (size_t)inputs.at(i).weights().count());
+ }
+ TRT_ShapedWeights* ptr_shape_weights = nullptr;
+ for (int i = 1; i < 5; i++) {
+ if (inputs.at(i).weights().count() == nweight) {
+ ptr_shape_weights =
+ const_cast<TRT_ShapedWeights*>(&(inputs.at(i).weights()));
+ } else if (inputs.at(i).weights().count() != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Inconsistent batchnorm parameter count, at: " + node_def.name());
+ }
+ }
+ // We could technically have two weights with different shape.
+ // that requires two addScale op, arguably less performant
+ TRT_ShapedWeights combined_scale_weights =
+ ctx.get_temp_weights_like(*ptr_shape_weights);
+ TRT_ShapedWeights combined_offset_weights =
+ ctx.get_temp_weights_like(*ptr_shape_weights);
+
+ const Eigen::half* cast_vals_array[4];
+ const float* vals_array[4];
+ for (int j = 0; j < 4; j++) {
+ cast_vals_array[j] =
+ static_cast<Eigen::half const*>(inputs.at(j + 1).weights().GetValues());
+ vals_array[j] =
+ static_cast<float const*>(inputs.at(j + 1).weights().GetValues());
+ }
+ Eigen::half* cast_combined_scale_vals = const_cast<Eigen::half*>(
+ static_cast<Eigen::half const*>(combined_scale_weights.GetValues()));
+ Eigen::half* cast_combined_offset_vals = const_cast<Eigen::half*>(
+ static_cast<Eigen::half const*>(combined_offset_weights.GetValues()));
+ float* combined_scale_vals = const_cast<float*>(
+ static_cast<float const*>(combined_scale_weights.GetValues()));
+ float* combined_offset_vals = const_cast<float*>(
+ static_cast<float const*>(combined_offset_weights.GetValues()));
+
+ for (size_t i = 0; i < nweight; ++i) {
+ float batchnorm_data[4];
+ for (int j = 0; j < 4; j++) {
+ if (inputs.at(j + 1).weights().count() != 1) {
+ if (parameter_type == tensorflow::DT_FLOAT) {
+ batchnorm_data[j] = vals_array[j][i];
+ } else if (parameter_type == tensorflow::DT_HALF) {
+ batchnorm_data[j] =
+ Eigen::half_impl::half_to_float(cast_vals_array[j][i]);
+ }
+ } else {
+ if (parameter_type == tensorflow::DT_FLOAT) {
+ batchnorm_data[j] = vals_array[j][0];
+ } else if (parameter_type == tensorflow::DT_HALF) {
+ batchnorm_data[j] =
+ Eigen::half_impl::half_to_float(cast_vals_array[j][0]);
+ }
}
}
+ float scale = batchnorm_data[0];
+ float offset = batchnorm_data[1];
+ float mean = batchnorm_data[2];
+ float variance = batchnorm_data[3];
+ float combined_scale_val = scale / sqrtf(variance + epsilon);
+ float combined_offset_val = offset - mean * combined_scale_val;
+ if (parameter_type == tensorflow::DT_FLOAT) {
+ combined_scale_vals[i] = combined_scale_val;
+ combined_offset_vals[i] = combined_offset_val;
+ } else if (parameter_type == tensorflow::DT_HALF) {
+ cast_combined_scale_vals[i] = Eigen::half(combined_scale_val);
+ cast_combined_offset_vals[i] = Eigen::half(combined_offset_val);
+ }
}
- nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
- *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
- combined_offset_weights.GetWeightsForTRT(),
- combined_scale_weights.GetWeightsForTRT(),
- dummy_power_weights.GetWeightsForTRT());
+
+ nvinfer1::ScaleMode mode = nweight == 1 ? nvinfer1::ScaleMode::kUNIFORM
+ : nvinfer1::ScaleMode::kCHANNEL;
+ nvinfer1::IScaleLayer* layer =
+ ctx.network()->addScale(*const_cast<nvinfer1::ITensor*>(tensor), mode,
+ combined_offset_weights.GetWeightsForTRT(),
+ combined_scale_weights.GetWeightsForTRT(),
+ dummy_power_weights.GetWeightsForTRT());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -2050,6 +2056,7 @@ void Converter::register_op_converters() {
op_registry_["Const"] = ConvertConst;
// TODO(ben,jie): this is a temp hack.
op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
+ op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed
// resnet_50_v1 slim implementation
op_registry_["Add"] = ConvertBinary;
@@ -2143,8 +2150,11 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
calib_res->thr_->join();
delete calib_res->thr_;
if (!calib_res->engine_) {
- LOG(FATAL) << "Calibration failed!, engine is nullptr. Did you run "
+ LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run "
"calibration graph?";
+ return tensorflow::errors::FailedPrecondition(
+ "Calibration graph needs to be executed on"
+ " calibration data before convertsion to inference graph");
}
auto weight_rmgr = trt_rm->getManager("WeightStore");
TF_CHECK_OK(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
@@ -2181,7 +2191,7 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
return status;
}
auto trt_engine_node = graph.AddNode(engine_node, &status);
- TF_CHECK_OK(status);
+ TF_RETURN_IF_ERROR(status);
for (size_t i = 0; i < out_edges.size(); i++) {
VLOG(1) << "Connecting trt_engine_node output " << i << " with "
<< out_edges.at(i)->dst()->name() << " port "
@@ -2279,6 +2289,12 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
input_dtypes.push_back(tf_dtype);
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ auto type_status = ConvertDType(tf_dtype, &dtype);
+ if (type_status != tensorflow::Status::OK()) {
+ LOG(WARNING) << "Data type conversion for input '" << node_name
+ << "' failed";
+ return type_status;
+ }
TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
VLOG(2) << "accessing output index of: " << output_idx
@@ -2346,8 +2362,8 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
output_names.push_back(tensor_name);
auto tensor_or_weights = converter.get_tensor(tensor_name);
if (!tensor_or_weights.is_tensor()) {
- return tensorflow::errors::InvalidArgument(
- "Output node is weights not tensor");
+ return tensorflow::errors::InvalidArgument("Output node'" + tensor_name +
+ "' is weights not tensor");
}
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
if (!tensor) {
@@ -2504,7 +2520,11 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
input_dtypes.push_back(tf_dtype);
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
+ auto type_status = ConvertDType(tf_dtype, &dtype);
+ if (type_status != tensorflow::Status::OK()) {
+ LOG(WARNING) << "Type conversion failed for " << node_name;
+ return type_status;
+ }
VLOG(2) << "Accessing output index of: " << output_idx
<< ", at node: " << node_name
@@ -2515,8 +2535,12 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
// update the code once TRT 4.0 comes out.
- if (op_info.shape().dim_size() != 4)
- return tensorflow::errors::Unimplemented("require 4 dimensional input");
+ if (op_info.shape().dim_size() != 4) {
+ string err_str = "Require 4 dimensional input.";
+ StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
+ shape_inference_node_name);
+ return tensorflow::errors::Unimplemented(err_str);
+ }
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(2) << "dimension: " << i
@@ -2577,8 +2601,8 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
output_names.push_back(tensor_name);
auto tensor_or_weights = converter.get_tensor(tensor_name);
if (!tensor_or_weights.is_tensor()) {
- return tensorflow::errors::InvalidArgument(
- "Output node is weights not tensor");
+ return tensorflow::errors::InvalidArgument("Output node '" + tensor_name +
+ "' is weights not tensor");
}
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
if (!tensor) {
@@ -2622,7 +2646,8 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
}
TF_RETURN_IF_ERROR(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
engine_name, engine_name));
- LOG(INFO) << "finished engine " << engine_name;
+ LOG(INFO) << "finished engine " << engine_name << " containing "
+ << s.subgraph_node_ids.size() << " nodes";
// Build the TRT op
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");