aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 19:18:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 19:21:07 -0700
commit87861251a5773315c7c2e36f85366c82cf64ad28 (patch)
treecd220ca6382802039caaa34fc0c9382723eaf0d9
parenta7dbbab3868437b1c4f6297dc7d6294c227a277a (diff)
Leverage the standard error space by using tensorflow::Status
PiperOrigin-RevId: 200322035
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc189
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc24
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc69
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h35
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h13
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc11
8 files changed, 171 insertions, 174 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 7ea4f32ef6..0789dc9928 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -374,6 +374,7 @@ tf_cc_test(
":toco_tooling",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -411,6 +412,7 @@ tf_cc_test(
deps = [
":model",
":tooling_util",
+ "//tensorflow/core:lib",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index a2241c85a7..120e858717 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -44,16 +43,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
-#define TOCO_RETURN_IF_ERROR(...) \
- do { \
- const ::toco::port::Status _status = (__VA_ARGS__); \
- if (!_status.ok()) return _status; \
- } while (0)
-
using tensorflow::AttrValue;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
@@ -69,8 +63,6 @@ using tensorflow::TensorShapeProto;
namespace toco {
-using port::Status;
-
namespace {
bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
@@ -136,35 +128,40 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node,
return attr.list();
}
-Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
- const string& expected_value) {
+tensorflow::Status CheckOptionalAttr(const NodeDef& node,
+ const string& attr_name,
+ const string& expected_value) {
if (HasAttr(node, attr_name)) {
const string& value = GetStringAttr(node, attr_name);
if (value != expected_value) {
- return Status(false, "Unexpected value for attribute '" + attr_name +
- "'. Expected '" + expected_value + "'");
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ expected_value + "'");
}
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status CheckOptionalAttr(const NodeDef& node, const string& attr_name,
- const tensorflow::DataType& expected_value) {
+
+tensorflow::Status CheckOptionalAttr(
+ const NodeDef& node, const string& attr_name,
+ const tensorflow::DataType& expected_value) {
if (HasAttr(node, attr_name)) {
const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
if (value != expected_value) {
- return Status(false, "Unexpected value for attribute '" + attr_name +
- "'. Expected '" +
- tensorflow::DataType_Name(expected_value) + "'");
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ tensorflow::DataType_Name(expected_value) + "'");
}
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
template <typename T1, typename T2>
-Status ExpectValue(const T1& v1, const T2& v2, const string& description) {
- if (v1 == v2) return Status::OK();
- return Status(false, absl::StrCat("Unexpected ", description, ": got ", v1,
- ", expected ", v2));
+tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
+ const string& description) {
+ if (v1 == v2) return tensorflow::Status::OK();
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Unexpected ", description, ": got ", v1, ", expected ", v2));
}
ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
@@ -185,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kNone;
}
-Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
- tensorflow::TensorShapeProto_Dim>& input_dims,
- int* input_flat_size, Shape* shape) {
+tensorflow::Status ImportShape(
+ const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
+ input_dims,
+ int* input_flat_size, Shape* shape) {
std::vector<int> input_dims_only_sizes;
for (auto& d : input_dims) {
if (d.size() == 0) {
@@ -197,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
// For now, tweaking this to record a 0-D shape instead.
shape->mutable_dims()->clear();
if (input_flat_size != nullptr) *input_flat_size = 0;
- return Status::OK();
+ return tensorflow::Status::OK();
}
// TensorFlow's shapes use int64s, while TOCO uses ints.
if (d.size() > std::numeric_limits<int>::max()) {
- return Status(false, "Shape element overflows");
+ return tensorflow::errors::InvalidArgument("Shape element overflows");
}
input_dims_only_sizes.push_back(d.size());
}
*shape->mutable_dims() = input_dims_only_sizes;
- if (input_flat_size == nullptr) return Status::OK();
+ if (input_flat_size == nullptr) return tensorflow::Status::OK();
return NumElements(input_dims_only_sizes, input_flat_size);
}
-Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -240,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_float_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(float),
") nor float_val (", input_tensor.float_val_size(),
") have the right dimensions (", input_flat_size,
") for this float tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -273,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(uint8_t),
") nor int_val (", input_tensor.int_val_size(),
") have the right dimensions (", input_flat_size,
") for this uint8 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -306,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
- absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size() / sizeof(int32),
- ") nor int_val (", input_tensor.int_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this int32 tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (",
+ input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (",
+ input_tensor.int_val_size(), ") have the right dimensions (",
+ input_flat_size, ") for this int32 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -339,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(int64),
") nor int64_val (", input_tensor.int64_val_size(),
") have the right dimensions (", input_flat_size,
") for this int64 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -380,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
// So far only encountered that in an array with 1 entry, let's
// require that until we encounter a graph where that's not the case.
if (output_bool_data.size() != 1) {
- return Status(
- false, absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size(),
- ") nor bool_val (", input_tensor.bool_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this bool tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (", input_tensor.tensor_content().size(),
+ ") nor bool_val (", input_tensor.bool_val_size(),
+ ") have the right dimensions (", input_flat_size,
+ ") for this bool tensor"));
}
output_bool_data[0] = false;
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -402,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
if (!status.ok()) return status;
if (input_flat_size != input_tensor.string_val_size()) {
- return Status(false,
- "Input_content string_val doesn't have the right dimensions "
- "for this string tensor");
+ return tensorflow::errors::InvalidArgument(
+ "Input_content string_val doesn't have the right dimensions "
+ "for this string tensor");
}
auto& output_string_data =
@@ -414,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
// Count the number of inputs of a given node. If
@@ -454,14 +452,14 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
-Status ConvertConstOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConstOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
- Status status = Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
auto& array = model->GetOrCreateArray(node.name());
switch (dtype) {
@@ -497,22 +495,21 @@ Status ConvertConstOperator(const NodeDef& node,
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
- if (!status.ok()) {
- status.AppendMessage(" (while processing node '" + node.name() + "')");
- }
- return status;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ status, " (while processing node '" + node.name() + "')");
+ return tensorflow::Status::OK();
}
-Status ConvertConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2D");
CheckInputsCount(node, tf_import_flags, 2);
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
- TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
- TOCO_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
const auto& input_name = node.input(0);
const auto& weights_name = node.input(1);
@@ -537,26 +534,25 @@ Status ConvertConvOperator(const NodeDef& node,
auto* conv = new ConvOperator;
conv->inputs = {input_name, reordered_weights_name};
conv->outputs = {node.name()};
- TOCO_RETURN_IF_ERROR(
- Status(HasAttr(node, "strides"), "Missing attribute 'strides'"));
+ if (!HasAttr(node, "strides")) {
+ return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
+ }
const auto& strides = GetListAttr(node, "strides");
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
- TOCO_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
- TOCO_RETURN_IF_ERROR(
+ TF_RETURN_IF_ERROR(
ExpectValue(dilations.i_size(), 4, "number of dilations"));
if (dilations.i(0) != 1 || dilations.i(3) != 1) {
- return Status(
- false, absl::StrCat(
- "Can only import Conv ops with dilation along the height "
- "(1st) or width (2nd) axis. TensorFlow op \"",
- node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
- dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3),
- "]."));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
}
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
@@ -570,11 +566,12 @@ Status ConvertConvOperator(const NodeDef& node,
} else if (padding == "VALID") {
conv->padding.type = PaddingType::kValid;
} else {
- return Status(false, "Bad padding (only SAME and VALID are supported)");
+ return tensorflow::errors::InvalidArgument(
+ "Bad padding (only SAME and VALID are supported)");
}
model->operators.emplace_back(conv);
- return Status::OK();
+ return tensorflow::Status::OK();
}
void ConvertDepthwiseConvOperator(const NodeDef& node,
@@ -1753,9 +1750,9 @@ void ConvertSparseToDenseOperator(const NodeDef& node,
} // namespace
namespace internal {
-Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ImportTensorFlowNode(
+ const tensorflow::NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags, Model* model) {
// TODO(ahentz): Historically these functions all CHECK-fail on error. We've
// been slowly converting them to return Status.
if (node.op() == "Const") {
@@ -1958,7 +1955,7 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
} // namespace internal
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 835676662b..d18c329a43 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -21,10 +21,10 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status.h"
namespace toco {
-using port::Status;
using tensorflow::AttrValue;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
@@ -33,6 +33,7 @@ using tensorflow::DT_INT64;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
using tensorflow::NodeDef;
+using tensorflow::Status;
namespace internal {
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
@@ -117,9 +118,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) {
NodeDef node;
BuildConstNode({1, -2, 10}, GetParam(), 0, &node);
auto status = ImportNode(node);
- EXPECT_EQ(status.error_message(),
- "Tensor shape should not include negative values (while processing "
- "node 'Node1')");
+ EXPECT_EQ(
+ status.error_message(),
+ "Tensor shape should not include negative values\n\t (while processing "
+ "node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -129,7 +131,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) {
BuildConstNode({3000000000}, GetParam(), 0, &node);
auto status = ImportNode(node);
EXPECT_EQ(status.error_message(),
- "Shape element overflows (while processing node 'Node1')");
+ "Shape element overflows\n\t (while processing node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -139,7 +141,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) {
BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node);
auto status = ImportNode(node);
EXPECT_EQ(status.error_message(),
- "Tensor shape is too large (while processing node 'Node1')");
+ "Tensor shape is too large\n\t (while processing node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -148,11 +150,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
NodeDef node;
BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
auto status = ImportNode(node);
- EXPECT_THAT(
- status.error_message(),
- ::testing::MatchesRegex(
- "Neither input_content .0. nor .*_val .0. have the right "
- "dimensions .8. for this .* tensor .while processing node 'Node1'."));
+ EXPECT_THAT(status.error_message(),
+ ::testing::MatchesRegex(
+ "Neither input_content .0. nor .*_val .0. have the right "
+ "dimensions .8. for this .* tensor\n\t .while processing "
+ "node 'Node1'."));
}
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
index a1c8696cd0..1b21c8bc60 100644
--- a/tensorflow/contrib/lite/toco/toco_port.cc
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
@@ -55,8 +57,12 @@ void CheckInitGoogleIsDone(const char* message) {
namespace file {
// Conversion to our wrapper Status.
-Status ToStatus(const ::util::Status& uts) {
- return Status(uts.ok(), uts.error_message());
+tensorflow::Status ToStatus(const ::util::Status& uts) {
+ if (!uts.ok()) {
+ return tensorflow::Status(tensorflow::errors::Code(uts.error_code()),
+ uts.error_message());
+ }
+ return tensorflow::Status::OK();
}
// Conversion to our wrapper Options.
@@ -65,7 +71,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) {
return Options();
}
-Status Writable(const string& filename) {
+tensorflow::Status Writable(const string& filename) {
File* f = nullptr;
const auto status = ::file::Open(filename, "w", &f, ::file::Defaults());
if (f) {
@@ -74,22 +80,24 @@ Status Writable(const string& filename) {
return ToStatus(status);
}
-Status Readable(const string& filename, const file::Options& options) {
+tensorflow::Status Readable(const string& filename,
+ const file::Options& options) {
return ToStatus(::file::Readable(filename, ::file::Defaults()));
}
-Status Exists(const string& filename, const file::Options& options) {
+tensorflow::Status Exists(const string& filename,
+ const file::Options& options) {
auto status = ::file::Exists(filename, ::file::Defaults());
return ToStatus(status);
}
-Status GetContents(const string& filename, string* contents,
- const file::Options& options) {
+tensorflow::Status GetContents(const string& filename, string* contents,
+ const file::Options& options) {
return ToStatus(::file::GetContents(filename, contents, ::file::Defaults()));
}
-Status SetContents(const string& filename, const string& contents,
- const file::Options& options) {
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
return ToStatus(::file::SetContents(filename, contents, ::file::Defaults()));
}
@@ -133,37 +141,42 @@ void CheckInitGoogleIsDone(const char* message) {
namespace file {
-Status Writable(const string& filename) {
+tensorflow::Status Writable(const string& filename) {
FILE* f = fopen(filename.c_str(), "w");
if (f) {
fclose(f);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
- return Status(false, "not writable");
+ return tensorflow::errors::NotFound("not writable");
}
-Status Readable(const string& filename, const file::Options& options) {
+tensorflow::Status Readable(const string& filename,
+ const file::Options& options) {
FILE* f = fopen(filename.c_str(), "r");
if (f) {
fclose(f);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
- return Status(false, "not readable");
+ return tensorflow::errors::NotFound("not readable");
}
-Status Exists(const string& filename, const file::Options& options) {
+tensorflow::Status Exists(const string& filename,
+ const file::Options& options) {
struct stat statbuf;
int ret = stat(filename.c_str(), &statbuf);
- return Status(ret != -1, "");
+ if (ret == -1) {
+ return tensorflow::errors::NotFound("file doesn't exist");
+ }
+ return tensorflow::Status::OK();
}
-Status GetContents(const string& path, string* output,
- const file::Options& options) {
+tensorflow::Status GetContents(const string& path, string* output,
+ const file::Options& options) {
output->clear();
int fd = open(path.c_str(), O_RDONLY);
if (fd == -1) {
- return Status(false, "can't open() for read");
+ return tensorflow::errors::NotFound("can't open() for read");
}
// Direct read, for speed.
@@ -174,25 +187,25 @@ Status GetContents(const string& path, string* output,
if (size == 0) {
// Done.
close(fd);
- return Status(true, "");
+ return tensorflow::Status::OK();
} else if (size == -1) {
// Error.
close(fd);
- return Status(false, "error during read()");
+ return tensorflow::errors::Internal("error during read()");
} else {
output->append(buffer, size);
}
}
CHECK(0);
- return Status(false, "internal error");
+ return tensorflow::errors::Internal("internal error");
}
-Status SetContents(const string& filename, const string& contents,
- const file::Options& options) {
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664);
if (fd == -1) {
- return Status(false, "can't open() for write");
+ return tensorflow::errors::Internal("can't open() for write");
}
size_t i = 0;
@@ -201,13 +214,13 @@ Status SetContents(const string& filename, const string& contents,
ssize_t written = write(fd, &contents[i], to_write);
if (written == -1) {
close(fd);
- return Status(false, "write() error");
+ return tensorflow::errors::Internal("write() error");
}
i += written;
}
close(fd);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
string JoinPath(const string& base, const string& filename) {
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index 906792ef56..5c019cb2bf 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "google/protobuf/text_format.h"
#include "tensorflow/contrib/lite/toco/format_port.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
@@ -36,26 +37,6 @@ limitations under the License.
namespace toco {
namespace port {
-class Status {
- public:
- static Status OK() { return Status(true, ""); }
-
- // Create a failed status with no message.
- Status() {}
-
- Status(bool ok, const string& message) : ok_(ok), message_(message) {}
-
- void AppendMessage(const string& message) { message_ += message; }
-
- bool ok() const { return ok_; }
-
- const string error_message() const { return message_; }
-
- private:
- bool ok_ = false;
- string message_;
-};
-
void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags);
void CheckInitGoogleIsDone(const char* message);
@@ -65,14 +46,14 @@ inline Options Defaults() {
Options o;
return o;
}
-Status GetContents(const string& filename, string* contents,
- const Options& options);
-Status SetContents(const string& filename, const string& contents,
- const Options& options);
+tensorflow::Status GetContents(const string& filename, string* contents,
+ const Options& options);
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const Options& options);
string JoinPath(const string& base, const string& filename);
-Status Writable(const string& filename);
-Status Readable(const string& filename, const Options& options);
-Status Exists(const string& filename, const Options& options);
+tensorflow::Status Writable(const string& filename);
+tensorflow::Status Readable(const string& filename, const Options& options);
+tensorflow::Status Exists(const string& filename, const Options& options);
} // namespace file
// Copy `src` string to `dest`. User must ensure `dest` has enough space.
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 810718f610..5cb4caab3f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 3b320e8013..ef8af4d112 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -32,8 +32,9 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
// TODO(aselle): Replace with using a container specific hash override instead.
namespace std {
@@ -315,7 +316,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output);
// doesn't have enough range to represent the sum of elements, an error is
// returned.
template <typename T, typename U>
-port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
+tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
static_assert(
std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(),
"vector type exceed capabilities of NumElements");
@@ -326,17 +327,17 @@ port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
// TensorFlow's shapes sometimes include -1 to represent an "unknown"
// size but TOCO isn't able to create arrays of unknown sizes and will
// crash in RequiredBufferSizeForShape().
- return port::Status(false,
- "Tensor shape should not include negative values");
+ return tensorflow::errors::InvalidArgument(
+ "Tensor shape should not include negative values");
}
if (static_cast<uint64_t>(dim) >
std::numeric_limits<U>::max() / *num_elements) {
*num_elements = 0;
- return port::Status(false, "Tensor shape is too large");
+ return tensorflow::errors::InvalidArgument("Tensor shape is too large");
}
*num_elements *= dim;
}
- return port::Status::OK();
+ return tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index 87fd30db2c..a683867374 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/lib/core/status.h"
namespace toco {
@@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large";
TEST(NumElementsTest, Int) {
int count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int>{1024, 1024, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) {
TEST(NumElementsTest, Int32) {
int32_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) {
TEST(NumElementsTest, Int64) {
int64_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count);
EXPECT_TRUE(status.ok());
@@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) {
TEST(NumElementsTest, UnsignedInt32) {
uint32_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) {
TEST(NumElementsTest, UnsignedInt64) {
uint64_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status =
NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count);