diff options
author | 2018-07-17 08:02:04 -0700 | |
---|---|---|
committer | 2018-07-17 08:02:04 -0700 | |
commit | e02fbb25784498b44e73d9370da65a3f23f6de15 (patch) | |
tree | 0c24c8dca458c6f68e9c0ee5b2e0415ab22842e9 /tensorflow/contrib/tensorrt/convert | |
parent | f340242952de5c4ef2ae78c891490248e5948a1f (diff) |
Fix review comments and formatting issues.
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 17 |
2 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 3b42a5ee96..8a0e4caa9c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -49,9 +49,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT -#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT #include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index e4ffc230e4..4dee51e1e8 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -125,8 +125,8 @@ void GetInputProperties(const grappler::GraphProperties& graph_properties, void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* outside_node, const int in_port, - PartialTensorShape* shape, - tensorflow::DataType* dtype) { + PartialTensorShape* shape, + tensorflow::DataType* dtype) { if (graph_properties.HasInputProperties(outside_node->name())) { auto input_params = graph_properties.GetInputProperties(outside_node->name()); @@ -141,10 +141,11 @@ void GetOutputProperties(const grappler::GraphProperties& graph_properties, tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, const tensorflow::DataType dtype, nvinfer1::DataType* trt_dtype) { + // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so + // put them there instead. TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); if (shape.dims() < 0) { - return tensorflow::errors::InvalidArgument( - "Input tensor rank is unknown."); + return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); } if (shape.dims() > 8) { return tensorflow::errors::OutOfRange( @@ -153,7 +154,7 @@ tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, for (int d = 1; d < shape.dims(); ++d) { if (shape.dim_size(d) < 0) { return tensorflow::errors::InvalidArgument( - "Input tensor has a unknow non-batch dimemension at dim ", d); + "Input tensor has a unknown non-batch dimemension at dim ", d); } } return Status::OK(); @@ -2703,9 +2704,9 @@ tensorflow::Status ConvertGraphDefToEngine( auto status = ValidateInputProperties( shape, node_def.attr().at("dtype").type(), &dtype); if (!status.ok()) { - const string error_message = StrCat( - "Validation failed for ", node_name, " and input slot ", - slot_number, ": ", status.error_message()); + const string error_message = + StrCat("Validation failed for ", node_name, " and input slot ", + slot_number, ": ", status.error_message()); LOG(WARNING) << error_message; return Status(status.code(), error_message); } |