aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 16:41:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 16:41:39 -0700
commit32e6e608f6329fda19f9eec560a634725303a749 (patch)
tree46b81cb20b8410d5960dc2b1afcd1b009681aa7f /tensorflow/contrib/tensorrt
parentf1ad54b58b7ce2e08b5f4e38a1631dc667e3e7af (diff)
parentbe645259c251e9b81e2d36efdd7b403bedaffe03 (diff)
Merge pull request #21075 from jjsjann123:trt4_input_patch
PiperOrigin-RevId: 209226085
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD18
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc261
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc78
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py154
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py42
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py58
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py47
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/manual_test.py114
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/rank_two_test.py89
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py304
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py21
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py21
18 files changed, 903 insertions, 410 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index fc0d22d112..26236a0435 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -387,17 +387,19 @@ cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
"test/base_test.py",
- # "test/batch_matmul_test.py",
- # "test/biasadd_matmul_test.py",
- # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation
- # "test/concatenation_test.py", # Blocked by trt4 installation
+ "test/batch_matmul_test.py",
+ "test/biasadd_matmul_test.py",
+ "test/binary_tensor_weight_broadcast_test.py",
+ "test/concatenation_test.py",
"test/const_broadcast_test.py",
+ "test/manual_test.py",
+ "test/memory_alignment_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
- # "test/unary_test.py", # Blocked by trt4 installation
- # "test/vgg_block_nchw_test.py",
- # "test/vgg_block_test.py",
- "test/memory_alignment_test.py",
+ "test/rank_two_test.py",
+ "test/unary_test.py",
+ "test/vgg_block_nchw_test.py",
+ "test/vgg_block_test.py",
],
additional_deps = [
":tf_trt_integration_test_base",
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 35fa590254..863074e773 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -155,12 +155,22 @@ tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
for (int d = 1; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return tensorflow::errors::InvalidArgument(
- "Input tensor has a unknown non-batch dimemension at dim ", d);
+ "Input tensor with shape ", shape.DebugString(),
+ " has an unknown non-batch dimemension at dim ", d);
}
}
return Status::OK();
}
+string DebugString(const nvinfer1::Dims& dims) {
+ string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
+ for (int i = 0; i < nvinfer1::Dims::MAX_DIMS; ++i) {
+ StrAppend(&out, dims.d[i], ",");
+ }
+ StrAppend(&out, ")");
+ return out;
+}
+
// Return whether or not the broadcast is feasible;
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
const bool operand_l_is_tensor,
@@ -353,6 +363,13 @@ class TRT_ShapedWeights {
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+ string DebugString() const {
+ return StrCat(
+ "TRT_ShapedWeights(shape=", convert::DebugString(shape_), ", type=",
+ type_, ", values=", reinterpret_cast<uintptr_t>(values_),
+ ", empty_weight_flag=", empty_weight_flag_, ")");
+ }
+
// TODO(aaroey): make these private.
nvinfer1::Dims shape_;
tensorflow::DataType type_;
@@ -367,11 +384,14 @@ class TRT_TensorOrWeights {
public:
explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
: tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
+
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
: tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+
// TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
: tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
+
~TRT_TensorOrWeights() {}
bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
@@ -381,18 +401,22 @@ class TRT_TensorOrWeights {
CHECK(is_tensor());
return tensor_;
}
+
const nvinfer1::ITensor* tensor() const {
CHECK(is_tensor());
return tensor_;
}
+
TRT_ShapedWeights& weights() {
CHECK(is_weights());
return weights_;
}
+
const TRT_ShapedWeights& weights() const {
CHECK(is_weights());
return weights_;
}
+
nvinfer1::Dims shape() const {
if (is_tensor()) {
return tensor()->getDimensions();
@@ -401,6 +425,18 @@ class TRT_TensorOrWeights {
}
}
+ string DebugString() const {
+ string output = "TRT_TensorOrWeights(type=";
+ if (is_tensor()) {
+ StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor_),
+ ", shape=", convert::DebugString(tensor_->getDimensions()));
+ } else {
+ StrAppend(&output, "weights=", weights_.DebugString());
+ }
+ StrAppend(&output, ")");
+ return output;
+ }
+
private:
nvinfer1::ITensor* tensor_;
TRT_ShapedWeights weights_;
@@ -555,7 +591,7 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
}
void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
- TRT_ShapedWeights* oweights, int num_groups) {
+ TRT_ShapedWeights* oweights, const int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
// K indexes over output channels, C over input channels, and R and S over the
@@ -563,13 +599,13 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
const int r = iweights.shape_.d[0];
const int s = iweights.shape_.d[1];
// TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
- VLOG(2) << "num_groups: " << num_groups;
const int c = iweights.shape_.d[2] / num_groups;
- VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
const int k = iweights.shape_.d[3] * num_groups;
- VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
- VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r;
- VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s;
+ VLOG(2) << "num_groups: " << num_groups
+ << "c" << iweights.shape_.d[2] << " then " << c
+ << "k" << iweights.shape_.d[3] << " then " << k
+ << "r" << iweights.shape_.d[0] << " then " << r
+ << "s" << iweights.shape_.d[1] << " then " << s;
oweights->shape_.d[0] = k / num_groups;
oweights->shape_.d[1] = c * num_groups;
oweights->shape_.d[2] = r;
@@ -607,63 +643,15 @@ using OpConverter =
std::vector<TRT_TensorOrWeights>*)>;
class Converter {
- // TODO(aaroey): fix the order of members.
- std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
- std::unordered_map<string, OpConverter> op_registry_;
- OpConverter plugin_converter_;
- nvinfer1::INetworkDefinition* trt_network_;
- std::list<std::vector<uint8_t>> temp_bufs_;
- // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
- // operate the stored weights instead of operating it directly.
- TRTWeightStore* weight_store_;
- bool fp16_;
- void register_op_converters();
- tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
- std::vector<TRT_TensorOrWeights>* inputs) {
- for (auto const& input_name : node_def.input()) {
- /*************************************************************************
- * TODO(jie): handle case 1) here.
- * Normalizes the inputs and extracts associated metadata:
- * 1) Inputs can contain a colon followed by a suffix of characters.
- * That suffix may be a single number (e.g. inputName:1) or several
- * word characters separated from a number by a colon
- * (e.g. inputName:foo:1). The
- * latter case is used to denote inputs and outputs of functions.
- * 2) Control dependency inputs contain caret at the beginning and we
- * remove this and annotate the edge as a control dependency.
- ************************************************************************/
- // skip control nodes
- if (input_name[0] == '^') continue;
- string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0')
- name.erase(first);
-
- VLOG(2) << "retrieve input: " << name;
- if (trt_tensors_.count(name)) {
- inputs->push_back(trt_tensors_.at(name));
- } else {
- // TODO(aaroey): this should not happen, make it a CHECK.
- // TODO(aaroey): use StrCat for pattern like this.
- string msg("Node ");
- StrAppend(&msg, node_def.name(), " should have an input named '", name,
- "' but it is not available");
- LOG(ERROR) << msg;
- return tensorflow::errors::InvalidArgument(msg);
- }
- }
- return tensorflow::Status::OK();
- }
-
public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters();
}
+
TRTWeightStore* weight_store() { return weight_store_; }
+
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape);
@@ -672,8 +660,10 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
+
// TODO(aaroey): fix all the namings.
bool isFP16() { return fp16_; }
+
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
}
@@ -684,7 +674,6 @@ class Converter {
const string& op = node_def.op();
std::vector<TRT_TensorOrWeights> outputs;
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
- // TODO(aaroey): plugin_converter_ is not set, fix it.
TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
} else {
if (!op_registry_.count(op)) {
@@ -702,7 +691,8 @@ class Converter {
if (output.is_tensor()) {
output.tensor()->setName(output_name.c_str());
}
- VLOG(2) << "Write out tensor: " << output_name;
+ VLOG(2) << "Adding out tensor " << output_name << ": "
+ << output.DebugString();
if (!trt_tensors_.insert({output_name, output}).second) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + op);
@@ -751,6 +741,63 @@ class Converter {
layer->setReshapeDimensions(reshape_dims);
return layer->getOutput(0);
}
+
+ private:
+ std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
+ std::unordered_map<string, OpConverter> op_registry_;
+ OpConverter plugin_converter_;
+ nvinfer1::INetworkDefinition* trt_network_;
+ std::list<std::vector<uint8_t>> temp_bufs_;
+
+ // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
+ // operate the stored weights instead of operating it directly.
+ TRTWeightStore* weight_store_;
+
+ bool fp16_;
+
+ void register_op_converters();
+
+ tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights>* inputs) {
+ for (auto const& input_name : node_def.input()) {
+ /*************************************************************************
+ * TODO(jie): handle case 1) here.
+ * Normalizes the inputs and extracts associated metadata:
+ * 1) Inputs can contain a colon followed by a suffix of characters.
+ * That suffix may be a single number (e.g. inputName:1) or several
+ * word characters separated from a number by a colon
+ * (e.g. inputName:foo:1). The
+ * latter case is used to denote inputs and outputs of functions.
+ * 2) Control dependency inputs contain caret at the beginning and we
+ * remove this and annotate the edge as a control dependency.
+ ************************************************************************/
+ // skip control nodes
+ if (input_name[0] == '^') continue;
+ string name = input_name;
+ auto first = name.find_first_of(':');
+ // TODO(aaroey): why removing the colon but not the zero? A bug?
+ // TODO(aaroey): use TensorId
+ if (first != string::npos && first + 2 == name.size() &&
+ name[first + 1] == '0') {
+ name.erase(first);
+ }
+
+ if (trt_tensors_.count(name)) {
+ TRT_TensorOrWeights& input = trt_tensors_.at(name);
+ inputs->push_back(input);
+ VLOG(2) << "Retrieved input " << name << ": " << input.DebugString();
+ } else {
+ // TODO(aaroey): this should not happen, make it a CHECK.
+ // TODO(aaroey): use StrCat for pattern like this.
+ string msg("Node ");
+ StrAppend(&msg, node_def.name(), " should have an input named '", name,
+ "' but it is not available");
+ LOG(ERROR) << msg;
+ return tensorflow::errors::InvalidArgument(msg);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
};
TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
@@ -1187,17 +1234,11 @@ tensorflow::Status ConvertConv2DHelper(
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
-
- VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims;
- for (int i = 0; i < weights_rsck.shape_.nbDims; i++) {
- VLOG(2) << weights_rsck.shape_.d[i];
- }
-
+ VLOG(2) << "weight shape: " << weights_rsck.DebugString();
if (weights_rsck.shape_.nbDims != 4) {
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
-
if (ctx.isFP16()) {
weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
@@ -1209,16 +1250,13 @@ tensorflow::Status ConvertConv2DHelper(
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
- VLOG(2) << "RSCK: ";
- for (int i = 0; i < 4; i++) {
- VLOG(2) << " " << weights.shape_.d[i];
- }
+ VLOG(2) << "RSCK: " << weights.DebugString();
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
// TODO(jie): stride. (NHWC/NCHW)
const auto tf_stride = attrs.get<std::vector<int>>("strides");
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
- VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
+ VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
<< tf_stride[3];
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
@@ -1240,10 +1278,7 @@ tensorflow::Status ConvertConv2DHelper(
// TODO(jie): handle asymmetric padding
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
-
- auto dim_before = tensor->getDimensions();
- VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
- << dim_before.d[2] << ", " << dim_before.d[3];
+ VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions());
auto pad_layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
@@ -1251,9 +1286,7 @@ tensorflow::Status ConvertConv2DHelper(
TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
- auto dim_after = tensor->getDimensions();
- VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
- << dim_after.d[2] << ", " << dim_after.d[3];
+ VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions());
}
nvinfer1::IConvolutionLayer* layer =
@@ -1266,17 +1299,12 @@ tensorflow::Status ConvertConv2DHelper(
layer->setName(node_def.name().c_str());
layer->setNbGroups(num_groups);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
-
- auto dim_after = output_tensor->getDimensions();
- VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] << ", "
- << dim_after.d[2] << ", " << dim_after.d[3];
-
+ VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
+ VLOG(2) << "data_format: " << data_format;
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
- } else {
- VLOG(2) << "NCHW !!!!";
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1990,22 +2018,22 @@ tensorflow::Status ConvertReduce(Converter& ctx,
return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
}
- const auto keep_dims = attrs.get<bool>("keep_dims");
- auto index_list_data =
- static_cast<int*>(const_cast<void*>(index_list.GetValues()));
-
int axes = 0;
if (index_list.count() == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot support reduce on all (batch) dimensions, at",
node_def.name());
} else {
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
for (int i = 0; i < index_list.count(); i++) {
- if (index_list_data[i] == 0) {
+ int axis = index_list_data[i];
+ if (axis < 0) axis += tensor->getDimensions().nbDims + 1;
+ if (axis == 0) {
return tensorflow::errors::InvalidArgument(
"TRT cannot reduce at batch dimension, at", node_def.name());
}
- axes |= (1 << (index_list_data[i] - 1));
+ axes |= (1 << (axis - 1));
}
}
@@ -2025,6 +2053,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
" , at ", node_def.name());
}
+ const auto keep_dims = attrs.get<bool>("keep_dims");
nvinfer1::ILayer* layer =
ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor),
reduce_operation, axes, keep_dims);
@@ -2694,8 +2723,6 @@ tensorflow::Status ConvertGraphDefToEngine(
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
- nvinfer1::DimsCHW input_dim_pseudo_chw;
- for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
@@ -2713,28 +2740,25 @@ tensorflow::Status ConvertGraphDefToEngine(
LOG(WARNING) << error_message;
return Status(status.code(), error_message);
}
- if (VLOG_IS_ON(1)) {
- string dim_str("dims=");
- StrAppend(&dim_str, "[ ", shape.dim_size(0));
- for (int i = 1; i < shape.dims(); i++) {
- StrAppend(&dim_str, ", ", shape.dim_size(i));
- }
- StrAppend(&dim_str, " ]");
- VLOG(1) << dim_str;
- }
+
+#if NV_TENSORRT_MAJOR == 3
+ nvinfer1::DimsCHW input_dim;
+#elif NV_TENSORRT_MAJOR > 3
+ nvinfer1::Dims input_dim;
+#endif
for (int i = 1; i < shape.dims(); i++) {
- input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
+ input_dim.d[i - 1] = shape.dim_size(i);
}
-
- input_dim_pseudo_chw.nbDims = shape.dims() - 1;
- nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- node_name.c_str(), dtype, input_dim_pseudo_chw);
+ input_dim.nbDims = shape.dims() - 1;
+ nvinfer1::ITensor* input_tensor =
+ converter.network()->addInput(node_name.c_str(), dtype, input_dim);
if (!input_tensor) {
return tensorflow::errors::InvalidArgument(
"Failed to create Input layer tensor ", node_name,
" rank=", shape.dims() - 1);
}
- VLOG(1) << "Input tensor name :" << node_name;
+ VLOG(2) << "Adding engine input tensor " << node_name << " with shape "
+ << DebugString(input_dim);
if (!converter.insert_input_tensor(node_name, input_tensor)) {
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + node_name);
@@ -2937,10 +2961,25 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
<< ": " << status;
return false;
}
- if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+
+
+ if (in_edge->src()->type_string() != "Const" &&
+#if NV_TENSORRT_MAJOR == 3
+ // TRT 3.x only support 4 dimensional input tensor.
+ shape.dims() != 4) {
+#else
+ // Single dimensional input tensor is not supported since the first
+ // dimension is treated as batch dimension.
+ shape.dims() < 2) {
+#endif
VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
- << " which has an input at port " << in_edge->dst_input()
- << " with #dim<3 and is not a const: " << shape;
+ << " which has an input at port " << in_edge->dst_input() << " with"
+#if NV_TENSORRT_MAJOR == 3
+ << " #dim!=4"
+#else
+ << " #dim<2"
+#endif
+ << " and is not a const: " << shape;
return false;
}
return true;
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index b43f1b190f..c82d4a0183 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -74,6 +74,7 @@ class SimpleNode {
const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
+
std::vector<SimpleNode*> in_nodes() const {
std::vector<SimpleNode*> res;
res.reserve(in_edges_.size());
@@ -82,6 +83,16 @@ class SimpleNode {
}
return res;
}
+
+ std::vector<SimpleNode*> out_nodes() const {
+ std::vector<SimpleNode*> res;
+ res.reserve(out_edges_.size());
+ for (const auto e : out_edges_) {
+ if (e) res.push_back(e->dst());
+ }
+ return res;
+ }
+
const string& name() const { return node_->name(); }
const tensorflow::Node* tf_node() const { return node_; }
int id() const { return id_; }
@@ -215,45 +226,53 @@ SimpleGraph::~SimpleGraph() {
namespace {
-bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
- const std::vector<SimpleNode*>& start) {
- // Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+// Copied from TF ReverseDFS, which only works for tensorflow::Graph.
+void StableDFS(const SimpleGraph& g, bool reverse,
+ const std::vector<const SimpleNode*>& start,
+ const std::function<bool(const SimpleNode*)>& enter,
+ const std::function<bool(const SimpleNode*)>& leave) {
+ // Stack of work to do.
struct Work {
- SimpleNode* node;
+ const SimpleNode* node;
bool leave; // Are we entering or leaving n?
};
-
std::vector<Work> stack(start.size());
for (int i = 0; i < start.size(); ++i) {
stack[i] = Work{start[i], false};
}
- std::vector<bool> visited(g->num_node_ids(), false);
+ auto get_nodes = reverse ? [](const SimpleNode* n) { return n->in_nodes(); }
+ : [](const SimpleNode* n) { return n->out_nodes(); };
+ std::vector<bool> visited(g.num_node_ids(), false);
while (!stack.empty()) {
Work w = stack.back();
stack.pop_back();
auto n = w.node;
if (w.leave) {
- if (n == src) {
- return true;
- }
+ if (leave && !leave(n)) return;
continue;
}
if (visited[n->id()]) continue;
visited[n->id()] = true;
- // Arrange to call leave(n) when all done with descendants.
- stack.push_back(Work{n, true});
+ if (enter && !enter(n)) return;
- auto nodes = n->in_nodes();
- for (const auto node : nodes) {
+ // Arrange to call leave(n) when all done with descendants.
+ if (leave) stack.push_back(Work{n, true});
+
+ auto nodes = get_nodes(n);
+ std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
+ std::sort(nodes_sorted.begin(), nodes_sorted.end(),
+ [](const SimpleNode* lhs, const SimpleNode* rhs) {
+ return lhs->name() < rhs->name();
+ });
+ for (const SimpleNode* node : nodes_sorted) {
if (!visited[node->id()]) {
stack.push_back(Work{node, false});
}
}
}
- return false;
}
bool CanContractEdge(const SimpleEdge* edge,
@@ -289,14 +308,21 @@ bool CanContractEdge(const SimpleEdge* edge,
// To achieve this goal, the correct way seems to be:
// 1. remove any direct edge from src->dst;
// 2. detect if src can reach dst, if so they cannot be merged.
- std::vector<SimpleNode*> dfs_start_nodes;
- for (SimpleNode* node : dst->in_nodes()) {
+ std::vector<const SimpleNode*> dfs_start_nodes;
+ for (const SimpleNode* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
-
- const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes);
+ bool has_cycle = false;
+ StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
+ [&has_cycle, src](const SimpleNode* n) {
+ if (n == src) {
+ has_cycle = true;
+ return false;
+ }
+ return true;
+ });
return !has_cycle;
}
} // namespace
@@ -403,15 +429,13 @@ tensorflow::Status SegmentGraph(
// In the future if we have a measure of how beneficial it is to include a
// given node in a TRT subgraph then we can revisit this algorithm to take
// advantage of that information.
- std::vector<tensorflow::Node*> tforder;
- tensorflow::GetPostOrder(*tf_graph, &tforder);
- // use postorder implementation from tensorflow and construct mirror in
- // internal format
- std::vector<SimpleNode*> order;
- order.reserve(tforder.size());
- for (const auto tfnode : tforder) {
- order.push_back(graph->FindNodeId(tfnode->id()));
- }
+ std::vector<const SimpleNode*> order;
+ order.reserve(graph->num_node_ids());
+ StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
+ /*enter=*/nullptr, [&order](const SimpleNode* n) {
+ order.push_back(n);
+ return true;
+ });
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 8ea5a63735..e9ac833d55 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -40,6 +40,7 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -62,19 +63,21 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
identity = array_ops.identity(relu, "identity")
pool = nn_ops.max_pool(
identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(pool, name=self.output_name)
+ array_ops.squeeze(pool, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
- # "relu", "identity", "max_pool"]
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(100, 6, 6, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 6, 6, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ return ["my_trt_op_0"]
class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -85,6 +88,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [100, 24, 24, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -115,20 +119,22 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
q = math_ops.mul(q, edge, name="mul1")
s = math_ops.add(p, q, name="add1")
s = math_ops.sub(s, r, name="sub1")
- array_ops.squeeze(s, name=self.output_name)
+ array_ops.squeeze(s, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
- # "add", "sub1"];
- # - my_trt_op_1 should have ["weights","conv", "div"]
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(100, 12, 12, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(100, 12, 12, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
+ # "add", "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ return ["my_trt_op_0", "my_trt_op_1"]
class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
@@ -143,6 +149,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing two segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -161,18 +168,20 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
c = constant_op.constant(1.0, name="c3")
n = math_ops.add(n, c, name="add3")
n = math_ops.mul(n, n, name="mul3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- # Only the first engine is built.
- "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ }
class PartiallyConvertedTestB(PartiallyConvertedTestA):
@@ -184,13 +193,12 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
trt_convert.clear_test_values("")
trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
- def GetParams(self):
- """Create a graph containing two segment."""
- return super(PartiallyConvertedTestB, self).GetParams()._replace(
- expected_engines={
- # Only the second engine is built.
- "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
- })
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ }
class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
@@ -199,6 +207,7 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -221,18 +230,20 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add", "add1", "mul"],
- "my_trt_op_1": ["add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
+ }
class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
@@ -241,6 +252,7 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing single segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -251,15 +263,17 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add")
n = math_ops.mul(n, n, name="mul")
n = math_ops.add(n, n, name="add1")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {"my_trt_op_0": ["c", "add", "add1", "mul"]}
class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -268,6 +282,7 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -282,22 +297,24 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
- array_ops.squeeze(n, name=self.output_name)
+ array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["add2", "add3", "mul1"],
- # Why segment ["add", "add1", "mul"] was assigned segment id 1
- # instead of 0: the parent node of this segment is actually const
- # node 'c', but it's removed later since it's const output of the
- # segment which is not allowed.
- "my_trt_op_1": ["add", "add1", "mul"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
+ "my_trt_op_1": ["add", "add1", "mul"]
+ }
class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
@@ -306,6 +323,7 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -328,18 +346,20 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
mul1 = math_ops.mul(add2, add2, name="mul1")
with g.control_dependencies([d1, d2, add, add1]):
add3 = math_ops.add(mul1, mul1, name="add3")
- array_ops.squeeze(add3, name=self.output_name)
+ array_ops.squeeze(add3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["c1", "add", "add1", "mul"],
- "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
- },
- expected_output_dims=tuple(input_dims),
- allclose_atol=1.e-06,
- allclose_rtol=1.e-06)
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 2e1107e303..2f153c6f2f 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -37,6 +37,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 12]
+ output_name = "output"
w1_name = "matmul_w1"
w1_dims = [12, 5, 12, 7]
w2_name = "matmul_w2"
@@ -61,15 +62,46 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
x3 = x3 + f
x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7])
out = x1 + x2 + x3
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(12, 5, 8, 7),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 7)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return ["my_trt_op_0", "my_trt_op_1"]
+ return ["my_trt_op_1"]
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return ["my_trt_op_1"]
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt library will fail like:
+ #
+ # ../builder/cudnnBuilder2.cpp:685:
+ # virtual std::vector<nvinfer1::query::Ports<
+ # nvinfer1::query::TensorRequirements>>
+ # nvinfer1::builder::Node::getSupportedFormats(
+ # const nvinfer1::query::Ports<nvinfer1::query::AbstractTensor>&,
+ # const nvinfer1::cudnn::HardwareContext&,
+ # nvinfer1::builder::Format::Type,
+ # const nvinfer1::builder::FormatTypeHack&) const:
+ # Assertion `sf' failed.
+ #
+ # To reproduce, run:
+ # bazel test -c opt --copt=-mavx \
+ # --test_arg=BatchMatMulTest.testTfTrt_ToolConversion_INT8_DynamicEngine \
+ # tensorflow/contrib/tensorrt:batch_matmul_test
+ #
+ # Investigate and fix it.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 8be32f59b4..62f4e525f7 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -38,6 +38,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [48, 12]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -97,18 +98,59 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
out = array_ops.concat(
[x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1)
- out = array_ops.squeeze(out, name=self.output_name)
+ out = array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
- ],
- expected_output_dims=(48, 89),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(48, 89)])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return super(BiasaddMatMulTest,
+ self).GetConversionParams(run_params)._replace(
+ max_batch_size=48, maximum_cached_engines=2)
+
+ def _ValidEngines(self):
+ """Engines expected to build and run."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_6",
+ "my_trt_op_7", "my_trt_op_8", "my_trt_op_9"
+ ]
+
+ def _InvalidEngines(self):
+ """Engines that will cause conversion error at building time."""
+ return ["my_trt_op_3", "my_trt_op_4", "my_trt_op_5"]
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # In dynamic engine mode the engines are built in execution time, not in
+ # conversion time, so build errors occurs later. Here three of the engines
+ # will be failed to built but the corresponding engine op are still created.
+ # TODO(aaroey, jjsjann123): fix this.
+ if (run_params.dynamic_engine and
+ not trt_test.IsQuantizationMode(run_params.precision_mode)):
+ return self._ValidEngines() + self._InvalidEngines()
+ return self._ValidEngines()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self._ValidEngines()
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index 9316b14da0..f126ed4238 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -37,6 +37,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [10, 24, 24, 20]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -104,32 +105,34 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
a = constant_op.constant(np.random.randn(24, 20), dtype=dtype)
f = x + a
x = math_ops.sigmoid(f)
- gen_array_ops.reshape(x, [5, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [5, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=[
- "my_trt_op_0",
- "my_trt_op_1",
- "my_trt_op_2",
- "my_trt_op_3",
- "my_trt_op_4",
- "my_trt_op_5",
- "my_trt_op_6",
- "my_trt_op_7",
- "my_trt_op_8",
- "my_trt_op_9",
- "my_trt_op_10",
- "my_trt_op_11",
- "my_trt_op_12",
- "my_trt_op_13",
- "my_trt_op_14",
- "my_trt_op_15",
- ],
- expected_output_dims=(5, 23040),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 23040)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 1874b9dd45..465cb02296 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -37,6 +37,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 3, 1]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -68,15 +69,17 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
concat1 = array_ops.concat([r1, r2, r3, r4, r5, r6], axis=-1)
concat2 = array_ops.concat([r7, r8, r9, r10, r11, r12], axis=3)
x = array_ops.concat([concat1, concat2], axis=-1)
- gen_array_ops.reshape(x, [2, -1], name=self.output_name)
+ gen_array_ops.reshape(x, [2, -1], name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 126),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 126)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 8c59000b70..e32f047866 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -36,6 +36,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = 'input'
input_dims = [5, 12, 12, 2]
+ output_name = 'output'
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -53,15 +54,25 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
dtype=dtype,
name='filt3')
y3 = nn.conv2d(z2, filt3, strides=[1, 1, 1, 1], padding='SAME', name='y3')
- nn.relu(y3, name='output')
+ nn.relu(y3, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=['my_trt_op_0'],
- expected_output_dims=(5, 12, 12, 1),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(5, 12, 12, 1)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ['my_trt_op_0']
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02
if __name__ == '__main__':
diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py
new file mode 100644
index 0000000000..1187c759b4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/manual_test.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Basic tests for TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import os
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+class ManualTest(trt_test.TfTrtIntegrationTestBase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ super(ManualTest, self).__init__(methodName)
+ self._params_map = None
+
+ def _GetEnv(self):
+ """Get an environment variable specifying the manual test parameters.
+
+ The value of the environment variable is the string representation of a dict
+ which should contain the following keys:
+ - 'graph_path': the file path to the serialized frozen graphdef
+ - 'input_names': TfTrtIntegrationTestParams.input_names
+ - 'input_dims': TfTrtIntegrationTestParams.input_dims
+ - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims
+ - 'output_name': the name of op to fetch
+ - 'expected_engines_to_run': ExpectedEnginesToRun() will return this
+ - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this
+ - 'max_batch_size': ConversionParams.max_batch_size
+
+ Returns:
+ The value of the environment variable.
+ """
+ return os.getenv('TRT_MANUAL_TEST_PARAMS', '')
+
+ def _GetParamsMap(self):
+ """Parse the environment variable as a dict and return it."""
+ if self._params_map is None:
+ self._params_map = ast.literal_eval(self._GetEnv())
+ return self._params_map
+
+ def GetParams(self):
+ """Testing conversion of manually provided frozen graph."""
+ params_map = self._GetParamsMap()
+ gdef = graph_pb2.GraphDef()
+ with gfile.Open(params_map['graph_path'], 'rb') as f:
+ gdef.ParseFromString(f.read())
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=gdef,
+ input_names=params_map['input_names'],
+ input_dims=params_map['input_dims'],
+ output_names=params_map['output_names'],
+ expected_output_dims=params_map['expected_output_dims'])
+
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ conversion_params = super(ManualTest, self).GetConversionParams(run_params)
+ params_map = self._GetParamsMap()
+ if 'max_batch_size' in params_map:
+ conversion_params = conversion_params._replace(
+ max_batch_size=params_map['max_batch_size'])
+ return conversion_params
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return self._GetParamsMap()['expected_engines_to_build']
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ params_map = self._GetParamsMap()
+ if 'expected_engines_to_run' in params_map:
+ return params_map['expected_engines_to_run']
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'atol' in params_map:
+ return params_map['atol']
+ return 1.e-3
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ params_map = self._GetParamsMap()
+ if 'rtol' in params_map:
+ return params_map['rtol']
+ return 1.e-3
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return len(self._GetEnv())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 66eb6be757..bc7c90081f 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -36,6 +36,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 15, 15, 3]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
@@ -57,15 +58,25 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
strides=[1, 1, 1, 1],
padding="VALID",
name="conv_2")
- array_ops.squeeze(out, name=self.output_name)
+ array_ops.squeeze(out, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 15, 15, 10),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ output_names=[output_name],
+ expected_output_dims=[(2, 15, 15, 10)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 0.1
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index fd55b8cd99..11be4feaf7 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -38,6 +38,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -72,15 +73,17 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
t = t + q
t = t + d
t = t - edge3
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0", "my_trt_op_1"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 51c905a50b..eddeafa38b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -37,6 +37,7 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
@@ -54,18 +55,20 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
t = math_ops.mul(conv, b, name="mul")
e = self.trt_incompatible_op(conv, name="incompatible")
t = math_ops.sub(t, e, name="sub")
- array_ops.squeeze(t, name=self.output_name)
+ array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines={
- "my_trt_op_0": ["bias", "mul", "sub"],
- "my_trt_op_1": ["weights", "conv"]
- },
- expected_output_dims=(2, 4, 5, 4),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(2, 4, 5, 4)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ }
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py
new file mode 100644
index 0000000000..74a4a05925
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class RankTwoTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Test for rank 2 input in TF-TRT."""
+ input_names = ["input", "input2"]
+ # Two paths: first with rank 2 input, second with rank 4 input.
+ input_dims = [[12, 5], [12, 5, 2, 2]]
+ output_name = "output"
+ g = ops.Graph()
+ with g.as_default():
+ outputs = []
+ for i in range(2):
+ x = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
+ c = constant_op.constant(1.0, name="c%d_1" % i)
+ q = math_ops.add(x, c, name="add%d_1" % i)
+ q = math_ops.abs(q, name="abs%d_1" % i)
+ c = constant_op.constant(2.2, name="c%d_2" % i)
+ q = math_ops.add(q, c, name="add%d_2" % i)
+ q = math_ops.abs(q, name="abs%d_2" % i)
+ c = constant_op.constant(3.0, name="c%d_3" % i)
+ q = math_ops.add(q, c, name="add%d_3" % i)
+ if i == 0:
+ for j in range(2):
+ q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
+ q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
+ outputs.append(q)
+ # Combine both paths
+ q = math_ops.add(outputs[0], outputs[1], name="add")
+ array_ops.squeeze(q, name=output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=input_names,
+ input_dims=input_dims,
+ output_names=[output_name],
+ expected_output_dims=[tuple(input_dims[1])])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return {
+ "my_trt_op_0": [
+ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1",
+ "abs0_2"
+ ],
+ "my_trt_op_1": [
+ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3",
+ "abs1_1", "abs1_2", "reciprocal0", "reciprocal1"
+ ],
+ }
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8
+ # mode, which is a bug. Re-enable this when trt library is fixed.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 6f85ada464..65ca21cf37 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
@@ -39,18 +40,23 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "expected_engines",
- "expected_output_dims", "allclose_atol", "allclose_rtol"
+ "gdef", "input_names", "input_dims", "output_names", "expected_output_dims"
])
RunParams = namedtuple(
"RunParams",
["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+ConversionParams = namedtuple("ConversionParams", [
+ "max_batch_size", "max_workspace_size_bytes", "precision_mode",
+ "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
+ "cached_engine_batches"
+])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
-def _IsQuantizationMode(mode):
+def IsQuantizationMode(mode):
return mode == "INT8"
@@ -64,10 +70,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@property
- def output_name(self):
- return "output"
-
- @property
def trt_incompatible_op(self):
return math_ops.sin
@@ -112,6 +114,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
super(TfTrtIntegrationTestBase, cls).setUpClass()
trt_convert.enable_test_value()
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(TfTrtIntegrationTestBase, self).__init__(methodName)
+ self._trt_test_params = None
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
@@ -122,43 +128,97 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _PrepareRun(self, params, graph_state):
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return ConversionParams(
+ max_batch_size=max([
+ dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
+ ]),
+ max_workspace_size_bytes=1 << 25,
+ precision_mode=self._ToBytes(run_params.precision_mode),
+ minimum_segment_size=2,
+ is_dynamic_op=run_params.dynamic_engine,
+ maximum_cached_engines=1,
+ cached_engine_batches=None)
+
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ return True
+
+ def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True):
+ """Verify the state of a particular engine after sess.run()."""
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ if expect_run:
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+ else:
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+
+ def VerifyRun(self, run_params, graph_state):
+ """Verify the state of all engines after sess.run()."""
+ for engine_name in self.ExpectedEnginesToBuild(run_params):
+ expect_run = (engine_name in self.ExpectedEnginesToRun(run_params))
+ self.VerifyRunForEngine(engine_name, graph_state, expect_run)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build, implemented by subclass."""
+ raise NotImplementedError()
+
+ def ExpectedEnginesToRun(self, run_params):
+ """Return the expected engines to run."""
+ return self.ExpectedEnginesToBuild(run_params)
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+
+ def _GetParamsCached(self):
+ if self._trt_test_params is None:
+ self._trt_test_params = self.GetParams()
+ return self._trt_test_params
+
+ def _PrepareRun(self, graph_state):
"""Set up necessary testing environment before calling sess.run()."""
# Clear test values added by TRTEngineOp.
trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
- def _VerifyRun(self, params, graph_state):
- """Verify the state after sess.run()."""
- for engine_name in params.expected_engines:
- if graph_state == GraphState.ORIGINAL:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.CALIBRATE:
- self._ExpectCalibration(engine_name, "done")
- self._ExpectNativeSegment(engine_name, "done")
- self._ExpectTrtEngine(engine_name, "")
- elif graph_state == GraphState.INFERENCE:
- self._ExpectCalibration(engine_name, "")
- self._ExpectNativeSegment(engine_name, "")
- self._ExpectTrtEngine(engine_name, "done")
-
- def _GetConfigProto(self, params, run_params, graph_state):
+ def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 2
- custom_op.parameter_map["max_batch_size"].i = max(
- [dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- run_params.precision_mode)
+ trt_params = self.GetConversionParams(run_params)
+ custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
+ custom_op.parameter_map["max_workspace_size_bytes"].i = (
+ trt_params.max_workspace_size_bytes)
+ custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
+ custom_op.parameter_map["minimum_segment_size"].i = (
+ trt_params.minimum_segment_size)
+ custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
+ custom_op.parameter_map["maximum_cached_engines"].i = (
+ trt_params.maximum_cached_engines)
+ if trt_params.cached_engine_batches:
+ custom_op.parameter_map["cached_engine_batches"].list.i.extend(
+ trt_params.cached_engine_batches)
+
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
@@ -190,53 +250,67 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _ExpectNativeSegment(self, engine_name, value):
self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
- def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ def _RunGraph(self,
+ run_params,
+ gdef,
+ input_data,
+ config,
+ graph_state,
num_runs=2):
"""Run given graphdef multiple times."""
+ params = self._GetParamsCached()
assert len(params.input_names) == len(input_data)
g = ops.Graph()
with g.as_default():
io_ops = importer.import_graph_def(
graph_def=gdef,
- return_elements=params.input_names + [self.output_name],
+ return_elements=params.input_names + params.output_names,
name="")
- inp = [i.outputs[0] for i in io_ops[:-1]]
- assert len(inp) == len(input_data)
- out = io_ops[-1].outputs[0]
+ inputs = [op.outputs[0] for op in io_ops[:len(params.input_names)]]
+ assert len(inputs) == len(input_data)
+ outputs = [op.outputs[0] for op in io_ops[len(params.input_names):]]
with self.test_session(
graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
- self._PrepareRun(params, graph_state)
- new_val = sess.run(out,
- {inp[i]: input_data[i] for i in range(len(inp))})
- self.assertEqual(params.expected_output_dims, new_val.shape)
+ self._PrepareRun(graph_state)
+ new_val = sess.run(
+ outputs, {inputs[i]: input_data[i] for i in range(len(inputs))})
+ output_len = len(params.expected_output_dims)
+ self.assertEqual(output_len, len(new_val))
+ for i in range(output_len):
+ self.assertEqual(params.expected_output_dims[i], new_val[i].shape)
if val is not None:
- self.assertAllEqual(val, new_val)
+ self.assertAllClose(val, new_val, atol=1.e-06, rtol=1.e-06)
val = new_val
- self._VerifyRun(params, graph_state)
+ self.VerifyRun(run_params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
- def _RunCalibration(self, params, gdef, input_data, config):
+ def _RunCalibration(self, run_params, gdef, input_data, config):
"""Run calibration on given graph."""
return self._RunGraph(
- params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
+ run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, run_params, gdef):
+ def _GetTrtGraphDef(self, run_params, gdef):
"""Return trt converted graphdef."""
+ params = self._GetParamsCached()
+ trt_params = self.GetConversionParams(run_params)
+ logging.info(trt_params)
return trt_convert.create_inference_graph(
input_graph_def=gdef,
- outputs=[self.output_name],
- max_batch_size=max([dims[0] for dims in params.input_dims]),
- max_workspace_size_bytes=1 << 25,
- precision_mode=run_params.precision_mode,
- minimum_segment_size=2,
- is_dynamic_op=run_params.dynamic_engine)
-
- def _WriteGraph(self, params, run_params, gdef, graph_state):
+ outputs=params.input_names + params.output_names,
+ max_batch_size=trt_params.max_batch_size,
+ max_workspace_size_bytes=trt_params.max_workspace_size_bytes,
+ precision_mode=trt_params.precision_mode,
+ minimum_segment_size=trt_params.minimum_segment_size,
+ is_dynamic_op=trt_params.is_dynamic_op,
+ maximum_cached_engines=trt_params.maximum_cached_engines,
+ cached_engine_batches=trt_params.cached_engine_batches)
+
+ def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
label = "Original"
elif graph_state == GraphState.CALIBRATE:
@@ -247,15 +321,17 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
".pbtxt")
temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
- logging.info("Writing graph to %s/%s", temp_dir, graph_name)
- graph_io.write_graph(gdef, temp_dir, graph_name)
+ if temp_dir:
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
- def _VerifyConnections(self, params, converted_gdef):
+ def _VerifyConnections(self, expected_engines, converted_gdef):
+ params = self._GetParamsCached()
old_to_new_node_map = {
self._ToString(node.name): self._ToString(node.name)
for node in params.gdef.node
}
- for engine_name, node_names in params.expected_engines.items():
+ for engine_name, node_names in expected_engines.items():
for node_name in node_names:
old_to_new_node_map[node_name] = engine_name
name_to_node_map = {
@@ -310,97 +386,114 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
msg="expected:\n%s\nvs actual:\n%s" % (sorted(
expected_input_map.items()), sorted(actual_input_map.items())))
- def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
- self._WriteGraph(params, run_params, gdef, graph_state)
+ def _VerifyGraphDef(self, run_params, gdef, graph_state):
+ self._WriteGraph(run_params, gdef, graph_state)
+ expected_engines = self.ExpectedEnginesToBuild(run_params)
num_engines = 0
for node in gdef.node:
if node.op == "TRTEngineOp":
+ logging.info("Found TRTEngineOp: " + node.name)
+ for node in gdef.node:
+ if node.op == "TRTEngineOp":
num_engines += 1
- self.assertTrue(node.name in params.expected_engines)
- self.assertTrue(len(node.attr["serialized_segment"].s))
- self.assertTrue(len(node.attr["segment_funcdef_name"].s))
+ self.assertTrue(node.name in expected_engines, node.name)
+ self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
+ self.assertTrue(len(node.attr["segment_funcdef_name"].s), node.name)
self.assertEqual(
self._ToBytes(run_params.precision_mode),
- node.attr["precision_mode"].s)
+ node.attr["precision_mode"].s, node.name)
is_dynamic_engine = not node.attr["static_engine"].b
- self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
+ node.name)
has_calibration_data = len(node.attr["calibration_data"].s)
- if (_IsQuantizationMode(run_params.precision_mode) and
+ if (IsQuantizationMode(run_params.precision_mode) and
graph_state == GraphState.INFERENCE):
- self.assertTrue(has_calibration_data)
+ self.assertTrue(has_calibration_data, node.name)
else:
- self.assertFalse(has_calibration_data)
+ self.assertFalse(has_calibration_data, node.name)
if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, len(params.expected_engines))
- if isinstance(params.expected_engines, dict):
- self._VerifyConnections(params, gdef)
+ self.assertEqual(num_engines, len(expected_engines))
+ if isinstance(expected_engines, dict):
+ self._VerifyConnections(expected_engines, gdef)
# TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, run_params):
+ def RunTest(self, run_params):
+ if not self.ShouldRunTest(run_params):
+ return
assert run_params.precision_mode in PRECISION_MODES
- input_data = [np.random.random_sample(dims) for dims in params.input_dims]
+
+ params = self._GetParamsCached()
input_gdef = params.gdef
- self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
+ input_dtypes = {}
+ for node in input_gdef.node:
+ if self._ToString(node.name) in params.input_names:
+ assert self._ToString(node.op) == "Placeholder"
+ input_dtypes[self._ToString(node.name)] = (
+ dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype())
+ assert len(params.input_names) == len(input_dtypes)
+
+ input_data = []
+ for i in range(len(params.input_names)):
+ dtype = input_dtypes[params.input_names[i]]
+ # Multiply the input by some constant to avoid all zeros input for integer
+ # types.
+ scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
+ dims = params.input_dims[i]
+ input_data.append((scale * np.random.random_sample(dims)).astype(dtype))
+ self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, run_params,
- GraphState.ORIGINAL)
+ config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
- GraphState.ORIGINAL)
+ ref_result = self._RunGraph(run_params, input_gdef, input_data,
+ config_no_trt, GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(run_params.precision_mode):
+ if IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, run_params,
- GraphState.CALIBRATE)
+ calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
if run_params.use_optimizer:
- result = self._RunCalibration(params, input_gdef, input_data,
+ result = self._RunCalibration(run_params, input_gdef, input_data,
calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
- self._VerifyGraphDef(params, run_params, calib_gdef,
- GraphState.CALIBRATE)
- result = self._RunCalibration(params, calib_gdef, input_data,
+ calib_gdef = self._GetTrtGraphDef(run_params, input_gdef)
+ self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE)
+ result = self._RunCalibration(run_params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(
+ calib_gdef, run_params.dynamic_engine)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
else:
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, run_params,
- GraphState.INFERENCE)
+ infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if run_params.use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
- else:
- trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
- self._VerifyGraphDef(params, run_params, trt_infer_gdef,
- GraphState.INFERENCE)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
- GraphState.INFERENCE)
+ if not run_params.use_optimizer:
+ infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef)
+ self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)
+ result = self._RunGraph(run_params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
result,
- atol=params.allclose_atol,
- rtol=params.allclose_rtol)
+ atol=self.ExpectedAbsoluteTolerance(run_params),
+ rtol=self.ExpectedRelativeTolerance(run_params))
def testIdempotence(self):
# Test that applying tensorrt optimizer or offline conversion tools multiple
@@ -421,13 +514,12 @@ def _AddTests(test_class):
"""Gets a single test method based on the parameters."""
def _Test(self):
- params = self.GetParams()
logging.info(
"Running test %s with parameters: use_optimizer=%s, "
"precision_mode=%s, dynamic_engine=%s",
"testTfTrt_" + run_params.test_name, run_params.use_optimizer,
run_params.precision_mode, run_params.dynamic_engine)
- self.RunTest(params, run_params)
+ self.RunTest(run_params)
return _Test
@@ -435,7 +527,7 @@ def _AddTests(test_class):
dynamic_engine_options = [False, True]
for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
- if _IsQuantizationMode(precision_mode):
+ if IsQuantizationMode(precision_mode):
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index 500057a36d..8736bfb644 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -38,6 +38,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [12, 5, 8, 1, 1, 12]
+ output_name = "output"
input2_name = "input_2"
input2_dims = [12, 5, 8, 1, 12, 1, 1]
g = ops.Graph()
@@ -95,18 +96,20 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
q = a * b
q = q / c
- array_ops.squeeze(q, name=self.output_name)
+ array_ops.squeeze(q, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- expected_engines=[
- "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
- "my_trt_op_4"
- ],
- expected_output_dims=(12, 5, 8, 12),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(12, 5, 8, 12)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return [
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index ab4d224db4..b0271a04b3 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -38,15 +38,14 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 2, 8, 8]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
data_format="NCHW",
is_training=False)
e = constant_op.constant(
@@ -67,15 +66,17 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
"VALID",
data_format="NCHW",
name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 6, 2, 2),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 6, 2, 2)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index 56bdf848ea..d7c165784b 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -38,15 +38,14 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 8, 8, 2]
+ output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
- x,
- np.random.randn(2).astype(np.float32),
- np.random.randn(2).astype(np.float32),
- mean=np.random.randn(2).astype(np.float32),
- variance=np.random.randn(2).astype(np.float32),
+ x, [1.0, 1.0], [0.0, 0.0],
+ mean=[0.5, 0.5],
+ variance=[1.0, 1.0],
is_training=False)
e = constant_op.constant(
np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
@@ -58,15 +57,17 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
idty = array_ops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- array_ops.squeeze(v, name="output")
+ array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(5, 2, 2, 6),
- allclose_atol=1.e-03,
- allclose_rtol=1.e-03)
+ output_names=[output_name],
+ expected_output_dims=[(5, 2, 2, 6)])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
if __name__ == "__main__":