aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1521
1 files changed, 1025 insertions, 496 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 146b9c7344..451d6fe698 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include <algorithm>
+#include <cstring>
#include <list>
#include <map>
#include <memory>
@@ -49,15 +50,34 @@ limitations under the License.
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
-// Check if the types are equal. Cast to int first so that failure log message
-// would work!
-#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
+#define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+
+#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \
+ do { \
+ return tensorflow::errors::Internal( \
+ "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \
+ } while (0)
+
+#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
+ do { \
+ if (status == false) { \
+ TFTRT_INTERNAL_ERROR_AT_NODE(node); \
+ } \
+ } while (0)
+
+#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
+ do { \
+ if (ptr == nullptr) { \
+ TFTRT_INTERNAL_ERROR_AT_NODE(node); \
+ } \
+ } while (0)
namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::tensorflow::str_util::Split;
-
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -75,13 +95,163 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
case tensorflow::DataType::DT_HALF:
*trt_dtype = nvinfer1::DataType::kHALF;
break;
+#if NV_TENSORRT_MAJOR > 3
+ case tensorflow::DataType::DT_INT32:
+ *trt_dtype = nvinfer1::DataType::kINT32;
+ break;
+#endif
default:
return tensorflow::errors::InvalidArgument(
- "Unsupported data type " + tensorflow::DataTypeString(tf_dtype));
+ "Unsupported data type ", tensorflow::DataTypeString(tf_dtype));
}
return tensorflow::Status::OK();
}
+void GetInputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int out_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasOutputProperties(outside_node->name())) {
+ auto output_params =
+ graph_properties.GetOutputProperties(outside_node->name());
+ auto out_shape = output_params.at(out_port);
+ *dtype = out_shape.dtype();
+ *shape = out_shape.shape();
+ } else {
+ VLOG(0) << "Unknown output shape" << outside_node->name();
+ *dtype = outside_node->output_type(out_port);
+ }
+}
+
+void GetOutputProperties(const grappler::GraphProperties& graph_properties,
+ const Node* outside_node, const int in_port,
+ PartialTensorShape* shape,
+ tensorflow::DataType* dtype) {
+ if (graph_properties.HasInputProperties(outside_node->name())) {
+ auto input_params =
+ graph_properties.GetInputProperties(outside_node->name());
+ auto in_shape = input_params.at(in_port);
+ *dtype = in_shape.dtype();
+ *shape = in_shape.shape();
+ } else {
+ *dtype = outside_node->input_type(in_port);
+ }
+}
+
+tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
+ const tensorflow::DataType dtype,
+ nvinfer1::DataType* trt_dtype) {
+ // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so
+ // put them there instead.
+ TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
+ if (shape.dims() < 0) {
+ return tensorflow::errors::InvalidArgument("Input tensor rank is unknown.");
+ }
+ if (shape.dims() > 9) {
+ return tensorflow::errors::OutOfRange(
+ "Input tensor rank is greater than 8.");
+ }
+ for (int d = 1; d < shape.dims(); ++d) {
+ if (shape.dim_size(d) < 0) {
+ return tensorflow::errors::InvalidArgument(
+ "Input tensor has a unknown non-batch dimemension at dim ", d);
+ }
+ }
+ return Status::OK();
+}
+
+// Return whether or not the broadcast is feasible;
+bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
+ const bool operand_l_is_tensor,
+ const nvinfer1::Dims& operand_r,
+ const bool operand_r_is_tensor,
+ nvinfer1::Dims* operand_l_new_shape,
+ nvinfer1::Dims* operand_r_new_shape) {
+ // ***************************************************************************
+ // TensorRT Elementwise op supports broadcast but requires both tensor to be
+ // of Identical rank
+ //
+ // We consider case of:
+ // 1. operand_l to be a Tensor & operand_r to be a Const;
+ // 2. operand_l to be a Tensor & operand_r to be a Tensor;
+ // note: const op const (constant folding) should fallback to TensorFlow
+ //
+ // broadcast scheme:
+ // T: 1 3 5 (tensor would not have batch dimension)
+ // W: 1 1 3 1 (weight would have all explicit dimensions)
+ // i. fill in explicit dimensions
+ // -> T: -1 1 3 5 (we put a -1 for batch dimension)
+ // -> W: 1 1 3 1
+ // ii. compare broadcast feasibility
+ //
+ // We cannot support the following since TensorRT does not allow manipulation
+ // on batch dimension, we cannot generate output with proper shape
+ // T: 3 5 1
+ // W: 1 1 1 1 3 5 1
+ // -> T: 1 1 1 -1 3 5 1
+ // -> W: 1 1 1 1 3 5 1
+ // ***************************************************************************
+ const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
+ const size_t element_size = sizeof(operand_l.d[0]);
+
+ // fill in dimensions
+ int l_s[max_nb_dims];
+ std::fill(l_s, l_s + max_nb_dims, 1);
+ int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims;
+ int r_s[max_nb_dims];
+ std::fill(r_s, r_s + max_nb_dims, 1);
+ int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims;
+
+ int max_d = std::max(l_d, r_d);
+ std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d,
+ operand_l.nbDims * element_size);
+ std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d,
+ operand_r.nbDims * element_size);
+
+ // set -1 for batch dimension, since batch size is not supposed to be
+ // broadcasted
+ if (operand_l_is_tensor) {
+ if (max_d != l_d) { // if broadcast beyond batch dimension, fail
+ return false;
+ }
+ l_s[0] = -1;
+ }
+ if (operand_r_is_tensor) {
+ if (max_d != r_d) { // if broadcast beyond batch dimension, fail
+ return false;
+ }
+ r_s[0] = -1;
+ }
+
+ // compare broadcast feasibility
+ for (int i = max_d - 1; i >= 0; i--) {
+ if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) {
+ return false;
+ }
+ }
+
+ // output new TensorRT Dimension (stripping the batch dimension)
+ operand_l_new_shape->nbDims = max_d - 1;
+ std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size);
+ operand_r_new_shape->nbDims = max_d - 1;
+ std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size);
+
+ return true;
+}
+
+inline bool DimsEqual(const nvinfer1::Dims& dim_l,
+ const nvinfer1::Dims& dim_r) {
+ if (dim_l.nbDims != dim_r.nbDims) {
+ return false;
+ }
+ for (int i = 0; i < dim_l.nbDims; i++) {
+ if (dim_l.d[i] != dim_r.d[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
nvinfer1::Dims dims;
dims.nbDims = tensor.dims();
@@ -91,7 +261,7 @@ inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
return dims;
}
-inline int64_t GetShapeSize(nvinfer1::Dims shape) {
+inline int64_t GetShapeSize(const nvinfer1::Dims& shape) {
// Returns total number of elements in shape
int64_t count = 1;
for (int d = 0; d < shape.nbDims; ++d) {
@@ -104,7 +274,7 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
const std::vector<int64_t>& input_dims) {
std::vector<std::pair<int, int>> padding(input_dims.size());
- CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
+ CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
for (size_t i = 0; i < input_dims.size(); ++i) {
// Formula to calculate the padding
@@ -134,6 +304,7 @@ string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
return op_name_a.substr(0, last_scope_separator);
}
+// Class to convert TF weight to TRT weight.
class TRT_ShapedWeights {
public:
TRT_ShapedWeights(tensorflow::DataType type, const void* values,
@@ -145,12 +316,14 @@ class TRT_ShapedWeights {
explicit TRT_ShapedWeights(tensorflow::DataType type)
: shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
+ // TODO(aaroey): use rvalue reference.
TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
: shape_(rhs.shape_),
type_(rhs.type_),
values_(rhs.values_),
empty_weight_flag_(rhs.empty_weight_flag_) {}
+ // TODO(aaroey): use GetShapeSize() instead.
int64_t count() const {
int64_t c = 1;
for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
@@ -168,6 +341,7 @@ class TRT_ShapedWeights {
const void* GetValues() const { return values_; }
+ // TODO(aaroey): get rid of this method.
void SetValues(const void* values) { values_ = values; }
size_t size_bytes() const {
@@ -178,10 +352,12 @@ class TRT_ShapedWeights {
// Default converter
operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+ // TODO(aaroey): make these private.
nvinfer1::Dims shape_;
tensorflow::DataType type_;
private:
+ // TODO(aaroey): this should not be const as it's always from TRTWeightStore.
const void* values_;
bool empty_weight_flag_;
};
@@ -192,6 +368,7 @@ class TRT_TensorOrWeights {
: tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
: tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+ // TODO(aaroey): use rvalue reference.
TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
: tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
~TRT_TensorOrWeights() {}
@@ -200,19 +377,19 @@ class TRT_TensorOrWeights {
bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
nvinfer1::ITensor* tensor() {
- CHECK_EQ(is_tensor(), true);
+ CHECK(is_tensor());
return tensor_;
}
const nvinfer1::ITensor* tensor() const {
- CHECK_EQ(is_tensor(), true);
+ CHECK(is_tensor());
return tensor_;
}
TRT_ShapedWeights& weights() {
- CHECK_EQ(is_weights(), true);
+ CHECK(is_weights());
return weights_;
}
const TRT_ShapedWeights& weights() const {
- CHECK_EQ(is_weights(), true);
+ CHECK(is_weights());
return weights_;
}
nvinfer1::Dims shape() const {
@@ -236,21 +413,25 @@ class TFAttrs {
attrs_.insert({attr.first, &attr.second});
}
}
- bool count(string key) const { return attrs_.count(key); }
- tensorflow::AttrValue const* at(string key) const {
+
+ bool count(const string& key) const { return attrs_.count(key); }
+
+ tensorflow::AttrValue const* at(const string& key) const {
if (!attrs_.count(key)) {
LOG(FATAL) << "Attribute not found: " << key;
}
return attrs_.at(key);
}
+
template <typename T>
T get(const string& key) const;
+
template <typename T>
T get(const string& key, const T& default_value) const {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
- std::vector<string> GetAllAttrKey() {
+ std::vector<string> GetAllAttrKeys() const {
std::vector<string> attr_list;
for (const auto& attr_item : attrs_) {
attr_list.emplace_back(attr_item.first);
@@ -285,15 +466,6 @@ std::vector<string> TFAttrs::get<std::vector<string>>(const string& key) const {
auto attr = this->at(key)->list().s();
return std::vector<string>(attr.begin(), attr.end());
}
-template <>
-nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(const string& key) const {
- auto values = this->get<std::vector<int>>(key);
- nvinfer1::Dims dims;
- dims.nbDims = values.size();
- std::copy(values.begin(), values.end(), dims.d);
- // Note: No dimension type information is included
- return dims;
-}
template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
@@ -319,10 +491,11 @@ bool TFAttrs::get<bool>(const string& key) const {
}
// TODO(jie): reorder4 & reorder2 should be merged?
+// TODO(aaroey): fix the order of parameters.
template <typename T>
-void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
- nvinfer1::DimsNCHW istrides, T* odata,
- nvinfer1::DimsNCHW ostrides) {
+void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata,
+ const nvinfer1::DimsNCHW& istrides, T* odata,
+ const nvinfer1::DimsNCHW& ostrides) {
for (int n = 0; n < shape.n(); ++n) {
for (int c = 0; c < shape.c(); ++c) {
for (int h = 0; h < shape.h(); ++h) {
@@ -337,12 +510,13 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
}
template <typename T>
-void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
- T* odata, nvinfer1::DimsHW ostrides) {
+void Reorder2(const nvinfer1::DimsHW& shape, const T* idata,
+ const nvinfer1::DimsHW& istrides, T* odata,
+ const nvinfer1::DimsHW& ostrides) {
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
- idata[h * ostrides.h() + w * ostrides.w()];
+ idata[h * istrides.h() + w * istrides.w()];
}
}
}
@@ -350,16 +524,17 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
// TODO(jie): fallback to tensorflow!!
void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights) {
- int c = iweights.shape_.d[0];
- int k = iweights.shape_.d[1];
+ const int c = iweights.shape_.d[0];
+ const int k = iweights.shape_.d[1];
oweights->shape_.d[0] = k;
oweights->shape_.d[1] = c;
- nvinfer1::DimsHW istrides = {1, k};
- nvinfer1::DimsHW ostrides = {c, 1};
+ const nvinfer1::DimsHW istrides = {1, k};
+ const nvinfer1::DimsHW ostrides = {c, 1};
switch (iweights.type_) {
case tensorflow::DataType::DT_FLOAT: {
Reorder2({k, c}, static_cast<float const*>(iweights.GetValues()),
istrides,
+ // TODO(aaroey): get rid of all the const_cast like this.
static_cast<float*>(const_cast<void*>(oweights->GetValues())),
ostrides);
break;
@@ -382,21 +557,24 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights, int num_groups) {
CHECK_EQ(iweights.type_, oweights->type_);
CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
- int r = iweights.shape_.d[0];
- int s = iweights.shape_.d[1];
- // TRT requires GKcRS, while TF depthwise has RSCK
- // where c=1, C=G
+ // K indexes over output channels, C over input channels, and R and S over the
+ // height and width of the convolution
+ const int r = iweights.shape_.d[0];
+ const int s = iweights.shape_.d[1];
+ // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G
VLOG(2) << "num_groups: " << num_groups;
- int c = iweights.shape_.d[2] / num_groups;
+ const int c = iweights.shape_.d[2] / num_groups;
VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c;
- int k = iweights.shape_.d[3] * num_groups;
+ const int k = iweights.shape_.d[3] * num_groups;
VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k;
+ VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r;
+ VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s;
oweights->shape_.d[0] = k / num_groups;
oweights->shape_.d[1] = c * num_groups;
oweights->shape_.d[2] = r;
oweights->shape_.d[3] = s;
- nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
- nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+ const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+ const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
switch (iweights.type_) {
case tensorflow::DataType::DT_FLOAT: {
Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
@@ -428,11 +606,14 @@ using OpConverter =
std::vector<TRT_TensorOrWeights>*)>;
class Converter {
+ // TODO(aaroey): fix the order of members.
std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
std::unordered_map<string, OpConverter> op_registry_;
OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
+ // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to
+ // operate the stored weights instead of operating it directly.
TRTWeightStore* weight_store_;
bool fp16_;
void register_op_converters();
@@ -440,7 +621,7 @@ class Converter {
std::vector<TRT_TensorOrWeights>* inputs) {
for (auto const& input_name : node_def.input()) {
/*************************************************************************
- * TODO(jie) handle case 1) here
+ * TODO(jie): handle case 1) here.
* Normalizes the inputs and extracts associated metadata:
* 1) Inputs can contain a colon followed by a suffix of characters.
* That suffix may be a single number (e.g. inputName:1) or several
@@ -454,6 +635,7 @@ class Converter {
if (input_name[0] == '^') continue;
string name = input_name;
auto first = name.find_first_of(':');
+ // TODO(aaroey): why removing the colon but not the zero? A bug?
if (first != string::npos && first + 2 == name.size() &&
name[first + 1] == '0')
name.erase(first);
@@ -462,12 +644,13 @@ class Converter {
if (trt_tensors_.count(name)) {
inputs->push_back(trt_tensors_.at(name));
} else {
- string str("Node ");
- StrAppend(&str, node_def.name(), " should have an input named '", name,
+ // TODO(aaroey): this should not happen, make it a CHECK.
+ // TODO(aaroey): use StrCat for pattern like this.
+ string msg("Node ");
+ StrAppend(&msg, node_def.name(), " should have an input named '", name,
"' but it is not available");
- LOG(WARNING) << "input: " << name << " not available for node at "
- << node_def.name();
- return tensorflow::errors::InvalidArgument(str);
+ LOG(ERROR) << msg;
+ return tensorflow::errors::InvalidArgument(msg);
}
}
return tensorflow::Status::OK();
@@ -488,6 +671,7 @@ class Converter {
weights.SetValues(weight_store_->store_.back().data());
return weights;
}
+ // TODO(aaroey): fix all the namings.
bool isFP16() { return fp16_; }
TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
return this->get_temp_weights(weights.type_, weights.shape_);
@@ -496,9 +680,10 @@ class Converter {
tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
std::vector<TRT_TensorOrWeights> inputs;
TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs));
- string op = node_def.op();
+ const string& op = node_def.op();
std::vector<TRT_TensorOrWeights> outputs;
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
+ // TODO(aaroey): plugin_converter_ is not set, fix it.
TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
} else {
if (!op_registry_.count(op)) {
@@ -509,7 +694,7 @@ class Converter {
TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
}
for (size_t i = 0; i < outputs.size(); ++i) {
- TRT_TensorOrWeights output = outputs.at(i);
+ TRT_TensorOrWeights& output = outputs[i];
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
@@ -527,26 +712,29 @@ class Converter {
nvinfer1::INetworkDefinition* network() { return trt_network_; }
- TRT_TensorOrWeights get_tensor(string name) {
+ TRT_TensorOrWeights get_tensor(const string& name) {
if (!trt_tensors_.count(name)) {
return TRT_TensorOrWeights(nullptr);
}
return trt_tensors_.at(name);
}
- bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
+ bool insert_input_tensor(const string& name, nvinfer1::ITensor* tensor) {
return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
}
nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
- std::vector<int> order) {
- auto dims = input_tensor->getDimensions();
+ const std::vector<int>& order) {
+ const auto dims = input_tensor->getDimensions();
// TODO(jie): change the return to status and properly exit
if (order.size() - 1 != size_t(dims.nbDims))
LOG(ERROR) << "Dimension does not match, fail gracefully";
nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+ if (layer == nullptr) {
+ return nullptr;
+ }
nvinfer1::Permutation permutation;
for (int32_t i = 0; i < dims.nbDims; ++i) {
permutation.order[i] = order[i + 1] - 1;
@@ -577,13 +765,14 @@ TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx,
}
return weights;
}
+
// ****************************************************************************
// Constant folding functions
// TODO(jie): once optimizer kicks in, we should have done constant folding
// there.
-//*****************************************************************************/
+// *****************************************************************************
struct LambdaFactory {
- enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP };
OP_CATEGORY op;
template <typename T>
@@ -595,6 +784,8 @@ struct LambdaFactory {
}
case OP_CATEGORY::NEG:
return [](T t) -> T { return -t; };
+ case OP_CATEGORY::RECIP:
+ return [](T t) -> T { return 1.0 / t; };
default:
VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
@@ -628,7 +819,6 @@ struct LambdaFactory {
VLOG(2) << "LAMBDA VAL : " << val;
return l + val;
};
- // Return [val](T l)-> T {return l+val;};
case OP_CATEGORY::SUB:
return [val](T l) -> T {
VLOG(2) << "LAMBDA VAL : " << val;
@@ -688,11 +878,13 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
}
case OP_CATEGORY::NEG:
return [](Eigen::half t) -> Eigen::half { return -t; };
+ // TODO(aaroey): can we support RECIP?
default:
VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
}
}
+
tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
TRT_ShapedWeights* oweights,
LambdaFactory unary_op) {
@@ -738,6 +930,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
if (iweights_l.count() != iweights_r.count()) {
// We only supports broadcast of RankZero
if (iweights_l.count() == 1) {
+ // TODO(aaroey): Remove loggings like this.
VLOG(2) << "I bet it is not working!" << (*inp_l);
std::transform(inp_r, inp_r + iweights_r.count(), oup,
binary_op.broadcast_l<float>(*inp_l));
@@ -790,117 +983,21 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
return tensorflow::Status::OK();
}
-tensorflow::Status ConstantFoldUnary(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- TRT_ShapedWeights weights_input = inputs.at(0).weights();
-
- // Allocate output weights
- TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
-
- // FIXME assume type matches input weights
- // Get trt type & shape
- // Maybe this part has to be moved into the block of rsqrt later
- // Check type consistency
- CHECK_EQ(weights_input.type_,
- TFAttrs(node_def).get<tensorflow::DataType>("T"));
-
- LambdaFactory unary_op;
- if (node_def.op() == "Rsqrt") {
- // Compute rsqrt
- unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
- auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
- // Pass the output
- if (ret == tensorflow::Status::OK()) {
- outputs->push_back(TRT_TensorOrWeights(weights_output));
- }
- return ret;
- } else {
- return tensorflow::errors::Unimplemented("Binary op not supported: " +
- node_def.op());
- }
-}
-
-// TODO(jie,ben) broadcast is needed yet not implemented
-// Let's get the simple stuff working first. Maybe we should fall back to TF
-// approach for constant folding
-tensorflow::Status ConstantFoldBinary(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
- TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
-
- // Check type consistency
- CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
-
- if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
- return tensorflow::errors::Unimplemented(
- "Binary op implicit broadcast not supported: " + node_def.op());
-
- // TODO(jie): constant fold should really fall back to TF.
- int num_dims = weights_input_l.shape_.nbDims;
- nvinfer1::Dims output_shape;
- output_shape.nbDims = num_dims;
- VLOG(2) << "nb_dims: " << num_dims
- << ", the other: " << weights_input_r.shape_.nbDims;
- for (int i = 0; i < num_dims; i++) {
- if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
- output_shape.d[i] = weights_input_l.shape_.d[i];
- } else if (weights_input_l.shape_.d[i] == 1 ||
- weights_input_r.shape_.d[i] == 1) {
- output_shape.d[i] =
- std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
- } else {
- return tensorflow::errors::Unimplemented(
- "Binary op with incompatible shape at, " + node_def.op());
- }
- VLOG(2) << "left: " << weights_input_l.shape_.d[i]
- << "right: " << weights_input_r.shape_.d[i]
- << "output: " << output_shape.d[i];
- }
-
- // FIXME assume type matches input weights
- // Get trt type & shape
- TFAttrs attrs(node_def);
- // Maybe this part has to be moved into the block of rsqrt later
- tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
-
- // Allocate output weights
- TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
-
- LambdaFactory binary_op;
- if (node_def.op() == "Sub") {
- binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
- } else if (node_def.op() == "Mul") {
- binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
- } else if (node_def.op() == "Add") {
- binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
- } else {
- return tensorflow::errors::Unimplemented("Binary op not supported: " +
- node_def.op());
- }
- auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
- binary_op);
-
- // Pass the output
- if (ret == tensorflow::Status::OK()) {
- outputs->push_back(TRT_TensorOrWeights(weights_output));
- }
-
- return ret;
-}
-
// TODO(jie): broadcast is needed yet not implemented.
// Only implemented channel wise for the time being
tensorflow::Status BinaryTensorOpWeight(
Converter& ctx, const tensorflow::NodeDef& node_def,
const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
- std::vector<TRT_TensorOrWeights>* outputs) {
- // FIXME assume type matches input weights
- // Get trt type & shape
- // Maybe this part has to be moved into the block of rsqrt later
+ bool swapped_inputs, std::vector<TRT_TensorOrWeights>* outputs) {
+ // tensor is the left operand while weights is the right operand;
+ // when swapped_inputs set to true, those two are swapped.
+ // TODO(aaroey): use a set.
+ if (node_def.op() != "Sub" && node_def.op() != "Add" &&
+ node_def.op() != "Mul" && node_def.op() != "Div" &&
+ node_def.op() != "RealDiv") {
+ return tensorflow::errors::Unimplemented(
+ "op not supported: " + node_def.op() + ", at: " + node_def.name());
+ }
// Check type consistency
nvinfer1::DataType ttype;
@@ -910,6 +1007,12 @@ tensorflow::Status BinaryTensorOpWeight(
auto dims_w = weights.shape_;
auto dims_t = tensor->getDimensions();
+ // TODO(jie): addScale checks for input tensor dimension
+ if (dims_t.nbDims != 3) {
+ return tensorflow::errors::InvalidArgument(
+ "addScale requires tensor with rank 3, " + node_def.name());
+ }
+
// default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
@@ -980,6 +1083,7 @@ tensorflow::Status BinaryTensorOpWeight(
permutation[dims_t.nbDims] = 1;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
permutation);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
} else {
return tensorflow::errors::InvalidArgument(
"Transpose cannot be applied, " + node_def.name());
@@ -997,11 +1101,35 @@ tensorflow::Status BinaryTensorOpWeight(
// Maybe I should do a switch
if (node_def.op() == "Sub") {
- TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
- LambdaFactory unary_op;
- unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
- TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
- shift_weights = neg_weights;
+ if (swapped_inputs) {
+ shift_weights = weights;
+ nvinfer1::IUnaryLayer* layer =
+ ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kNEG);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ } else {
+ TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
+ shift_weights = neg_weights;
+ }
+ } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
+ if (swapped_inputs) {
+ scale_weights = weights;
+ nvinfer1::IUnaryLayer* layer =
+ ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kRECIP);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ } else {
+ TRT_ShapedWeights recip_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::RECIP;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op));
+ scale_weights = recip_weights;
+ }
} else if (node_def.op() == "Mul") {
scale_weights = weights;
} else if (node_def.op() == "Add") {
@@ -1014,11 +1142,13 @@ tensorflow::Status BinaryTensorOpWeight(
nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
*const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
scale_weights, power_weights);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// transpose back dimension
if (permutation_flag) {
output_tensor = ctx.TransposeTensor(output_tensor, permutation);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
// Pass the output
@@ -1042,20 +1172,31 @@ tensorflow::Status ConvertConv2DHelper(
if (data_format == "NHWC") {
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 1, 2});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
h_index = 1;
w_index = 2;
// TODO(jie): transpose it
}
// tensor after transpose (NCHW)
- auto tensor_dim = tensor->getDimensions();
+ const auto tensor_dim = tensor->getDimensions();
int num_groups = group;
- if (num_groups == 0) // depthwise convolution
- num_groups = tensor_dim.d[0];
+ if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution
VLOG(2) << "groups count: " << num_groups;
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+
+ VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims;
+ for (int i = 0; i < weights_rsck.shape_.nbDims; i++) {
+ VLOG(2) << weights_rsck.shape_.d[i];
+ }
+
+ if (weights_rsck.shape_.nbDims != 4) {
+ return tensorflow::errors::Internal(
+ "Conv2D expects kernel of dimension 4, at: " + node_def.name());
+ }
+
if (ctx.isFP16()) {
weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
@@ -1063,18 +1204,22 @@ tensorflow::Status ConvertConv2DHelper(
TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_);
- int noutput = weights.shape_.d[0] * num_groups;
+ const int noutput = weights.shape_.d[0] * num_groups;
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
+ VLOG(2) << "RSCK: ";
+ for (int i = 0; i < 4; i++) {
+ VLOG(2) << " " << weights.shape_.d[i];
+ }
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
// TODO(jie): stride. (NHWC/NCHW)
- auto tf_stride = attrs.get<std::vector<int>>("strides");
+ const auto tf_stride = attrs.get<std::vector<int>>("strides");
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
<< tf_stride[3];
- nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+ const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
std::vector<std::pair<int, int>> padding;
// TODO(jie): padding.
@@ -1102,6 +1247,7 @@ tensorflow::Status ConvertConv2DHelper(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
auto dim_after = tensor->getDimensions();
@@ -1112,6 +1258,7 @@ tensorflow::Status ConvertConv2DHelper(
nvinfer1::IConvolutionLayer* layer =
ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
noutput, kernel_size, weights, biases);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
@@ -1126,6 +1273,7 @@ tensorflow::Status ConvertConv2DHelper(
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
} else {
VLOG(2) << "NCHW !!!!";
}
@@ -1147,35 +1295,91 @@ tensorflow::Status ConvertConv2DHelper(
node_def.name());
}
+// Helper function converts input into tensor with shape specified by dims.
+bool PrepareTensorForShape(Converter& ctx, const TRT_TensorOrWeights& input,
+ const nvinfer1::Dims& dims,
+ const nvinfer1::ITensor** tensor) {
+ if (input.is_tensor()) {
+ if (DimsEqual(input.shape(), dims)) {
+ *tensor = input.tensor();
+ } else {
+ nvinfer1::IShuffleLayer* layer = ctx.network()->addShuffle(
+ *const_cast<nvinfer1::ITensor*>(input.tensor()));
+ if (layer != nullptr) {
+ layer->setReshapeDimensions(dims);
+ *tensor = layer->getOutput(0);
+ } else {
+ return false;
+ }
+ }
+ } else {
+#if NV_TENSORRT_MAJOR > 3
+ nvinfer1::IConstantLayer* layer =
+ ctx.network()->addConstant(dims, input.weights());
+ if (layer != nullptr) {
+ *tensor = layer->getOutput(0);
+ } else {
+ return false;
+ }
+#else
+ return false;
+#endif
+ }
+ return true;
+}
+
tensorflow::Status BinaryTensorOpTensor(
Converter& ctx, const tensorflow::NodeDef& node_def,
- const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
std::vector<TRT_TensorOrWeights>* outputs) {
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
{"Mul", nvinfer1::ElementWiseOperation::kPROD},
{"Sub", nvinfer1::ElementWiseOperation::kSUB},
{"Div", nvinfer1::ElementWiseOperation::kDIV},
+ {"RealDiv", nvinfer1::ElementWiseOperation::kDIV},
+ {"Minimum", nvinfer1::ElementWiseOperation::kMIN},
+ {"Maximum", nvinfer1::ElementWiseOperation::kMAX},
};
- // FIXME assume type matches input weights
+ const nvinfer1::ITensor* tensor_l;
+ const nvinfer1::ITensor* tensor_r;
+
+ nvinfer1::Dims dim_l;
+ nvinfer1::Dims dim_r;
+
+ if (!TensorRTGetBroadcastShape(operand_l.shape(), operand_l.is_tensor(),
+ operand_r.shape(), operand_r.is_tensor(),
+ &dim_l, &dim_r)) {
+ return tensorflow::errors::InvalidArgument(
+ "Binary op broadcast scheme not supported by TensorRT op: " +
+ node_def.op() + ", at: " + node_def.name());
+ }
+
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, operand_l, dim_l, &tensor_l), node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, operand_r, dim_r, &tensor_r), node_def.name());
+
// get trt type & shape
TFAttrs attrs(node_def);
// maybe this part has to be moved into the block of rsqrt later
nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
// check type consistency
- CHECK_EQ_TYPE(tensor_l->getType(), dtype);
- CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
- if (op_pair == ops.end())
+ if (op_pair == ops.end()) {
return tensorflow::errors::Unimplemented(
- "binary op: " + node_def.op() +
- " not supported at: " + node_def.name());
+ "binary op: ", node_def.op(), " not supported at: ", node_def.name());
+ }
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ // TODO(aaroey): will tensor_l/tensor_r get modified?
*const_cast<nvinfer1::ITensor*>(tensor_l),
*const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
@@ -1202,7 +1406,7 @@ tensorflow::Status ConvertPlugin(Converter& ctx,
// passing attributes
// TODO(jie): support more general attribute
TFAttrs attrs(node_def);
- auto attr_key_vector = attrs.GetAllAttrKey();
+ auto attr_key_vector = attrs.GetAllAttrKeys();
for (auto attr_key : attr_key_vector) {
// TODO(jie): support only list of float for toy example here.
auto data = attrs.get<std::vector<float>>(attr_key);
@@ -1223,29 +1427,6 @@ tensorflow::Status ConvertPlugin(Converter& ctx,
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertPlaceholder(
- Converter& ctx, const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- VLOG(2) << "Placeholder should have been replace already";
- return tensorflow::errors::Unimplemented("cannot convert Placeholder op");
- // OK this make sense since we are supposed to replace it with input
- TFAttrs attrs(node_def);
- nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
- nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
-
- dims.nbDims--;
- for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
-
- nvinfer1::ITensor* output =
- ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
- if (!output) {
- return tensorflow::errors::InvalidArgument("Failed to create Input layer");
- }
- outputs->push_back(TRT_TensorOrWeights(output));
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertConv2D(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
@@ -1271,65 +1452,64 @@ tensorflow::Status ConvertPool(Converter& ctx,
int h_index = 2;
int w_index = 3;
- auto data_format = attrs.get<string>("data_format");
+ const auto data_format = attrs.get<string>("data_format");
if (data_format == "NHWC") {
h_index = 1;
w_index = 2;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 1, 2});
- } else {
- VLOG(2) << "NCHW !!!!";
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
}
+
nvinfer1::PoolingType type;
- // TODO(jie): support other pooling type
- if (node_def.op() == "MaxPool")
+ if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
- else if (node_def.op() == "AvgPool")
+ } else if (node_def.op() == "AvgPool") {
type = nvinfer1::PoolingType::kAVERAGE;
- else
- return tensorflow::errors::Unimplemented("Only supports Max pool");
+ } else {
+ return tensorflow::errors::Unimplemented("Unsupported pool type: ",
+ node_def.op());
+ }
- // TODO(jie): NCHW
- auto tf_stride = attrs.get<std::vector<int>>("strides");
- nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+ const auto tf_stride = attrs.get<std::vector<int>>("strides");
+ const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
- auto tf_kernel = attrs.get<std::vector<int>>("ksize");
- nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+ const auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+ const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
auto tensor_dim = tensor->getDimensions();
std::vector<std::pair<int, int>> padding;
- // TODO(jie): padding.
- if (attrs.get<string>("padding") == "SAME") {
+ const string padding_type = attrs.get<string>("padding");
+ if (padding_type == "SAME") {
// This is NCHW tensor with no batch dimension.
// 1 -> h
// 2 -> w
padding = CreateSamePadding(
stride, ksize,
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
- } else if (attrs.get<string>("padding") == "VALID") {
- // No padding for valid padding here
- VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
+ } else if (padding_type == "VALID") {
padding = {{0, 0}, {0, 0}};
} else {
- return tensorflow::errors::Unimplemented(
- "Current MaxPool cannot support padding other than SAME");
+ return tensorflow::errors::Unimplemented("Unsupported padding type: ",
+ padding_type);
}
if (padding[0].first != padding[0].second ||
padding[1].first != padding[1].second) {
- // TODO(jie): handle asymmetric padding
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
<< padding[1].first << padding[1].second;
auto pad_layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
}
nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
*const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
@@ -1337,10 +1517,8 @@ tensorflow::Status ConvertPool(Converter& ctx,
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
if (data_format == "NHWC") {
- // TODO(jie): transpose it back!
output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
- } else {
- VLOG(2) << "NCHW !!!!";
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1353,6 +1531,7 @@ tensorflow::Status ConvertActivation(
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
*const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1363,40 +1542,61 @@ tensorflow::Status ConvertScale(Converter& ctx,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::Unimplemented(
- "Only supports tensor op weight for now, at " + node_def.name());
- // Implement tensor binaryOp weight [channel wise] for now;
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ "ConvertScale only supports tensor<op>weight: ", node_def.name());
+ }
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TRT_ShapedWeights weights = inputs.at(1).weights();
if (ctx.isFP16()) {
weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights());
}
TRT_ShapedWeights empty_weights(weights.type_);
-
TFAttrs attrs(node_def);
- // Transpose NHWC
- auto data_format = attrs.get<string>("data_format");
+ const auto data_format = attrs.get<string>("data_format");
+ int channel_index;
+ const auto dims = tensor->getDimensions();
if (data_format == "NHWC") {
- tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
- {0, 3, 1, 2});
- // TODO(jie): transpose it
+ // 1). NHWC is really N+C
+ channel_index = dims.nbDims - 1; // batch dimension is implicit here!
} else {
- VLOG(2) << "NCHW !!!!";
+ // 2). NCHW is really N+CHW
+ channel_index = dims.nbDims - 3; // batch dimension is implicit here!
}
- auto dims = tensor->getDimensions();
- VLOG(2) << "tensor dimensions: " << dims.nbDims;
- for (int i = 0; i < dims.nbDims; i++) {
- VLOG(2) << "i: " << dims.d[i];
+ nvinfer1::Permutation permutation;
+ for (int32_t i = 0; i < dims.nbDims; ++i) {
+ permutation.order[i] = i;
}
- dims = weights.shape_;
- VLOG(2) << "tensor dimensions: " << dims.nbDims;
- for (int i = 0; i < dims.nbDims; i++) {
- VLOG(2) << "i: " << dims.d[i];
+
+ if (channel_index >= 0) {
+ permutation.order[0] = channel_index;
+ permutation.order[channel_index] = 0;
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name());
+ }
+
+ // TensorRT addScale requires input to be of rank 3, we need to apply
+ // transpose as well as reshape
+ if (channel_index != 0 || dims.nbDims != 3) {
+ nvinfer1::IShuffleLayer* shuffle_layer =
+ ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
+ nvinfer1::Dims reshape_dims;
+ reshape_dims.nbDims = 3;
+ reshape_dims.d[0] = 0; // 0 copy from the input
+ reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input
+ reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest
+ if (channel_index != 0) {
+ // maybe we do not need this check. concerned about TRT optimization
+ shuffle_layer->setFirstTranspose(permutation);
+ }
+ shuffle_layer->setReshapeDimensions(reshape_dims);
+ tensor = shuffle_layer->getOutput(0);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
@@ -1407,14 +1607,26 @@ tensorflow::Status ConvertScale(Converter& ctx,
nvinfer1::IScaleLayer* layer =
ctx.network()->addScale(*const_cast<nvinfer1::ITensor*>(tensor), mode,
weights, empty_weights, empty_weights);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- if (data_format == "NHWC") {
- // TODO(jie): transpose it back!
- output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
- } else {
- VLOG(2) << "NCHW !!!!";
+
+ // restore transpose & reshape
+ if (channel_index != 0 || dims.nbDims != 3) {
+ nvinfer1::IShuffleLayer* shuffle_layer = ctx.network()->addShuffle(
+ *const_cast<nvinfer1::ITensor*>(output_tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
+ nvinfer1::Dims reshape_dims = dims;
+ int tmp = reshape_dims.d[channel_index];
+ reshape_dims.d[channel_index] = reshape_dims.d[0];
+ reshape_dims.d[0] = tmp;
+ shuffle_layer->setReshapeDimensions(reshape_dims);
+ if (channel_index != 0) {
+ shuffle_layer->setSecondTranspose(permutation);
+ }
+ output_tensor = shuffle_layer->getOutput(0);
}
+
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1431,11 +1643,13 @@ tensorflow::Status ConvertConst(Converter& ctx,
// Create shaped weights as output
tensorflow::Tensor tensor;
- if (!tensor.FromProto(weights_tensor))
- return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
+ if (!tensor.FromProto(weights_tensor)) {
+ return tensorflow::errors::Internal("Cannot parse weight tensor proto: ",
node_def.name());
+ }
TRT_ShapedWeights weights(dtype);
+ // TODO(aaroey): we should choose the array using dtype and shape.
if (!weights_tensor.float_val().empty()) {
VLOG(2) << "SCALAR!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
@@ -1443,22 +1657,16 @@ tensorflow::Status ConvertConst(Converter& ctx,
VLOG(2) << "dimensions: " << tensor.dims();
VLOG(2) << "size: " << weights_tensor.float_val_size();
scalar_shape = GetTensorShape(tensor);
+ VLOG(2) << "details: ";
for (int i = 0; i < scalar_shape.nbDims; i++)
VLOG(2) << scalar_shape.d[i];
- if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) {
- if (weights_tensor.float_val_size() == 1 ||
- scalar_shape.d[0] == weights_tensor.float_val_size()) {
- scalar_shape.nbDims = 1;
- // no dimension provided. flatten it
- scalar_shape.d[0] = weights_tensor.float_val_size();
- scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
- } else {
- LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
- string err_str("Broadcast method is not supported for '");
- StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
- return tensorflow::errors::InvalidArgument(err_str);
- }
+ if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size() &&
+ weights_tensor.float_val_size() != 1) {
+ LOG(ERROR) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
} else {
VLOG(2) << "Dimensions: " << tensor.dims();
@@ -1468,39 +1676,42 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
scalar_shape.d[i] = 0;
- scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
+ // TODO(aaroey): use GetShapeSize().
size_t len_data = tensorflow::DataTypeSize(dtype);
for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<float> tensor_data(
- weights_tensor.float_val().begin(),
- weights_tensor.float_val()
- .end()); // make a local copy first to flatten
- memcpy(dst, tensor_data.data(), len_data); // store into weight store
+ if (weights_tensor.float_val_size() == 1) {
+ std::fill_n((float*)dst, GetShapeSize(scalar_shape),
+ *weights_tensor.float_val().begin());
+ } else {
+ // TODO(aaroey): get rid of this copy as RepeatedField is always
+ // contiguous make a local copy first to flatten doesn't have to be
+ // contiguous
+ std::vector<float> tensor_data(weights_tensor.float_val().begin(),
+ weights_tensor.float_val().end());
+ memcpy(dst, tensor_data.data(), len_data); // store into weight store
+ }
+ VLOG(2) << "create shape details: ";
+ for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i];
weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.int_val().empty()) {
+ // TODO(aaroey): this is very similar to the above code for float, merge
+ // them.
VLOG(2) << "int!!!" << node_def.name();
nvinfer1::Dims scalar_shape;
if (tensor.dims() > 0) {
VLOG(2) << "dimensions: " << tensor.dims();
scalar_shape = GetTensorShape(tensor);
- if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) {
- if (weights_tensor.int_val_size() == 1 ||
- scalar_shape.d[0] == weights_tensor.int_val_size()) {
- scalar_shape.nbDims = 1;
- // no dimension provided. flatten it
- scalar_shape.d[0] = weights_tensor.int_val_size();
- scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
- } else {
- LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
- << " kUNIFORM, at: " << node_def.name();
- string err_str("Broadcast method is not supported for '");
- StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
- return tensorflow::errors::InvalidArgument(err_str);
- }
+ if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size() &&
+ weights_tensor.int_val_size() != 1) {
+ LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and"
+ << " kUNIFORM, at: " << node_def.name();
+ string err_str("Broadcast method is not supported for '");
+ StrAppend(&err_str, node_def.name(), "' of type ", node_def.op());
+ return tensorflow::errors::InvalidArgument(err_str);
}
} else {
VLOG(2) << "dimensions: " << tensor.dims();
@@ -1513,23 +1724,30 @@ tensorflow::Status ConvertConst(Converter& ctx,
scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
}
}
- // we should not have converted //if (ctx.isFP16()) {
+ // we should not have converted
size_t len_data = tensorflow::DataTypeSize(dtype);
for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i];
size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32);
len_data = std::max(len_data, len_tensor);
ctx.weight_store()->store_.push_back(std::vector<uint8_t>(len_data));
void* dst = static_cast<void*>(&(ctx.weight_store()->store_.back()[0]));
- std::vector<int32> tensor_data(
- weights_tensor.int_val().begin(),
- weights_tensor.int_val().end()); // make a local copy first to flatten
- // doesn't have to be contigous
- memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
+ if (weights_tensor.int_val_size() == 1) {
+ std::fill_n((int*)dst, GetShapeSize(scalar_shape),
+ *weights_tensor.int_val().begin());
+ } else {
+ // TODO(aaroey): get rid of this copy as RepeatedField is always
+ // contiguous make a local copy first to flatten doesn't have to be
+ // contiguous
+ std::vector<int32> tensor_data(weights_tensor.int_val().begin(),
+ weights_tensor.int_val().end());
+ memcpy(dst, tensor_data.data(), len_tensor); // store into weight store
+ }
weights = TRT_ShapedWeights(dtype, dst, scalar_shape);
} else if (!weights_tensor.tensor_content().empty()) {
- // obsolete method.
- // After optimization path, we do not see weights in this format.
- // fp16 conversion technically should be needed here.
+ // obsolete method.
+ // After optimization path, we do not see weights in this format.
+ // TODO(aaroey): why?
+ // fp16 conversion technically should be needed here.
VLOG(2) << "TENSOR!!!" << node_def.name();
const auto& content = weights_tensor.tensor_content();
@@ -1543,8 +1761,8 @@ tensorflow::Status ConvertConst(Converter& ctx,
content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
}
} else {
- return tensorflow::errors::Unimplemented(
- "Not supported constant type, at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Not supported constant type, at ",
+ node_def.name());
}
// Pass the output
outputs->push_back(TRT_TensorOrWeights(weights));
@@ -1563,96 +1781,144 @@ tensorflow::Status ConvertBinary(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 2)
+ if (inputs.size() != 2) {
return tensorflow::errors::FailedPrecondition(
- "Binary ops require two tensor input, at " + node_def.name());
-
- if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
- return ConstantFoldBinary(ctx, node_def, inputs, outputs);
-
- if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
- return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
- inputs.at(1).weights(), outputs);
+ "Binary ops require two tensor input, at ", node_def.name());
+ }
- if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
- return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
- inputs.at(0).weights(), outputs);
+ // Constant folding should have been done by TensorFlow
- if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
- return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
- inputs.at(1).tensor(), outputs);
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
+ return tensorflow::errors::Unimplemented(
+ "Constant folding is falled back to TensorFlow, binary op received "
+ "both input as constant at: ",
+ node_def.name());
+ }
- return tensorflow::errors::Unknown("Binary op input error, at " +
- node_def.name());
+ // Try to convert into Scale layer first (for better performance)
+ // Since scale layer supports restricted broadcast policy and op types, we
+ // allow failure and try to handle it through Elementwise op
+ // (BinaryTensorOpTensor)
+ Status status = tensorflow::Status::OK();
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
+ status = BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).weights(), false, outputs);
+ } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) {
+ status = BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+ inputs.at(0).weights(), true, outputs);
+#if NV_TENSORRT_MAJOR == 3
+ } else {
+#else
+ }
+ if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
+#endif
+ status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1),
+ outputs);
+ }
+ return status;
}
tensorflow::Status ConvertUnary(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 1)
+ static const std::unordered_map<string, nvinfer1::UnaryOperation> ops{
+ {"Neg", nvinfer1::UnaryOperation::kNEG},
+ {"Exp", nvinfer1::UnaryOperation::kEXP},
+ {"Log", nvinfer1::UnaryOperation::kLOG},
+ {"Sqrt", nvinfer1::UnaryOperation::kSQRT},
+ {"Abs", nvinfer1::UnaryOperation::kABS},
+ {"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
+ };
+
+ if (inputs.size() != 1) {
return tensorflow::errors::FailedPrecondition(
- "Unary ops require single tensor input, at " + node_def.name());
+ "Unary ops require single tensor input, at ", node_def.name());
+ }
- if (inputs.at(0).is_weights())
- return ConstantFoldUnary(ctx, node_def, inputs, outputs);
- else if (inputs.at(0).is_tensor())
+#if NV_TENSORRT_MAJOR == 3
+ if (inputs.at(0).is_weights()) {
return tensorflow::errors::Unimplemented(
- "Unary op for tensor not supported, at " + node_def.name());
+ "Constant folding for unary op is not supported", node_def.name());
+ }
+#endif
- return tensorflow::errors::Unknown("Binary op input error, at " +
- node_def.name());
+ // TODO(jie): check type
+ const nvinfer1::ITensor* tensor;
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(0), inputs.at(0).shape(), &tensor),
+ node_def.name());
+
+ nvinfer1::IUnaryLayer* layer;
+ if (node_def.op() == "Rsqrt") {
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kSQRT);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ tensor = layer->getOutput(0);
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::UnaryOperation::kRECIP);
+ } else if (ops.count(node_def.op()) != 0) {
+ layer = ctx.network()->addUnary(*const_cast<nvinfer1::ITensor*>(tensor),
+ ops.at(node_def.op()));
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Binary op: ", node_def.op(), " not supported, at ", node_def.name());
+ }
+
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
}
-tensorflow::Status ConvertReduce(Converter& ctx,
- const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
+#if NV_TENSORRT_MAJOR == 3
+tensorflow::Status ConvertReducePool(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ "Input expects tensor and weights, at", node_def.name());
+ }
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
+ const auto dims = tensor->getDimensions();
// Restore implicit batch dimension
- int nb_dims = dims.nbDims + 1;
+ const int nb_dims = dims.nbDims + 1;
TRT_ShapedWeights index_list = inputs.at(1).weights();
-
TFAttrs attrs(node_def);
- // TODO(jie): handle data type.
- // Index type here is done through TF type, so I can leverage their
- // EnumToDataType for my cast
auto index_type = attrs.get<tensorflow::DataType>("Tidx");
// Only expect to handle INT32 as attributes for now
- if (index_type != tensorflow::DataType::DT_INT32)
+ if (index_type != tensorflow::DataType::DT_INT32) {
return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
- auto index_list_data =
+ }
+ const auto index_list_data =
static_cast<int*>(const_cast<void*>(index_list.GetValues()));
- // Hack warning: have to fall back to pool layer since reduce is not in public
- // TRT yet.
- if (nb_dims != 4)
+ if (nb_dims != 4) {
return tensorflow::errors::InvalidArgument(
- "TRT only support reduce on 4 dimensional tensors, at" +
+ "TRT only support reduce on 4 dimensional tensors, at",
node_def.name());
- if (index_list.count() > 2)
+ }
+ if (index_list.count() > 2) {
return tensorflow::errors::InvalidArgument(
- "TRT cannot support reduce on more than 2 dimensions, at" +
+ "TRT cannot support reduce on more than 2 dimensions, at",
node_def.name());
+ }
std::set<int> idx_set;
// We cannot operate on Channel. permutation flag used to transpose tensor
int permuted_index = -1;
for (int i = 0; i < index_list.count(); i++) {
- if (index_list_data[i] == 0)
- return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+ if (index_list_data[i] == 0) {
+ return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at",
node_def.name());
+ }
if (index_list_data[i] == 1) permuted_index = 1;
-
idx_set.emplace(index_list_data[i]);
}
@@ -1673,6 +1939,7 @@ tensorflow::Status ConvertReduce(Converter& ctx,
// Apply permutation before extracting dimension for pool_kernel
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
}
// Apply permutation before extracting dimension for pool_kernel
@@ -1685,34 +1952,104 @@ tensorflow::Status ConvertReduce(Converter& ctx,
nvinfer1::IPoolingLayer* layer =
ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
output_tensor = layer->getOutput(0);
} else {
- return tensorflow::errors::Unimplemented(
- "Op not supported " + node_def.op() + " , at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(),
+ " , at ", node_def.name());
}
if (permuted_index != -1) {
// Apply permutation before extracting dimension for pool_kernel
output_tensor = ctx.TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
+#elif NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertReduce(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights()) {
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at", node_def.name());
+ }
+
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+
+ // Only expect to handle INT32 as attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32) {
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+ }
+
+ const auto keep_dims = attrs.get<bool>("keep_dims");
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
+
+ int axes = 0;
+ if (index_list.count() == 0) {
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot support reduce on all (batch) dimensions, at",
+ node_def.name());
+ } else {
+ for (int i = 0; i < index_list.count(); i++) {
+ if (index_list_data[i] == 0) {
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot reduce at batch dimension, at", node_def.name());
+ }
+ axes |= (1 << (index_list_data[i] - 1));
+ }
+ }
+
+ nvinfer1::ReduceOperation reduce_operation;
+ if (node_def.op() == "Sum") {
+ reduce_operation = nvinfer1::ReduceOperation::kSUM;
+ } else if (node_def.op() == "Prod") {
+ reduce_operation = nvinfer1::ReduceOperation::kPROD;
+ } else if (node_def.op() == "Max") {
+ reduce_operation = nvinfer1::ReduceOperation::kMAX;
+ } else if (node_def.op() == "Min") {
+ reduce_operation = nvinfer1::ReduceOperation::kMIN;
+ } else if (node_def.op() == "Mean") {
+ reduce_operation = nvinfer1::ReduceOperation::kAVG;
+ } else {
+ return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(),
+ " , at ", node_def.name());
+ }
+
+ nvinfer1::ILayer* layer =
+ ctx.network()->addReduce(*const_cast<nvinfer1::ITensor*>(tensor),
+ reduce_operation, axes, keep_dims);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+ outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
+ return tensorflow::Status::OK();
+}
+#endif
tensorflow::Status ConvertPad(Converter& ctx,
const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
+ // TODO(aaroey): make a routine for this check and reuse it.
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
+ !inputs.at(1).is_weights()) {
return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ "Input expects tensor and weights, at", node_def.name());
+ }
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
+ const auto dims = tensor->getDimensions();
// Restore implicit batch dimension
- int nb_dims = dims.nbDims + 1;
+ const int nb_dims = dims.nbDims + 1;
TRT_ShapedWeights pads = inputs.at(1).weights();
@@ -1722,21 +2059,24 @@ tensorflow::Status ConvertPad(Converter& ctx,
auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
// TODO(jie): handle data type conversion for TRT?
- if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
+ if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) {
return tensorflow::errors::InvalidArgument(
- "Pad only supports explicit padding on 4 dimensional tensor, at " +
+ "Pad only supports explicit padding on 4 dimensional tensor, at ",
node_def.name());
+ }
// Only expect to handle INT32 as attributes for now
- if (padding_type != tensorflow::DataType::DT_INT32)
+ if (padding_type != tensorflow::DataType::DT_INT32) {
return tensorflow::errors::Unimplemented(
"Tpaddings supports only DT_INT32");
+ }
auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
std::vector<int32_t> pad_index;
for (int i = 0; i < nb_dims; i++) {
- if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+ if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) {
pad_index.push_back(i);
+ }
}
// No padding at all, we should exit
@@ -1746,20 +2086,23 @@ tensorflow::Status ConvertPad(Converter& ctx,
}
// Only supports padding on less than 2 axis GIE-2579
- if (pad_index.size() > 2)
+ if (pad_index.size() > 2) {
return tensorflow::errors::InvalidArgument(
"Padding layer does not support padding on > 2");
+ }
// Padding on batch dimension is not supported
- if (pad_index[0] == 0)
+ if (pad_index[0] == 0) {
return tensorflow::errors::InvalidArgument(
"Padding layer does not support padding on batch dimension");
+ }
// Not doing the legit thing here. ignoring padding on dim 1 and 3;
// TODO(jie): implement pad as uff parser
- if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+ if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) {
return tensorflow::errors::Unimplemented(
"Padding layer does not support padding on dimension 1 and 3 yet");
+ }
bool legit_pad = true;
nvinfer1::DimsHW pre_padding(0, 0);
@@ -1770,6 +2113,7 @@ tensorflow::Status ConvertPad(Converter& ctx,
legit_pad = false;
tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
{0, 3, 2, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name());
permuted_pad_index[0] = 3;
}
@@ -1786,11 +2130,14 @@ tensorflow::Status ConvertPad(Converter& ctx,
nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- if (!legit_pad)
+ if (!legit_pad) {
output_tensor = ctx.TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
+ }
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
@@ -1803,9 +2150,10 @@ tensorflow::Status ConvertConcat(Converter& ctx,
// not including the last input (axis) here
int input_size = static_cast<int>(inputs.size()) - 1;
- if (!inputs.at(0).is_tensor())
+ if (!inputs.at(0).is_tensor()) {
return tensorflow::errors::InvalidArgument(
- "Concat in TRT support only Tensor input, at " + node_def.name());
+ "Concat in TRT support only Tensor input, at ", node_def.name());
+ }
// We are retrieving the axis
TRT_ShapedWeights axis = inputs.at(input_size).weights();
@@ -1816,8 +2164,8 @@ tensorflow::Status ConvertConcat(Converter& ctx,
// TODO(jie): handle data type
// Only expect to handle INT32 as index attributes for now
if (index_type != tensorflow::DataType::DT_INT32)
- return tensorflow::errors::Unimplemented(
- "Tidx supports only DT_INT32, at " + node_def.name());
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ",
+ node_def.name());
int index = *(static_cast<int*>(const_cast<void*>(axis.GetValues())));
@@ -1825,23 +2173,29 @@ tensorflow::Status ConvertConcat(Converter& ctx,
auto dim = inputs.at(0).tensor()->getDimensions();
// dimension check
- if (index > dim.nbDims + 1)
+ if (index > dim.nbDims + 1) {
return tensorflow::errors::InvalidArgument(
- "Concatenate on axis out of dimension range, at " + node_def.name());
-
- if (index == 0)
+ "Concatenate on axis out of dimension range, at ", node_def.name());
+ }
+ if (index == 0) {
return tensorflow::errors::InvalidArgument(
- "Concatenate on batch dimension not supported, at " + node_def.name());
+ "Concatenate on batch dimension not supported, at ", node_def.name());
+ }
+ if (index < 0) {
+ index = dim.nbDims + index + 1;
+ }
+#if NV_TENSORRT_MAJOR == 3
// incase we need permutation;
std::vector<int> permutation_order(dim.nbDims + 1);
for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i;
if (index != 1) {
- permutation_order[1] = index - 1;
- permutation_order[index - 1] = 1;
+ permutation_order[1] = index;
+ permutation_order[index] = 1;
}
+#endif
std::vector<nvinfer1::ITensor const*> inputs_vec;
// Shap chack (all input tensor should have same shape)
@@ -1849,24 +2203,28 @@ tensorflow::Status ConvertConcat(Converter& ctx,
for (int i = 0; i < input_size; i++) {
auto tensor_i = inputs.at(i).tensor();
auto dim_i = tensor_i->getDimensions();
- if (dim_i.nbDims != dim.nbDims)
+ if (dim_i.nbDims != dim.nbDims) {
return tensorflow::errors::InvalidArgument(
- "Concatenate receives inputs with inconsistent dimensions, at " +
+ "Concatenate receives inputs with inconsistent dimensions, at ",
node_def.name());
-
+ }
for (int j = 0; j < dim.nbDims; j++) {
// check dimension consistency on non-concatenate axis
- if (j != index - 1 && dim_i.d[j] != dim.d[j])
+ if (j != index - 1 && dim_i.d[j] != dim.d[j]) {
return tensorflow::errors::InvalidArgument(
- "Concatenate receives inputs with inconsistent shape, at" +
+ "Concatenate receives inputs with inconsistent shape, at",
node_def.name());
+ }
}
- // TRT does concatenation only on channel!
- if (index != 1)
+#if NV_TENSORRT_MAJOR == 3
+ // TRT3 does concatenation only on channel!
+ if (index != 1) {
tensor_i = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor_i),
permutation_order);
-
+ TFTRT_RETURN_ERROR_IF_NULLPTR(tensor_i, node_def.name());
+ }
+#endif
inputs_vec.push_back(tensor_i);
}
@@ -1874,11 +2232,18 @@ tensorflow::Status ConvertConcat(Converter& ctx,
nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation(
const_cast<nvinfer1::ITensor* const*>(inputs_vec.data()),
inputs_vec.size());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+#if NV_TENSORRT_MAJOR > 3
+ layer->setAxis(index - 1);
+#endif
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+#if NV_TENSORRT_MAJOR == 3
if (index != 1) {
output_tensor = ctx.TransposeTensor(output_tensor, permutation_order);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name());
}
+#endif
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1997,112 +2362,243 @@ tensorflow::Status ConvertFusedBatchNorm(
combined_offset_weights.GetWeightsForTRT(),
combined_scale_weights.GetWeightsForTRT(),
dummy_power_weights.GetWeightsForTRT());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertMatMul(Converter& ctx,
- const tensorflow::NodeDef& node_def,
- const std::vector<TRT_TensorOrWeights>& inputs,
- std::vector<TRT_TensorOrWeights>* outputs) {
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
-
- // TODO(jie): transpose!
- TFAttrs attrs(node_def);
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertMatMulHelper(
+ Converter& ctx, TRT_TensorOrWeights tensor_input,
+ TRT_ShapedWeights weights_raw, bool transpose_weight, string node_name,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor* output_tensor;
+ if (!tensor_input.is_tensor()) {
+ return tensorflow::errors::InvalidArgument("Input 0 expects tensor");
+ }
+ const nvinfer1::ITensor* tensor = tensor_input.tensor();
- TRT_ShapedWeights weights_ck = inputs.at(1).weights();
- TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck);
- ReorderCKtoKC(weights_ck, &weights);
+ TRT_ShapedWeights weights(weights_raw.type_);
+ if (transpose_weight) {
+ weights = weights_raw;
+ } else {
+ TRT_ShapedWeights weights_ck = weights_raw;
+ weights = ctx.get_temp_weights_like(weights_ck);
+ ReorderCKtoKC(weights_raw, &weights);
+ }
TRT_ShapedWeights biases(weights.type_);
int noutput = weights.shape_.d[0];
+ auto input_dim = tensor->getDimensions();
+ while (input_dim.nbDims != 3) {
+ input_dim.d[input_dim.nbDims++] = 1;
+ }
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, tensor_input, input_dim, &tensor), node_name);
+
nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected(
*const_cast<nvinfer1::ITensor*>(tensor), noutput, weights, biases);
-
- nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name);
+ output_tensor = layer->getOutput(0);
+
+ const nvinfer1::ITensor* temp_tensor;
+ auto output_dim = output_tensor->getDimensions();
+ output_dim.nbDims = 1;
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, TRT_TensorOrWeights(output_tensor), output_dim,
+ &temp_tensor),
+ node_name);
+ output_tensor = const_cast<nvinfer1::ITensor*>(temp_tensor);
outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertReshape(
+// inputs are both two dimensional (tensorflow::ops::MatMul)
+tensorflow::Status ConvertMatMul(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (!inputs.at(0).is_tensor()) {
+ return tensorflow::errors::InvalidArgument("Input 0 expects tensor, at" +
+ node_def.name());
+ }
+
+ TFAttrs attrs(node_def);
+ // TODO(jie): INT32 should be converted?
+ tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
+ if (tf_dtype != tensorflow::DataType::DT_FLOAT &&
+ tf_dtype != tensorflow::DataType::DT_HALF) {
+ return tensorflow::errors::Unimplemented(
+ "data type is not supported, for node " + node_def.name() + " got " +
+ tensorflow::DataTypeString(tf_dtype));
+ }
+ bool transpose_a = attrs.get<bool>("transpose_a");
+ bool transpose_b = attrs.get<bool>("transpose_b");
+
+ // FullyConnected:
+ if (transpose_a) {
+ return tensorflow::errors::Internal(
+ "Transpose_a is not supported for TensorRT FullyConnected (op: " +
+ node_def.op() + "), at: " + node_def.name());
+ }
+ if (inputs.at(1).is_tensor()) {
+ return tensorflow::errors::Internal(
+ "Operand 1 must be constant for TensorRT FullyConnected (op: " +
+ node_def.op() + "), at: " + node_def.name());
+ }
+ return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(),
+ transpose_b, node_def.name(), outputs);
+}
+
+tensorflow::Status ConvertBatchMatMul(
Converter& ctx, const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
std::vector<TRT_TensorOrWeights>* outputs) {
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights())
- return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at" + node_def.name());
+ TFAttrs attrs(node_def);
- // implement tensor binaryOp weight [channel wise] for now;
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- auto dims = tensor->getDimensions();
- // restore implicit batch dimension
+ // TODO(jie): INT32 should be converted?
+ tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
+ if (tf_dtype != tensorflow::DataType::DT_FLOAT &&
+ tf_dtype != tensorflow::DataType::DT_HALF) {
+ return tensorflow::errors::Unimplemented(
+ "data type is not supported, for node " + node_def.name() + " got " +
+ tensorflow::DataTypeString(tf_dtype));
+ }
- TRT_ShapedWeights shape = inputs.at(1).weights();
+ bool transpose_a = attrs.get<bool>("adj_x");
+ bool transpose_b = attrs.get<bool>("adj_y");
- TFAttrs attrs(node_def);
+ auto dims = inputs.at(0).shape();
+ if (dims.nbDims == 1) { // NC * CK is only supported through fully connected
+ if (transpose_a == false && inputs.at(0).is_tensor() &&
+ inputs.at(1).is_weights()) {
+ return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(),
+ transpose_b, node_def.name(), outputs);
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Invalid configuration for MatMul, at: " + node_def.name());
+ }
+ }
- auto padding_type = attrs.get<tensorflow::DataType>("Tshape");
+ const nvinfer1::ITensor* tensor_l;
+ const nvinfer1::ITensor* tensor_r;
+ auto dims_l = inputs.at(0).shape();
+ auto dims_r = inputs.at(1).shape();
+ if (inputs.at(0).is_weights()) {
+ if (inputs.at(0).shape().d[0] != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Input 0 as weight assumes broadcast across batch for MatMul, at: " +
+ node_def.name());
+ } else {
+ for (int i = 0; i < dims_l.nbDims - 1; i++) {
+ dims_l.d[i] = dims_l.d[i + 1];
+ }
+ dims_l.nbDims--;
+ }
+ }
+ if (inputs.at(1).is_weights()) {
+ if (inputs.at(1).shape().d[0] != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Input 1 as weight assumes broadcast across batch for MatMul, at: " +
+ node_def.name());
+ } else {
+ for (int i = 0; i < dims_r.nbDims - 1; i++) {
+ dims_r.d[i] = dims_r.d[i + 1];
+ }
+ dims_r.nbDims--;
+ }
+ }
- if (shape.shape_.nbDims != 1)
- return tensorflow::errors::InvalidArgument(
- "reshape new shape is not 1 dimensional, at " + node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(0), dims_l, &tensor_l),
+ node_def.name());
+ TFTRT_RETURN_ERROR_IF_FALSE(
+ PrepareTensorForShape(ctx, inputs.at(1), dims_r, &tensor_r),
+ node_def.name());
- // Only expect to handle INT32 as attributes for now
- if (padding_type != tensorflow::DataType::DT_INT32)
- return tensorflow::errors::Unimplemented(
- "reshape new shape supports only DT_INT32, at " + node_def.name());
+ nvinfer1::IMatrixMultiplyLayer* layer = ctx.network()->addMatrixMultiply(
+ *const_cast<nvinfer1::ITensor*>(tensor_l), transpose_a,
+ *const_cast<nvinfer1::ITensor*>(tensor_r), transpose_b);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+#endif
- auto shape_data = static_cast<int*>(const_cast<void*>(shape.GetValues()));
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertSoftmax(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- if (shape_data[0] != -1)
+ int nbDims = tensor->getDimensions().nbDims;
+ if (nbDims == 0) {
return tensorflow::errors::InvalidArgument(
- "reshape new shape first dimension is not -1, at " + node_def.name());
+ "TensorRT Softmax cannot apply on batch dimension, at" +
+ node_def.name());
+ }
+ nvinfer1::ISoftMaxLayer* layer =
+ ctx.network()->addSoftMax(*const_cast<nvinfer1::ITensor*>(tensor));
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ // Tensorflow SoftMax assumes applying softmax on the last dimension.
+ layer->setAxes(1 << (nbDims - 1));
- auto shape_num_dims = shape.shape_.d[0];
- VLOG(2) << "shape dimensions: " << shape_num_dims;
- int volume_w = 1;
- for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i];
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+#endif
- int volume_t = 1;
- for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i];
+#if NV_TENSORRT_MAJOR > 3
+tensorflow::Status ConvertTopK(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
- VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w;
- if (volume_w != volume_t)
+ int nbDims = tensor->getDimensions().nbDims;
+ if (nbDims == 0) {
return tensorflow::errors::InvalidArgument(
- "volume does not agree between tensor and new shape, at " +
- node_def.name());
+ "TensorRT TopK cannot apply on batch dimension, at" + node_def.name());
+ }
- nvinfer1::IShuffleLayer* layer =
- ctx.network()->addShuffle(*const_cast<nvinfer1::ITensor*>(tensor));
+ TRT_ShapedWeights k_w = inputs.at(1).weights();
+ int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues())));
- nvinfer1::Dims reshape_dims;
- VLOG(2) << "new dimension: " << shape_num_dims - 1;
- reshape_dims.nbDims = shape_num_dims - 1;
- for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
- reshape_dims.d[i] = shape_data[i + 1];
+ nvinfer1::TopKOperation op;
+ uint32_t reducedAxes = 0;
+ if (node_def.op() == "TopKV2") {
+ op = nvinfer1::TopKOperation::kMAX;
+ reducedAxes |= 1 << (nbDims - 1);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Operation: " + node_def.op() +
+ " not implemented, at: " + node_def.name());
}
- layer->setReshapeDimensions(reshape_dims);
- VLOG(2) << "new dimension: " << shape_num_dims - 1;
- nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- auto dims_output = output_tensor->getDimensions();
- VLOG(2) << "output tensor dimension:" << dims_output.nbDims;
- outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ nvinfer1::ITopKLayer* layer = ctx.network()->addTopK(
+ *const_cast<nvinfer1::ITensor*>(tensor), op, k, reducedAxes);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+
+ nvinfer1::ITensor* output_value_tensor = layer->getOutput(0);
+ nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1);
+ outputs->push_back(TRT_TensorOrWeights(output_value_tensor));
+ outputs->push_back(TRT_TensorOrWeights(output_indices_tensor));
return tensorflow::Status::OK();
}
+#endif
void Converter::register_op_converters() {
// vgg_16 slim implementation
- op_registry_["Placeholder"] = ConvertPlaceholder;
op_registry_["Conv2D"] = ConvertConv2D;
op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
op_registry_["Relu"] = ConvertActivation;
op_registry_["MaxPool"] = ConvertPool;
op_registry_["AvgPool"] = ConvertPool;
- // This could be really handled as ConvertBinary
op_registry_["BiasAdd"] = ConvertScale;
op_registry_["Const"] = ConvertConst;
// TODO(ben,jie): this is a temp hack.
@@ -2113,17 +2609,39 @@ void Converter::register_op_converters() {
op_registry_["Add"] = ConvertBinary;
op_registry_["Mul"] = ConvertBinary;
op_registry_["Sub"] = ConvertBinary;
- op_registry_["Rsqrt"] = ConvertUnary;
- op_registry_["Mean"] = ConvertReduce;
op_registry_["Pad"] = ConvertPad;
- // TODO(ben,jie): Add more ops
op_registry_["ConcatV2"] = ConvertConcat;
- op_registry_["MatMul"] = ConvertMatMul;
- op_registry_["Reshape"] = ConvertReshape;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
+ op_registry_["Div"] = ConvertBinary;
+ op_registry_["RealDiv"] = ConvertBinary;
+
+ op_registry_["Rsqrt"] = ConvertUnary;
+ op_registry_["Reciprocal"] = ConvertUnary;
+ op_registry_["Exp"] = ConvertUnary;
+ op_registry_["Log"] = ConvertUnary;
+ op_registry_["Sqrt"] = ConvertUnary;
+ op_registry_["Abs"] = ConvertUnary;
+ op_registry_["Neg"] = ConvertUnary;
+#if NV_TENSORRT_MAJOR == 3
+ op_registry_["Mean"] = ConvertReducePool;
+#endif
+#if NV_TENSORRT_MAJOR > 3
+ op_registry_["Sum"] = ConvertReduce;
+ op_registry_["Prod"] = ConvertReduce;
+ op_registry_["Max"] = ConvertReduce;
+ op_registry_["Min"] = ConvertReduce;
+ op_registry_["Mean"] = ConvertReduce;
+ op_registry_["Maximum"] = ConvertBinary;
+ op_registry_["Minimum"] = ConvertBinary;
+ op_registry_["Softmax"] = ConvertSoftmax;
+ op_registry_["MatMul"] = ConvertMatMul;
+ op_registry_["BatchMatMul"] = ConvertBatchMatMul;
+ op_registry_["TopKV2"] = ConvertTopK;
+#endif
+
plugin_converter_ = ConvertPlugin;
}
@@ -2177,25 +2695,22 @@ tensorflow::Status ConvertGraphDefToEngine(
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
- nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- auto type_status =
- ConvertDType(node_def.attr().at("dtype").type(), &dtype);
- if (type_status != tensorflow::Status::OK()) {
- LOG(WARNING) << "Type conversion failed for " << node_name;
- return type_status;
- }
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +8= " << node_name.c_str() + 8;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kInputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
+ nvinfer1::DataType dtype;
auto shape = input_shapes.at(slot_number);
- if (shape.dims() > 8) {
- LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
- << " at input slot " << slot_number;
- return tensorflow::errors::OutOfRange(
- "Input tensor rank is greater than 8");
+ auto status = ValidateInputProperties(
+ shape, node_def.attr().at("dtype").type(), &dtype);
+ if (!status.ok()) {
+ const string error_message =
+ StrCat("Validation failed for ", node_name, " and input slot ",
+ slot_number, ": ", status.error_message());
+ LOG(WARNING) << error_message;
+ return Status(status.code(), error_message);
}
if (VLOG_IS_ON(1)) {
string dim_str("dims=");
@@ -2226,10 +2741,10 @@ tensorflow::Status ConvertGraphDefToEngine(
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
int32 slot_number = -1;
- if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
- &slot_number)) {
- LOG(ERROR) << "Failed to parse slot number from " << node_name
- << " +9=" << node_name.c_str() + 9;
+ if (!tensorflow::strings::safe_strto32(
+ node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to parse slot number from ", node_name);
}
if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1);
@@ -2288,38 +2803,20 @@ tensorflow::Status ConvertSegmentToGraphDef(
"Cannot find node with id ", connection.outside_id, " in the graph.");
}
// Updates the shape and data types of input/output connections.
- tensorflow::DataType input_type = tensorflow::DT_FLOAT;
+ tensorflow::DataType dtype;
tensorflow::PartialTensorShape partial_shape;
if (connection.is_input_edge) {
- if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
- auto output_params =
- graph_properties.GetOutputProperties(connection.outside_node_name);
- auto out_shape = output_params.at(connection.outside_port);
- input_type = out_shape.dtype();
- std::vector<tensorflow::int64> dims;
- partial_shape = out_shape.shape();
- connection.outside_shape = partial_shape;
- } else {
- VLOG(0) << "Unknown output shape" << outside_node->name();
- input_type = graph->FindNodeId(connection.outside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
-
- } else { // output edge
- if (graph_properties.HasInputProperties(connection.outside_node_name)) {
- auto input_params =
- graph_properties.GetInputProperties(connection.outside_node_name);
- auto in_shape = input_params.at(connection.outside_port);
- input_type = in_shape.dtype();
- partial_shape = in_shape.shape();
- connection.inside_shape = partial_shape;
- } else {
- input_type = graph->FindNodeId(connection.inside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
+ GetInputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
+
+ } else {
+ GetOutputProperties(graph_properties,
+ graph->FindNodeId(connection.outside_id),
+ connection.outside_port, &partial_shape, &dtype);
}
+ connection.outside_shape = partial_shape;
+ connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
@@ -2335,7 +2832,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
auto status = builder.Attr("shape", partial_shape)
- .Attr("dtype", input_type)
+ .Attr("dtype", dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
@@ -2353,7 +2850,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Identity");
- auto status = builder.Input(connection.inside_node_name, 0, input_type)
+ auto status = builder.Input(connection.inside_node_name, 0, dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
@@ -2391,6 +2888,38 @@ tensorflow::Status ConvertSegmentToGraphDef(
return tensorflow::Status::OK();
}
+bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
+ if (in_edge->IsControlEdge()) return true;
+ PartialTensorShape shape;
+ tensorflow::DataType dtype;
+ GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(),
+ &shape, &dtype);
+ nvinfer1::DataType trt_dtype;
+ Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
+ if (!status.ok()) {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << ": " << status;
+ return false;
+ }
+ if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
+ VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ << " which has an input at port " << in_edge->dst_input()
+ << " with #dim<3 and is not a const: " << shape;
+ return false;
+ }
+ return true;
+}
+
+bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
+ if (out_edge->IsControlEdge()) return true;
+ if (out_edge->src()->type_string() == "Const") {
+ VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
+ << " which is a Const.";
+ return false;
+ }
+ return true;
+}
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow