aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 19:18:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 19:21:07 -0700
commit87861251a5773315c7c2e36f85366c82cf64ad28 (patch)
treecd220ca6382802039caaa34fc0c9382723eaf0d9 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parenta7dbbab3868437b1c4f6297dc7d6294c227a277a (diff)
Leverage the standard error space by using tensorflow::Status
PiperOrigin-RevId: 200322035
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc189
1 files changed, 93 insertions, 96 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index a2241c85a7..120e858717 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -44,16 +43,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
-#define TOCO_RETURN_IF_ERROR(...) \
- do { \
- const ::toco::port::Status _status = (__VA_ARGS__); \
- if (!_status.ok()) return _status; \
- } while (0)
-
using tensorflow::AttrValue;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
@@ -69,8 +63,6 @@ using tensorflow::TensorShapeProto;
namespace toco {
-using port::Status;
-
namespace {
bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
@@ -136,35 +128,40 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node,
return attr.list();
}
-Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
- const string& expected_value) {
+tensorflow::Status CheckOptionalAttr(const NodeDef& node,
+ const string& attr_name,
+ const string& expected_value) {
if (HasAttr(node, attr_name)) {
const string& value = GetStringAttr(node, attr_name);
if (value != expected_value) {
- return Status(false, "Unexpected value for attribute '" + attr_name +
- "'. Expected '" + expected_value + "'");
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ expected_value + "'");
}
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
- const tensorflow::DataType& expected_value) {
+
+tensorflow::Status CheckOptionalAttr(
+ const NodeDef& node, const string& attr_name,
+ const tensorflow::DataType& expected_value) {
if (HasAttr(node, attr_name)) {
const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
if (value != expected_value) {
- return Status(false, "Unexpected value for attribute '" + attr_name +
- "'. Expected '" +
- tensorflow::DataType_Name(expected_value) + "'");
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ tensorflow::DataType_Name(expected_value) + "'");
}
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
template <typename T1, typename T2>
-Status ExpectValue(const T1& v1, const T2& v2, const string& description) {
- if (v1 == v2) return Status::OK();
- return Status(false, absl::StrCat("Unexpected ", description, ": got ", v1,
- ", expected ", v2));
+tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
+ const string& description) {
+ if (v1 == v2) return tensorflow::Status::OK();
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Unexpected ", description, ": got ", v1, ", expected ", v2));
}
ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
@@ -185,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kNone;
}
-Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
- tensorflow::TensorShapeProto_Dim>& input_dims,
- int* input_flat_size, Shape* shape) {
+tensorflow::Status ImportShape(
+ const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
+ input_dims,
+ int* input_flat_size, Shape* shape) {
std::vector<int> input_dims_only_sizes;
for (auto& d : input_dims) {
if (d.size() == 0) {
@@ -197,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
// For now, tweaking this to record a 0-D shape instead.
shape->mutable_dims()->clear();
if (input_flat_size != nullptr) *input_flat_size = 0;
- return Status::OK();
+ return tensorflow::Status::OK();
}
// TensorFlow's shapes use int64s, while TOCO uses ints.
if (d.size() > std::numeric_limits<int>::max()) {
- return Status(false, "Shape element overflows");
+ return tensorflow::errors::InvalidArgument("Shape element overflows");
}
input_dims_only_sizes.push_back(d.size());
}
*shape->mutable_dims() = input_dims_only_sizes;
- if (input_flat_size == nullptr) return Status::OK();
+ if (input_flat_size == nullptr) return tensorflow::Status::OK();
return NumElements(input_dims_only_sizes, input_flat_size);
}
-Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -240,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_float_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(float),
") nor float_val (", input_tensor.float_val_size(),
") have the right dimensions (", input_flat_size,
") for this float tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -273,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(uint8_t),
") nor int_val (", input_tensor.int_val_size(),
") have the right dimensions (", input_flat_size,
") for this uint8 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -306,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
- absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size() / sizeof(int32),
- ") nor int_val (", input_tensor.int_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this int32 tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (",
+ input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (",
+ input_tensor.int_val_size(), ") have the right dimensions (",
+ input_flat_size, ") for this int32 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -339,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(int64),
") nor int64_val (", input_tensor.int64_val_size(),
") have the right dimensions (", input_flat_size,
") for this int64 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -380,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
// So far only encountered that in an array with 1 entry, let's
// require that until we encounter a graph where that's not the case.
if (output_bool_data.size() != 1) {
- return Status(
- false, absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size(),
- ") nor bool_val (", input_tensor.bool_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this bool tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (", input_tensor.tensor_content().size(),
+ ") nor bool_val (", input_tensor.bool_val_size(),
+ ") have the right dimensions (", input_flat_size,
+ ") for this bool tensor"));
}
output_bool_data[0] = false;
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -402,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
if (!status.ok()) return status;
if (input_flat_size != input_tensor.string_val_size()) {
- return Status(false,
- "Input_content string_val doesn't have the right dimensions "
- "for this string tensor");
+ return tensorflow::errors::InvalidArgument(
+ "Input_content string_val doesn't have the right dimensions "
+ "for this string tensor");
}
auto& output_string_data =
@@ -414,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
// Count the number of inputs of a given node. If
@@ -454,14 +452,14 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
-Status ConvertConstOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConstOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
- Status status = Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
auto& array = model->GetOrCreateArray(node.name());
switch (dtype) {
@@ -497,22 +495,21 @@ Status ConvertConstOperator(const NodeDef& node,
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
- if (!status.ok()) {
- status.AppendMessage(" (while processing node '" + node.name() + "')");
- }
- return status;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ status, " (while processing node '" + node.name() + "')");
+ return tensorflow::Status::OK();
}
-Status ConvertConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2D");
CheckInputsCount(node, tf_import_flags, 2);
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
- TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
- TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
const auto& input_name = node.input(0);
const auto& weights_name = node.input(1);
@@ -537,26 +534,25 @@ Status ConvertConvOperator(const NodeDef& node,
auto* conv = new ConvOperator;
conv->inputs = {input_name, reordered_weights_name};
conv->outputs = {node.name()};
- TOCO_RETURN_IF_ERROR(
- Status(HasAttr(node, "strides"), "Missing attribute 'strides'"));
+ if (!HasAttr(node, "strides")) {
+ return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
+ }
const auto& strides = GetListAttr(node, "strides");
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
- TOCO_RETURN_IF_ERROR(
+ TF_RETURN_IF_ERROR(
ExpectValue(dilations.i_size(), 4, "number of dilations"));
if (dilations.i(0) != 1 || dilations.i(3) != 1) {
- return Status(
- false, absl::StrCat(
- "Can only import Conv ops with dilation along the height "
- "(1st) or width (2nd) axis. TensorFlow op \"",
- node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
- dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3),
- "]."));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
}
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
@@ -570,11 +566,12 @@ Status ConvertConvOperator(const NodeDef& node,
} else if (padding == "VALID") {
conv->padding.type = PaddingType::kValid;
} else {
- return Status(false, "Bad padding (only SAME and VALID are supported)");
+ return tensorflow::errors::InvalidArgument(
+ "Bad padding (only SAME and VALID are supported)");
}
model->operators.emplace_back(conv);
- return Status::OK();
+ return tensorflow::Status::OK();
}
void ConvertDepthwiseConvOperator(const NodeDef& node,
@@ -1753,9 +1750,9 @@ void ConvertSparseToDenseOperator(const NodeDef& node,
} // namespace
namespace internal {
-Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ImportTensorFlowNode(
+ const tensorflow::NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags, Model* model) {
// TODO(ahentz): Historically these functions all CHECK-fail on error. We've
// been slowly converting them to return Status.
if (node.op() == "Const") {
@@ -1958,7 +1955,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
} // namespace internal