diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-12 19:18:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-12 19:21:07 -0700 |
commit | 87861251a5773315c7c2e36f85366c82cf64ad28 (patch) | |
tree | cd220ca6382802039caaa34fc0c9382723eaf0d9 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | a7dbbab3868437b1c4f6297dc7d6294c227a277a (diff) |
Leverage the standard error space by using tensorflow::Status
PiperOrigin-RevId: 200322035
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 189 |
1 files changed, 93 insertions, 96 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index a2241c85a7..120e858717 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -44,16 +43,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" -#define TOCO_RETURN_IF_ERROR(...) \ - do { \ - const ::toco::port::Status _status = (__VA_ARGS__); \ - if (!_status.ok()) return _status; \ - } while (0) - using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -69,8 +63,6 @@ using tensorflow::TensorShapeProto; namespace toco { -using port::Status; - namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -136,35 +128,40 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } -Status CheckOptionalAttr(const NodeDef& node, const string& attr_name, - const string& expected_value) { +tensorflow::Status CheckOptionalAttr(const NodeDef& node, + const string& attr_name, + const string& expected_value) { if (HasAttr(node, attr_name)) { const string& value = GetStringAttr(node, attr_name); if (value != expected_value) { - return Status(false, "Unexpected value for attribute '" + attr_name + - "'. Expected '" + expected_value + "'"); + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + expected_value + "'"); } } - return Status::OK(); + return tensorflow::Status::OK(); } -Status CheckOptionalAttr(const NodeDef& node, const string& attr_name, - const tensorflow::DataType& expected_value) { + +tensorflow::Status CheckOptionalAttr( + const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { if (HasAttr(node, attr_name)) { const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); if (value != expected_value) { - return Status(false, "Unexpected value for attribute '" + attr_name + - "'. Expected '" + - tensorflow::DataType_Name(expected_value) + "'"); + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); } } - return Status::OK(); + return tensorflow::Status::OK(); } template <typename T1, typename T2> -Status ExpectValue(const T1& v1, const T2& v2, const string& description) { - if (v1 == v2) return Status::OK(); - return Status(false, absl::StrCat("Unexpected ", description, ": got ", v1, - ", expected ", v2)); +tensorflow::Status ExpectValue(const T1& v1, const T2& v2, + const string& description) { + if (v1 == v2) return tensorflow::Status::OK(); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unexpected ", description, ": got ", v1, ", expected ", v2)); } ArrayDataType ConvertDataType(tensorflow::DataType dtype) { @@ -185,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - int* input_flat_size, Shape* shape) { +tensorflow::Status ImportShape( + const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>& + input_dims, + int* input_flat_size, Shape* shape) { std::vector<int> input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -197,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // For now, tweaking this to record a 0-D shape instead. shape->mutable_dims()->clear(); if (input_flat_size != nullptr) *input_flat_size = 0; - return Status::OK(); + return tensorflow::Status::OK(); } // TensorFlow's shapes use int64s, while TOCO uses ints. if (d.size() > std::numeric_limits<int>::max()) { - return Status(false, "Shape element overflows"); + return tensorflow::errors::InvalidArgument("Shape element overflows"); } input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; - if (input_flat_size == nullptr) return Status::OK(); + if (input_flat_size == nullptr) return tensorflow::Status::OK(); return NumElements(input_dims_only_sizes, input_flat_size); } -Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -240,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_float_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(float), ") nor float_val (", input_tensor.float_val_size(), ") have the right dimensions (", input_flat_size, ") for this float tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -273,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(uint8_t), ") nor int_val (", input_tensor.int_val_size(), ") have the right dimensions (", input_flat_size, ") for this uint8 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -306,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - return Status( - false, - absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size() / sizeof(int32), - ") nor int_val (", input_tensor.int_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this int32 tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (", + input_tensor.int_val_size(), ") have the right dimensions (", + input_flat_size, ") for this int32 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -339,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(int64), ") nor int64_val (", input_tensor.int64_val_size(), ") have the right dimensions (", input_flat_size, ") for this int64 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -380,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status( - false, absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size(), - ") nor bool_val (", input_tensor.bool_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this bool tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -402,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { if (!status.ok()) return status; if (input_flat_size != input_tensor.string_val_size()) { - return Status(false, - "Input_content string_val doesn't have the right dimensions " - "for this string tensor"); + return tensorflow::errors::InvalidArgument( + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); } auto& output_string_data = @@ -414,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } - return Status::OK(); + return tensorflow::Status::OK(); } // Count the number of inputs of a given node. If @@ -454,14 +452,14 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -Status ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConstOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - Status status = Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -497,22 +495,21 @@ Status ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer<ArrayDataType::kNone>(); break; } - if (!status.ok()) { - status.AppendMessage(" (while processing node '" + node.name() + "')"); - } - return status; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + status, " (while processing node '" + node.name() + "')"); + return tensorflow::Status::OK(); } -Status ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); CheckInputsCount(node, tf_import_flags, 2); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); - TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -537,26 +534,25 @@ Status ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; - TOCO_RETURN_IF_ERROR( - Status(HasAttr(node, "strides"), "Missing attribute 'strides'")); + if (!HasAttr(node, "strides")) { + return tensorflow::errors::InvalidArgument("Missing attribute 'strides'"); + } const auto& strides = GetListAttr(node, "strides"); - TOCO_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); - TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); - TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - TOCO_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR( ExpectValue(dilations.i_size(), 4, "number of dilations")); if (dilations.i(0) != 1 || dilations.i(3) != 1) { - return Status( - false, absl::StrCat( - "Can only import Conv ops with dilation along the height " - "(1st) or width (2nd) axis. TensorFlow op \"", - node.name(), "\" had dilations:[ ", dilations.i(0), ", ", - dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), - "].")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); @@ -570,11 +566,12 @@ Status ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - return Status(false, "Bad padding (only SAME and VALID are supported)"); + return tensorflow::errors::InvalidArgument( + "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); - return Status::OK(); + return tensorflow::Status::OK(); } void ConvertDepthwiseConvOperator(const NodeDef& node, @@ -1753,9 +1750,9 @@ void ConvertSparseToDenseOperator(const NodeDef& node, } // namespace namespace internal { -Status ImportTensorFlowNode(const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ImportTensorFlowNode( + const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, Model* model) { // TODO(ahentz): Historically these functions all CHECK-fail on error. We've // been slowly converting them to return Status. if (node.op() == "Const") { @@ -1958,7 +1955,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } - return Status::OK(); + return tensorflow::Status::OK(); } } // namespace internal |