aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-01 16:33:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 16:35:52 -0700
commitf5dbc1e16622f433f41f195bb33f56d674a004ce (patch)
tree8a08ec5c43192415056e0695337dd26e61256fcb
parentfb8f040f2a927c6df149238da7c4278cf781d081 (diff)
Check for overflow in shape calculation.
PiperOrigin-RevId: 195017114
-rw-r--r--tensorflow/contrib/lite/toco/BUILD12
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc505
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc160
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h5
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h29
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc81
6 files changed, 562 insertions, 230 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index f92e546ab8..f16225fd66 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -364,6 +364,18 @@ cc_library(
}),
)
+tf_cc_test(
+ name = "import_tensorflow_test",
+ srcs = ["import_tensorflow_test.cc"],
+ deps = [
+ ":toco_tooling",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
cc_library(
name = "tooling_util",
srcs = [
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index fa8b26bce0..453ff29b0d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -62,6 +62,9 @@ using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
namespace toco {
+
+using port::Status;
+
namespace {
bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
@@ -113,7 +116,7 @@ const TensorShapeProto& GetShapeAttr(const NodeDef& node,
}
const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
- CHECK(HasAttr(node, attr_name));
+ CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
const auto& attr = node.attr().at(attr_name);
CHECK_EQ(attr.value_case(), AttrValue::kTensor);
return attr.tensor();
@@ -145,9 +148,9 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kNone;
}
-void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
- tensorflow::TensorShapeProto_Dim>& input_dims,
- Shape* shape) {
+Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
+ tensorflow::TensorShapeProto_Dim>& input_dims,
+ int* input_flat_size, Shape* shape) {
std::vector<int> input_dims_only_sizes;
for (auto& d : input_dims) {
if (d.size() == 0) {
@@ -155,23 +158,33 @@ void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
// them of flat size 0 even though they have other nonzero dims.
// This breaks our invariant, that array dims can't be 0.
// For now, tweaking this to record a 0-D shape instead.
- input_dims_only_sizes.clear();
- break;
+ shape->mutable_dims()->clear();
+ if (input_flat_size != nullptr) *input_flat_size = 0;
+ return Status::OK();
+ }
+ // TensorFlow's shapes use int64s, while TOCO uses ints.
+ if (d.size() > std::numeric_limits<int>::max()) {
+ return Status(false, "Shape element overflows");
}
+
input_dims_only_sizes.push_back(d.size());
}
*shape->mutable_dims() = input_dims_only_sizes;
+
+ if (input_flat_size == nullptr) return Status::OK();
+
+ return NumElements(input_dims_only_sizes, input_flat_size);
}
-void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_float_data =
output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
@@ -189,20 +202,22 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_float_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor float_val have the right "
- "dimensions for this float tensor.";
+ return Status(false,
+ "Neither input_content nor float_val have the right "
+ "dimensions for this float tensor");
}
+ return Status::OK();
}
-void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -215,20 +230,22 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int_val have the right "
- "dimensions for this uint8 tensor.";
+ return Status(false,
+ "Neither input_content nor int_val have the right dimensions "
+ "for this uint8 tensor");
}
+ return Status::OK();
}
-void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -241,20 +258,22 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int_val have the right "
- "dimensions for this int32 tensor.";
+ return Status(false,
+ "Neither input_content nor int_val have the right dimensions "
+ "for this int32 tensor");
}
+ return Status::OK();
}
-void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
@@ -267,20 +286,22 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- LOG(FATAL) << "Neither input_content nor int64_val have the right "
- "dimensions for this int64 tensor.";
+ return Status(false,
+ "Neither input_content nor int64_val have the right "
+ "dimensions for this int64 tensor");
}
+ return Status::OK();
}
-void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_bool_data =
output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
@@ -300,20 +321,25 @@ void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
// assuming that 'false' is implied.
// So far only encountered that in an array with 1 entry, let's
// require that until we encounter a graph where that's not the case.
- CHECK_EQ(output_bool_data.size(), 1);
+ if (output_bool_data.size() != 1) {
+ return Status(false,
+ "Neither input_content nor bool_val have the right "
+ "dimensions for this bool tensor");
+ }
output_bool_data[0] = false;
}
+ return Status::OK();
}
-void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
+Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
- ImportShape(input_shape.dim(), output_array->mutable_shape());
- int input_flat_size = 1;
- for (int k = 0; k < input_shape.dim_size(); k++) {
- input_flat_size *= input_shape.dim(k).size();
- }
+ int input_flat_size;
+ auto status = ImportShape(input_shape.dim(), &input_flat_size,
+ output_array->mutable_shape());
+ if (!status.ok()) return status;
+
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
@@ -324,6 +350,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
+ return Status::OK();
}
// Count the number of inputs of a given node. If
@@ -363,38 +390,40 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
-void ConvertConstOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+Status ConvertConstOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
+ Status status = Status::OK();
+
auto& array = model->GetOrCreateArray(node.name());
switch (dtype) {
case DT_FLOAT:
array.data_type = ArrayDataType::kFloat;
- ImportFloatArray(tensor, &array);
+ status = ImportFloatArray(tensor, &array);
break;
case DT_INT32:
array.data_type = ArrayDataType::kInt32;
- ImportInt32Array(tensor, &array);
+ status = ImportInt32Array(tensor, &array);
break;
case DT_QUINT8:
array.data_type = ArrayDataType::kUint8;
- ImportQuint8Array(tensor, &array);
+ status = ImportQuint8Array(tensor, &array);
break;
case DT_INT64:
array.data_type = ArrayDataType::kInt64;
- ImportInt64Array(tensor, &array);
+ status = ImportInt64Array(tensor, &array);
break;
case DT_STRING:
array.data_type = ArrayDataType::kString;
- ImportStringArray(tensor, &array);
+ status = ImportStringArray(tensor, &array);
break;
case DT_BOOL:
array.data_type = ArrayDataType::kBool;
- ImportBoolArray(tensor, &array);
+ status = ImportBoolArray(tensor, &array);
break;
default:
array.data_type = ArrayDataType::kNone;
@@ -404,6 +433,10 @@ void ConvertConstOperator(const NodeDef& node,
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
+ if (!status.ok()) {
+ status.AppendMessage(" (while processing node '" + node.name() + "')");
+ }
+ return status;
}
void ConvertConvOperator(const NodeDef& node,
@@ -2033,6 +2066,186 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
} // namespace
+namespace internal {
+Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ // TODO(ahentz): Historically these functions all CHECK-fail on error. We've
+ // been slowly converting them to return Status.
+ if (node.op() == "Const") {
+ return ConvertConstOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Conv2D") {
+ ConvertConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Conv2DBackpropInput") {
+ ConvertTransposeConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DepthwiseConv2dNative") {
+ ConvertDepthwiseConvOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DepthToSpace") {
+ ConvertDepthToSpaceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SpaceToDepth") {
+ ConvertSpaceToDepthOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BiasAdd") {
+ ConvertBiasAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Relu") {
+ ConvertReluOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Relu6") {
+ ConvertRelu6Operator(node, tf_import_flags, model);
+ } else if (node.op() == "Sigmoid") {
+ ConvertLogisticOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Tanh") {
+ ConvertTanhOperator(node, tf_import_flags, model);
+ } else if (node.op() == "MaxPool") {
+ ConvertMaxPoolOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AvgPool") {
+ ConvertAvgPoolOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Reshape") {
+ ConvertReshapeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchMatMul") {
+ ConvertBatchMatMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "MatMul") {
+ ConvertMatMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Div" || node.op() == "RealDiv") {
+ ConvertDivOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
+ node.op() == "StopGradient") {
+ ConvertIdentityOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FakeQuantWithMinMaxVars") {
+ ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
+ } else if (node.op() == "FakeQuantWithMinMaxArgs") {
+ ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
+ } else if (node.op() == "Neg") {
+ ConvertNegOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Rsqrt") {
+ ConvertRsqrtOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Squeeze") {
+ ConvertSqueezeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sqrt") {
+ ConvertSqrtOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Square") {
+ ConvertSquareOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Add") {
+ ConvertAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AddN") {
+ ConvertAddNOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Mul") {
+ ConvertMulOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sub") {
+ ConvertSubOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Sum") {
+ ConvertSumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Tile") {
+ ConvertTileOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
+ ConvertConcatOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LRN") {
+ ConvertLRNOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Softmax") {
+ ConvertSoftmaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Log") {
+ ConvertLogOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LogSoftmax") {
+ ConvertLogSoftmaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "All") {
+ ConvertAllOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Assert") {
+ ConvertAssertOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Less") {
+ ConvertLessOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LessEqual") {
+ ConvertLessEqualOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Greater") {
+ ConvertGreaterOperator(node, tf_import_flags, model);
+ } else if (node.op() == "GreaterEqual") {
+ ConvertGreaterEqualOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Max") {
+ ConvertMaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Min") {
+ ConvertMinOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Maximum") {
+ ConvertMaximumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Minimum") {
+ ConvertMinimumOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Merge") {
+ ConvertMergeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Pad") {
+ ConvertPadOperator(node, tf_import_flags, model);
+ } else if (node.op() == "StridedSlice") {
+ ConvertStridedSliceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Shape") {
+ ConvertShapeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Slice") {
+ ConvertSliceOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Split") {
+ ConvertSplitOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Switch") {
+ ConvertSwitchOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Placeholder") {
+ ConvertPlaceholderOperator(node, tf_import_flags, model);
+ } else if (node.op() == "PlaceholderWithDefault") {
+ ConvertIdentityOperator(node, tf_import_flags, model);
+ } else if (node.op() == "LegacyFedInput") {
+ ConvertPlaceholderOperator(node, tf_import_flags, model);
+ } else if (node.op() == "NoOp") {
+ ConvertNoOpOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Cast") {
+ ConvertCastOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Floor") {
+ ConvertFloorOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Gather" || node.op() == "GatherV2") {
+ ConvertGatherOperator(node, tf_import_flags, model);
+ } else if (node.op() == "ResizeBilinear") {
+ ConvertResizeBilinearOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchNormWithGlobalNormalization") {
+ ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
+ model);
+ } else if (node.op() == "FusedBatchNorm") {
+ ConvertFusedBatchNormOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SpaceToBatchND") {
+ ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
+ } else if (node.op() == "BatchToSpaceND") {
+ ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Mean") {
+ ConvertMeanOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Svdf") {
+ ConvertSvdfOperator(node, tf_import_flags, model);
+ } else if (node.op() == "NextIteration") {
+ ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
+ } else if (node.op() == "ExpandDims") {
+ ConvertExpandDimsOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Fill") {
+ ConvertFillOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FloorDiv") {
+ ConvertFloorDivOperator(node, tf_import_flags, model);
+ } else if (node.op() == "FloorMod") {
+ ConvertFloorModOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Range") {
+ ConvertRangeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Rank") {
+ ConvertRankOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Stack" || node.op() == "Pack") {
+ ConvertStackOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Transpose") {
+ ConvertTransposeOperator(node, tf_import_flags, model);
+ } else if (node.op() == "ArgMax") {
+ ConvertArgMaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Exp") {
+ ConvertExpOperator(node, tf_import_flags, model);
+ } else if (node.op() == "TopK" || node.op() == "TopKV2") {
+ ConvertTopKV2Operator(node, tf_import_flags, model);
+ } else if (node.op() == "DynamicPartition") {
+ ConvertDynamicPartitionOperator(node, tf_import_flags, model);
+ } else if (node.op() == "DynamicStitch" ||
+ node.op() == "ParallelDynamicStitch") {
+ ConvertDynamicStitchOperator(node, tf_import_flags, model);
+ } else if (node.op() == "RandomUniform") {
+ ConvertRandomUniform(node, tf_import_flags, model);
+ } else {
+ ConvertUnsupportedOperator(node, tf_import_flags, model);
+ }
+ return Status::OK();
+}
+} // namespace internal
+
std::unique_ptr<Model> ImportTensorFlowGraphDef(
const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
const GraphDef& tf_graph) {
@@ -2058,176 +2271,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
- if (node.op() == "Const") {
- ConvertConstOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2D") {
- ConvertConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2DBackpropInput") {
- ConvertTransposeConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthwiseConv2dNative") {
- ConvertDepthwiseConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthToSpace") {
- ConvertDepthToSpaceOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToDepth") {
- ConvertSpaceToDepthOperator(node, tf_import_flags, model);
- } else if (node.op() == "BiasAdd") {
- ConvertBiasAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu") {
- ConvertReluOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu6") {
- ConvertRelu6Operator(node, tf_import_flags, model);
- } else if (node.op() == "Sigmoid") {
- ConvertLogisticOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tanh") {
- ConvertTanhOperator(node, tf_import_flags, model);
- } else if (node.op() == "MaxPool") {
- ConvertMaxPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "AvgPool") {
- ConvertAvgPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "Reshape") {
- ConvertReshapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchMatMul") {
- ConvertBatchMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "MatMul") {
- ConvertMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Div" || node.op() == "RealDiv") {
- ConvertDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
- node.op() == "StopGradient") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxVars") {
- ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxArgs") {
- ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
- } else if (node.op() == "Neg") {
- ConvertNegOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rsqrt") {
- ConvertRsqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Squeeze") {
- ConvertSqueezeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sqrt") {
- ConvertSqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Square") {
- ConvertSquareOperator(node, tf_import_flags, model);
- } else if (node.op() == "Add") {
- ConvertAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "AddN") {
- ConvertAddNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mul") {
- ConvertMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sub") {
- ConvertSubOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sum") {
- ConvertSumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tile") {
- ConvertTileOperator(node, tf_import_flags, model);
- } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
- ConvertConcatOperator(node, tf_import_flags, model);
- } else if (node.op() == "LRN") {
- ConvertLRNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Softmax") {
- ConvertSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Log") {
- ConvertLogOperator(node, tf_import_flags, model);
- } else if (node.op() == "LogSoftmax") {
- ConvertLogSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "All") {
- ConvertAllOperator(node, tf_import_flags, model);
- } else if (node.op() == "Assert") {
- ConvertAssertOperator(node, tf_import_flags, model);
- } else if (node.op() == "Less") {
- ConvertLessOperator(node, tf_import_flags, model);
- } else if (node.op() == "LessEqual") {
- ConvertLessEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Greater") {
- ConvertGreaterOperator(node, tf_import_flags, model);
- } else if (node.op() == "GreaterEqual") {
- ConvertGreaterEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Max") {
- ConvertMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Min") {
- ConvertMinOperator(node, tf_import_flags, model);
- } else if (node.op() == "Maximum") {
- ConvertMaximumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Minimum") {
- ConvertMinimumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Merge") {
- ConvertMergeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Pad") {
- ConvertPadOperator(node, tf_import_flags, model);
- } else if (node.op() == "StridedSlice") {
- ConvertStridedSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Shape") {
- ConvertShapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Slice") {
- ConvertSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Split") {
- ConvertSplitOperator(node, tf_import_flags, model);
- } else if (node.op() == "Switch") {
- ConvertSwitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "Placeholder") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "PlaceholderWithDefault") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "LegacyFedInput") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "NoOp") {
- ConvertNoOpOperator(node, tf_import_flags, model);
- } else if (node.op() == "Cast") {
- ConvertCastOperator(node, tf_import_flags, model);
- } else if (node.op() == "Floor") {
- ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather" || node.op() == "GatherV2") {
- ConvertGatherOperator(node, tf_import_flags, model);
- } else if (node.op() == "ResizeBilinear") {
- ConvertResizeBilinearOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchNormWithGlobalNormalization") {
- ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
- model);
- } else if (node.op() == "FusedBatchNorm") {
- ConvertFusedBatchNormOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToBatchND") {
- ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchToSpaceND") {
- ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mean") {
- ConvertMeanOperator(node, tf_import_flags, model);
- } else if (node.op() == "Svdf") {
- ConvertSvdfOperator(node, tf_import_flags, model);
- } else if (node.op() == "NextIteration") {
- ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
- } else if (node.op() == "ExpandDims") {
- ConvertExpandDimsOperator(node, tf_import_flags, model);
- } else if (node.op() == "Fill") {
- ConvertFillOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorDiv") {
- ConvertFloorDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorMod") {
- ConvertFloorModOperator(node, tf_import_flags, model);
- } else if (node.op() == "Range") {
- ConvertRangeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rank") {
- ConvertRankOperator(node, tf_import_flags, model);
- } else if (node.op() == "Stack" || node.op() == "Pack") {
- ConvertStackOperator(node, tf_import_flags, model);
- } else if (node.op() == "Transpose") {
- ConvertTransposeOperator(node, tf_import_flags, model);
- } else if (node.op() == "ArgMax") {
- ConvertArgMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Exp") {
- ConvertExpOperator(node, tf_import_flags, model);
- } else if (node.op() == "TopK" || node.op() == "TopKV2") {
- ConvertTopKV2Operator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicPartition") {
- ConvertDynamicPartitionOperator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicStitch" ||
- node.op() == "ParallelDynamicStitch") {
- ConvertDynamicStitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "RandomUniform") {
- ConvertRandomUniform(node, tf_import_flags, model);
- } else {
- ConvertUnsupportedOperator(node, tf_import_flags, model);
- }
+ auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model);
+ CHECK(status.ok()) << status.error_message();
}
ResolveModelFlags(model_flags, model);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
new file mode 100644
index 0000000000..5dc78f73ad
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/import_tensorflow.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+namespace toco {
+
+using port::Status;
+using tensorflow::AttrValue;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_QUINT8;
+using tensorflow::DT_STRING;
+using tensorflow::NodeDef;
+
+namespace internal {
+Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
+ Model*);
+} // namespace internal
+
+namespace {
+
+class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
+ protected:
+ ShapeImportTest() {}
+
+ void BuildConstNode(std::initializer_list<int64_t> shape,
+ tensorflow::DataType dtype, int64_t num_elements,
+ NodeDef* node) {
+ node->set_op("Const");
+ node->set_name("Node1");
+
+ // An attribute describing the type of this const node.
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["dtype"] = dtype_attr;
+
+ // An attribute describing the content of this const node.
+ tensorflow::TensorProto t;
+ t.set_dtype(dtype);
+ auto* s = t.mutable_tensor_shape();
+ for (auto d : shape) {
+ s->add_dim()->set_size(d);
+ }
+
+ // TODO(ahentz): also need to test via tensor_content()
+ switch (dtype) {
+ case DT_FLOAT:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_float_val(i / 10000.0);
+ }
+ break;
+ case DT_INT32:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int_val(i % std::numeric_limits<int>::max());
+ }
+ break;
+ case DT_QUINT8:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int_val(i % std::numeric_limits<uint8_t>::max());
+ }
+ break;
+ case DT_INT64:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_int64_val(i);
+ }
+ break;
+ case DT_STRING:
+ break;
+ case DT_BOOL:
+ for (int64_t i = 0; i < num_elements; ++i) {
+ t.add_bool_val(i % 2);
+ }
+ break;
+ default:
+ break;
+ }
+
+ AttrValue value_attr;
+ SetAttrValue(t, &value_attr);
+ (*node->mutable_attr())["value"] = value_attr;
+ }
+
+ Status ImportNode(const NodeDef& node) {
+ Model model;
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
+ &model);
+ }
+};
+
+std::vector<tensorflow::DataType> TestTypes() {
+ return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8};
+}
+
+TEST_P(ShapeImportTest, ShapeElementIsNegative) {
+ NodeDef node;
+ BuildConstNode({1, -2, 10}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Tensor shape should not include negative values (while processing "
+ "node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ShapeElementTooLarge) {
+ NodeDef node;
+ BuildConstNode({3000000000}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Shape element overflows (while processing node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ShapeTooLarge) {
+ NodeDef node;
+ BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_EQ(status.error_message(),
+ "Tensor shape is too large (while processing node 'Node1')");
+}
+INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
+ NodeDef node;
+ BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
+ auto status = ImportNode(node);
+ EXPECT_THAT(status.error_message(),
+ ::testing::MatchesRegex(
+ "Neither input_content nor .*_val have the right dimensions "
+ "for this .* tensor .while processing node 'Node1'."));
+}
+INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
+ ::testing::ValuesIn(TestTypes()));
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index 2d5c231bef..906792ef56 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -38,10 +38,15 @@ namespace port {
class Status {
public:
+ static Status OK() { return Status(true, ""); }
+
+ // Create a failed status with no message.
Status() {}
Status(bool ok, const string& message) : ok_(ok), message_(message) {}
+ void AppendMessage(const string& message) { message_ += message; }
+
bool ok() const { return ok_; }
const string error_message() const { return message_; }
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 5cc15fa57b..f5b596df0f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -294,6 +294,35 @@ void FinishBuildingRNNStates(Model* model);
void UseArraysExtraInfo(Model* model, bool quantize_output);
+// Calculates the number of elements in tensor given a shape. Shape elements
+// are assumed to be of type T, while the result total is of type U. If U
+// doesn't have enough range to represent the sum of elements, an error is
+// returned.
+template <typename T, typename U>
+port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
+ static_assert(
+ std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(),
+ "vector type exceed capabilities of NumElements");
+
+ *num_elements = 1;
+ for (const T& dim : shape) {
+ if (dim < 0) {
+ // TensorFlow's shapes sometimes include -1 to represent an "unknown"
+ // size but TOCO isn't able to create arrays of unknown sizes and will
+ // crash in RequiredBufferSizeForShape().
+ return port::Status(false,
+ "Tensor shape should not include negative values");
+ }
+ if (static_cast<uint64_t>(dim) >
+ std::numeric_limits<U>::max() / *num_elements) {
+ *num_elements = 0;
+ return port::Status(false, "Tensor shape is too large");
+ }
+ *num_elements *= dim;
+ }
+ return port::Status::OK();
+}
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index 22955ce956..87fd30db2c 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -93,4 +93,85 @@ TEST_P(ShapeTest, Agrees) {
INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest,
::testing::ValuesIn(CreateShapePairs()));
+static const char kNegativeValuesMessage[] =
+ "Tensor shape should not include negative values";
+static const char kLargeTensorMessage[] = "Tensor shape is too large";
+
+TEST(NumElementsTest, Int) {
+ int count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int>{1024, 1024, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 2146435072);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int>{1024, 1024, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, Int32) {
+ int32_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 2146435072);
+
+ status = NumElements(std::vector<int32_t>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int32_t>{1024, 1024, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, Int64) {
+ int64_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 9223090561878065152LL);
+
+ status = NumElements(std::vector<int64_t>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<int64_t>{16777216, 16777216, 32768}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, UnsignedInt32) {
+ uint32_t count;
+ port::Status status = port::Status::OK();
+
+ status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 4292870144);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status = NumElements(std::vector<uint32_t>{1024, 2048, 2048}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
+TEST(NumElementsTest, UnsignedInt64) {
+ uint64_t count;
+ port::Status status = port::Status::OK();
+
+ status =
+ NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 18446462598732840960ULL);
+
+ status = NumElements(std::vector<int>{1, 2, -3}, &count);
+ EXPECT_EQ(status.error_message(), kNegativeValuesMessage);
+
+ status =
+ NumElements(std::vector<uint64_t>{16777216, 16777216, 65536}, &count);
+ EXPECT_EQ(status.error_message(), kLargeTensorMessage);
+}
+
} // namespace toco