diff options
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r-- | tensorflow/contrib/lite/toco/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 505 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 160 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_port.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.h | 29 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util_test.cc | 81 |
6 files changed, 562 insertions, 230 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index f92e546ab8..f16225fd66 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -364,6 +364,18 @@ cc_library( }), ) +tf_cc_test( + name = "import_tensorflow_test", + srcs = ["import_tensorflow_test.cc"], + deps = [ + ":toco_tooling", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "tooling_util", srcs = [ diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index fa8b26bce0..453ff29b0d 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -62,6 +62,9 @@ using tensorflow::TensorProto; using tensorflow::TensorShapeProto; namespace toco { + +using port::Status; + namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -113,7 +116,7 @@ const TensorShapeProto& GetShapeAttr(const NodeDef& node, } const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) { - CHECK(HasAttr(node, attr_name)); + CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'"; const auto& attr = node.attr().at(attr_name); CHECK_EQ(attr.value_case(), AttrValue::kTensor); return attr.tensor(); @@ -145,9 +148,9 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - Shape* shape) { +Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< + tensorflow::TensorShapeProto_Dim>& input_dims, + int* input_flat_size, Shape* shape) { std::vector<int> input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -155,23 +158,33 @@ void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // them of flat size 0 even though they have other nonzero dims. // This breaks our invariant, that array dims can't be 0. // For now, tweaking this to record a 0-D shape instead. - input_dims_only_sizes.clear(); - break; + shape->mutable_dims()->clear(); + if (input_flat_size != nullptr) *input_flat_size = 0; + return Status::OK(); + } + // TensorFlow's shapes use int64s, while TOCO uses ints. + if (d.size() > std::numeric_limits<int>::max()) { + return Status(false, "Shape element overflows"); } + input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; + + if (input_flat_size == nullptr) return Status::OK(); + + return NumElements(input_dims_only_sizes, input_flat_size); } -void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_float_data = output_array->GetMutableBuffer<ArrayDataType::kFloat>().data; output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()), @@ -189,20 +202,22 @@ void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_float_data.data())); } else { - LOG(FATAL) << "Neither input_content nor float_val have the right " - "dimensions for this float tensor."; + return Status(false, + "Neither input_content nor float_val have the right " + "dimensions for this float tensor"); } + return Status::OK(); } -void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kUint8>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); @@ -215,20 +230,22 @@ void ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - LOG(FATAL) << "Neither input_content nor int_val have the right " - "dimensions for this uint8 tensor."; + return Status(false, + "Neither input_content nor int_val have the right dimensions " + "for this uint8 tensor"); } + return Status::OK(); } -void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kInt32>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); @@ -241,20 +258,22 @@ void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - LOG(FATAL) << "Neither input_content nor int_val have the right " - "dimensions for this int32 tensor."; + return Status(false, + "Neither input_content nor int_val have the right dimensions " + "for this int32 tensor"); } + return Status::OK(); } -void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_int_data = output_array->GetMutableBuffer<ArrayDataType::kInt64>().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); @@ -267,20 +286,22 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast<char*>(output_int_data.data())); } else { - LOG(FATAL) << "Neither input_content nor int64_val have the right " - "dimensions for this int64 tensor."; + return Status(false, + "Neither input_content nor int64_val have the right " + "dimensions for this int64 tensor"); } + return Status::OK(); } -void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_bool_data = output_array->GetMutableBuffer<ArrayDataType::kBool>().data; output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), @@ -300,20 +321,25 @@ void ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // assuming that 'false' is implied. // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. - CHECK_EQ(output_bool_data.size(), 1); + if (output_bool_data.size() != 1) { + return Status(false, + "Neither input_content nor bool_val have the right " + "dimensions for this bool tensor"); + } output_bool_data[0] = false; } + return Status::OK(); } -void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); - ImportShape(input_shape.dim(), output_array->mutable_shape()); - int input_flat_size = 1; - for (int k = 0; k < input_shape.dim_size(); k++) { - input_flat_size *= input_shape.dim(k).size(); - } + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + auto& output_string_data = output_array->GetMutableBuffer<ArrayDataType::kString>().data; output_string_data.resize(RequiredBufferSizeForShape(output_array->shape())); @@ -324,6 +350,7 @@ void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } + return Status::OK(); } // Count the number of inputs of a given node. If @@ -363,38 +390,40 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -void ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +Status ConvertConstOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); + Status status = Status::OK(); + auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { case DT_FLOAT: array.data_type = ArrayDataType::kFloat; - ImportFloatArray(tensor, &array); + status = ImportFloatArray(tensor, &array); break; case DT_INT32: array.data_type = ArrayDataType::kInt32; - ImportInt32Array(tensor, &array); + status = ImportInt32Array(tensor, &array); break; case DT_QUINT8: array.data_type = ArrayDataType::kUint8; - ImportQuint8Array(tensor, &array); + status = ImportQuint8Array(tensor, &array); break; case DT_INT64: array.data_type = ArrayDataType::kInt64; - ImportInt64Array(tensor, &array); + status = ImportInt64Array(tensor, &array); break; case DT_STRING: array.data_type = ArrayDataType::kString; - ImportStringArray(tensor, &array); + status = ImportStringArray(tensor, &array); break; case DT_BOOL: array.data_type = ArrayDataType::kBool; - ImportBoolArray(tensor, &array); + status = ImportBoolArray(tensor, &array); break; default: array.data_type = ArrayDataType::kNone; @@ -404,6 +433,10 @@ void ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer<ArrayDataType::kNone>(); break; } + if (!status.ok()) { + status.AppendMessage(" (while processing node '" + node.name() + "')"); + } + return status; } void ConvertConvOperator(const NodeDef& node, @@ -2033,6 +2066,186 @@ void ConvertDynamicStitchOperator(const NodeDef& node, } // namespace +namespace internal { +Status ImportTensorFlowNode(const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + // TODO(ahentz): Historically these functions all CHECK-fail on error. We've + // been slowly converting them to return Status. + if (node.op() == "Const") { + return ConvertConstOperator(node, tf_import_flags, model); + } else if (node.op() == "Conv2D") { + ConvertConvOperator(node, tf_import_flags, model); + } else if (node.op() == "Conv2DBackpropInput") { + ConvertTransposeConvOperator(node, tf_import_flags, model); + } else if (node.op() == "DepthwiseConv2dNative") { + ConvertDepthwiseConvOperator(node, tf_import_flags, model); + } else if (node.op() == "DepthToSpace") { + ConvertDepthToSpaceOperator(node, tf_import_flags, model); + } else if (node.op() == "SpaceToDepth") { + ConvertSpaceToDepthOperator(node, tf_import_flags, model); + } else if (node.op() == "BiasAdd") { + ConvertBiasAddOperator(node, tf_import_flags, model); + } else if (node.op() == "Relu") { + ConvertReluOperator(node, tf_import_flags, model); + } else if (node.op() == "Relu6") { + ConvertRelu6Operator(node, tf_import_flags, model); + } else if (node.op() == "Sigmoid") { + ConvertLogisticOperator(node, tf_import_flags, model); + } else if (node.op() == "Tanh") { + ConvertTanhOperator(node, tf_import_flags, model); + } else if (node.op() == "MaxPool") { + ConvertMaxPoolOperator(node, tf_import_flags, model); + } else if (node.op() == "AvgPool") { + ConvertAvgPoolOperator(node, tf_import_flags, model); + } else if (node.op() == "Reshape") { + ConvertReshapeOperator(node, tf_import_flags, model); + } else if (node.op() == "BatchMatMul") { + ConvertBatchMatMulOperator(node, tf_import_flags, model); + } else if (node.op() == "MatMul") { + ConvertMatMulOperator(node, tf_import_flags, model); + } else if (node.op() == "Div" || node.op() == "RealDiv") { + ConvertDivOperator(node, tf_import_flags, model); + } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || + node.op() == "StopGradient") { + ConvertIdentityOperator(node, tf_import_flags, model); + } else if (node.op() == "FakeQuantWithMinMaxVars") { + ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); + } else if (node.op() == "FakeQuantWithMinMaxArgs") { + ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); + } else if (node.op() == "Neg") { + ConvertNegOperator(node, tf_import_flags, model); + } else if (node.op() == "Rsqrt") { + ConvertRsqrtOperator(node, tf_import_flags, model); + } else if (node.op() == "Squeeze") { + ConvertSqueezeOperator(node, tf_import_flags, model); + } else if (node.op() == "Sqrt") { + ConvertSqrtOperator(node, tf_import_flags, model); + } else if (node.op() == "Square") { + ConvertSquareOperator(node, tf_import_flags, model); + } else if (node.op() == "Add") { + ConvertAddOperator(node, tf_import_flags, model); + } else if (node.op() == "AddN") { + ConvertAddNOperator(node, tf_import_flags, model); + } else if (node.op() == "Mul") { + ConvertMulOperator(node, tf_import_flags, model); + } else if (node.op() == "Sub") { + ConvertSubOperator(node, tf_import_flags, model); + } else if (node.op() == "Sum") { + ConvertSumOperator(node, tf_import_flags, model); + } else if (node.op() == "Tile") { + ConvertTileOperator(node, tf_import_flags, model); + } else if (node.op() == "Concat" || node.op() == "ConcatV2") { + ConvertConcatOperator(node, tf_import_flags, model); + } else if (node.op() == "LRN") { + ConvertLRNOperator(node, tf_import_flags, model); + } else if (node.op() == "Softmax") { + ConvertSoftmaxOperator(node, tf_import_flags, model); + } else if (node.op() == "Log") { + ConvertLogOperator(node, tf_import_flags, model); + } else if (node.op() == "LogSoftmax") { + ConvertLogSoftmaxOperator(node, tf_import_flags, model); + } else if (node.op() == "All") { + ConvertAllOperator(node, tf_import_flags, model); + } else if (node.op() == "Assert") { + ConvertAssertOperator(node, tf_import_flags, model); + } else if (node.op() == "Less") { + ConvertLessOperator(node, tf_import_flags, model); + } else if (node.op() == "LessEqual") { + ConvertLessEqualOperator(node, tf_import_flags, model); + } else if (node.op() == "Greater") { + ConvertGreaterOperator(node, tf_import_flags, model); + } else if (node.op() == "GreaterEqual") { + ConvertGreaterEqualOperator(node, tf_import_flags, model); + } else if (node.op() == "Max") { + ConvertMaxOperator(node, tf_import_flags, model); + } else if (node.op() == "Min") { + ConvertMinOperator(node, tf_import_flags, model); + } else if (node.op() == "Maximum") { + ConvertMaximumOperator(node, tf_import_flags, model); + } else if (node.op() == "Minimum") { + ConvertMinimumOperator(node, tf_import_flags, model); + } else if (node.op() == "Merge") { + ConvertMergeOperator(node, tf_import_flags, model); + } else if (node.op() == "Pad") { + ConvertPadOperator(node, tf_import_flags, model); + } else if (node.op() == "StridedSlice") { + ConvertStridedSliceOperator(node, tf_import_flags, model); + } else if (node.op() == "Shape") { + ConvertShapeOperator(node, tf_import_flags, model); + } else if (node.op() == "Slice") { + ConvertSliceOperator(node, tf_import_flags, model); + } else if (node.op() == "Split") { + ConvertSplitOperator(node, tf_import_flags, model); + } else if (node.op() == "Switch") { + ConvertSwitchOperator(node, tf_import_flags, model); + } else if (node.op() == "Placeholder") { + ConvertPlaceholderOperator(node, tf_import_flags, model); + } else if (node.op() == "PlaceholderWithDefault") { + ConvertIdentityOperator(node, tf_import_flags, model); + } else if (node.op() == "LegacyFedInput") { + ConvertPlaceholderOperator(node, tf_import_flags, model); + } else if (node.op() == "NoOp") { + ConvertNoOpOperator(node, tf_import_flags, model); + } else if (node.op() == "Cast") { + ConvertCastOperator(node, tf_import_flags, model); + } else if (node.op() == "Floor") { + ConvertFloorOperator(node, tf_import_flags, model); + } else if (node.op() == "Gather" || node.op() == "GatherV2") { + ConvertGatherOperator(node, tf_import_flags, model); + } else if (node.op() == "ResizeBilinear") { + ConvertResizeBilinearOperator(node, tf_import_flags, model); + } else if (node.op() == "BatchNormWithGlobalNormalization") { + ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, + model); + } else if (node.op() == "FusedBatchNorm") { + ConvertFusedBatchNormOperator(node, tf_import_flags, model); + } else if (node.op() == "SpaceToBatchND") { + ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); + } else if (node.op() == "BatchToSpaceND") { + ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); + } else if (node.op() == "Mean") { + ConvertMeanOperator(node, tf_import_flags, model); + } else if (node.op() == "Svdf") { + ConvertSvdfOperator(node, tf_import_flags, model); + } else if (node.op() == "NextIteration") { + ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); + } else if (node.op() == "ExpandDims") { + ConvertExpandDimsOperator(node, tf_import_flags, model); + } else if (node.op() == "Fill") { + ConvertFillOperator(node, tf_import_flags, model); + } else if (node.op() == "FloorDiv") { + ConvertFloorDivOperator(node, tf_import_flags, model); + } else if (node.op() == "FloorMod") { + ConvertFloorModOperator(node, tf_import_flags, model); + } else if (node.op() == "Range") { + ConvertRangeOperator(node, tf_import_flags, model); + } else if (node.op() == "Rank") { + ConvertRankOperator(node, tf_import_flags, model); + } else if (node.op() == "Stack" || node.op() == "Pack") { + ConvertStackOperator(node, tf_import_flags, model); + } else if (node.op() == "Transpose") { + ConvertTransposeOperator(node, tf_import_flags, model); + } else if (node.op() == "ArgMax") { + ConvertArgMaxOperator(node, tf_import_flags, model); + } else if (node.op() == "Exp") { + ConvertExpOperator(node, tf_import_flags, model); + } else if (node.op() == "TopK" || node.op() == "TopKV2") { + ConvertTopKV2Operator(node, tf_import_flags, model); + } else if (node.op() == "DynamicPartition") { + ConvertDynamicPartitionOperator(node, tf_import_flags, model); + } else if (node.op() == "DynamicStitch" || + node.op() == "ParallelDynamicStitch") { + ConvertDynamicStitchOperator(node, tf_import_flags, model); + } else if (node.op() == "RandomUniform") { + ConvertRandomUniform(node, tf_import_flags, model); + } else { + ConvertUnsupportedOperator(node, tf_import_flags, model); + } + return Status::OK(); +} +} // namespace internal + std::unique_ptr<Model> ImportTensorFlowGraphDef( const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags, const GraphDef& tf_graph) { @@ -2058,176 +2271,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - if (node.op() == "Const") { - ConvertConstOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2DBackpropInput") { - ConvertTransposeConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthwiseConv2dNative") { - ConvertDepthwiseConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthToSpace") { - ConvertDepthToSpaceOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToDepth") { - ConvertSpaceToDepthOperator(node, tf_import_flags, model); - } else if (node.op() == "BiasAdd") { - ConvertBiasAddOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); - } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); - } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); - } else if (node.op() == "MaxPool") { - ConvertMaxPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "AvgPool") { - ConvertAvgPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchMatMul") { - ConvertBatchMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul") { - ConvertMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); - } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || - node.op() == "StopGradient") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxVars") { - ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxArgs") { - ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); - } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); - } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); - } else if (node.op() == "Squeeze") { - ConvertSqueezeOperator(node, tf_import_flags, model); - } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); - } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); - } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); - } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); - } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); - } else if (node.op() == "Sum") { - ConvertSumOperator(node, tf_import_flags, model); - } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); - } else if (node.op() == "Concat" || node.op() == "ConcatV2") { - ConvertConcatOperator(node, tf_import_flags, model); - } else if (node.op() == "LRN") { - ConvertLRNOperator(node, tf_import_flags, model); - } else if (node.op() == "Softmax") { - ConvertSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); - } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); - } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); - } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); - } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); - } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); - } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); - } else if (node.op() == "Max") { - ConvertMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Min") { - ConvertMinOperator(node, tf_import_flags, model); - } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); - } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); - } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); - } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); - } else if (node.op() == "StridedSlice") { - ConvertStridedSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); - } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Split") { - ConvertSplitOperator(node, tf_import_flags, model); - } else if (node.op() == "Switch") { - ConvertSwitchOperator(node, tf_import_flags, model); - } else if (node.op() == "Placeholder") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "PlaceholderWithDefault") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "LegacyFedInput") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "NoOp") { - ConvertNoOpOperator(node, tf_import_flags, model); - } else if (node.op() == "Cast") { - ConvertCastOperator(node, tf_import_flags, model); - } else if (node.op() == "Floor") { - ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather" || node.op() == "GatherV2") { - ConvertGatherOperator(node, tf_import_flags, model); - } else if (node.op() == "ResizeBilinear") { - ConvertResizeBilinearOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchNormWithGlobalNormalization") { - ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, - model); - } else if (node.op() == "FusedBatchNorm") { - ConvertFusedBatchNormOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToBatchND") { - ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchToSpaceND") { - ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); - } else if (node.op() == "Mean") { - ConvertMeanOperator(node, tf_import_flags, model); - } else if (node.op() == "Svdf") { - ConvertSvdfOperator(node, tf_import_flags, model); - } else if (node.op() == "NextIteration") { - ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); - } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); - } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); - } else if (node.op() == "Range") { - ConvertRangeOperator(node, tf_import_flags, model); - } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); - } else if (node.op() == "Stack" || node.op() == "Pack") { - ConvertStackOperator(node, tf_import_flags, model); - } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); - } else if (node.op() == "ArgMax") { - ConvertArgMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); - } else if (node.op() == "TopK" || node.op() == "TopKV2") { - ConvertTopKV2Operator(node, tf_import_flags, model); - } else if (node.op() == "DynamicPartition") { - ConvertDynamicPartitionOperator(node, tf_import_flags, model); - } else if (node.op() == "DynamicStitch" || - node.op() == "ParallelDynamicStitch") { - ConvertDynamicStitchOperator(node, tf_import_flags, model); - } else if (node.op() == "RandomUniform") { - ConvertRandomUniform(node, tf_import_flags, model); - } else { - ConvertUnsupportedOperator(node, tf_import_flags, model); - } + auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model); + CHECK(status.ok()) << status.error_message(); } ResolveModelFlags(model_flags, model); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc new file mode 100644 index 0000000000..5dc78f73ad --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/import_tensorflow.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +using port::Status; +using tensorflow::AttrValue; +using tensorflow::DT_BOOL; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::DT_QUINT8; +using tensorflow::DT_STRING; +using tensorflow::NodeDef; + +namespace internal { +Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, + Model*); +} // namespace internal + +namespace { + +class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { + protected: + ShapeImportTest() {} + + void BuildConstNode(std::initializer_list<int64_t> shape, + tensorflow::DataType dtype, int64_t num_elements, + NodeDef* node) { + node->set_op("Const"); + node->set_name("Node1"); + + // An attribute describing the type of this const node. + AttrValue dtype_attr; + SetAttrValue(dtype, &dtype_attr); + (*node->mutable_attr())["dtype"] = dtype_attr; + + // An attribute describing the content of this const node. + tensorflow::TensorProto t; + t.set_dtype(dtype); + auto* s = t.mutable_tensor_shape(); + for (auto d : shape) { + s->add_dim()->set_size(d); + } + + // TODO(ahentz): also need to test via tensor_content() + switch (dtype) { + case DT_FLOAT: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_float_val(i / 10000.0); + } + break; + case DT_INT32: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int_val(i % std::numeric_limits<int>::max()); + } + break; + case DT_QUINT8: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int_val(i % std::numeric_limits<uint8_t>::max()); + } + break; + case DT_INT64: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int64_val(i); + } + break; + case DT_STRING: + break; + case DT_BOOL: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_bool_val(i % 2); + } + break; + default: + break; + } + + AttrValue value_attr; + SetAttrValue(t, &value_attr); + (*node->mutable_attr())["value"] = value_attr; + } + + Status ImportNode(const NodeDef& node) { + Model model; + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), + &model); + } +}; + +std::vector<tensorflow::DataType> TestTypes() { + return {DT_FLOAT, DT_INT32, DT_INT64, DT_BOOL, DT_QUINT8}; +} + +TEST_P(ShapeImportTest, ShapeElementIsNegative) { + NodeDef node; + BuildConstNode({1, -2, 10}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Tensor shape should not include negative values (while processing " + "node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ShapeElementTooLarge) { + NodeDef node; + BuildConstNode({3000000000}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Shape element overflows (while processing node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ShapeTooLarge) { + NodeDef node; + BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_EQ(status.error_message(), + "Tensor shape is too large (while processing node 'Node1')"); +} +INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +TEST_P(ShapeImportTest, ValidShapeButZeroElements) { + NodeDef node; + BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); + auto status = ImportNode(node); + EXPECT_THAT(status.error_message(), + ::testing::MatchesRegex( + "Neither input_content nor .*_val have the right dimensions " + "for this .* tensor .while processing node 'Node1'.")); +} +INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, + ::testing::ValuesIn(TestTypes())); + +} // namespace +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index 2d5c231bef..906792ef56 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -38,10 +38,15 @@ namespace port { class Status { public: + static Status OK() { return Status(true, ""); } + + // Create a failed status with no message. Status() {} Status(bool ok, const string& message) : ok_(ok), message_(message) {} + void AppendMessage(const string& message) { message_ += message; } + bool ok() const { return ok_; } const string error_message() const { return message_; } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 5cc15fa57b..f5b596df0f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -294,6 +294,35 @@ void FinishBuildingRNNStates(Model* model); void UseArraysExtraInfo(Model* model, bool quantize_output); +// Calculates the number of elements in tensor given a shape. Shape elements +// are assumed to be of type T, while the result total is of type U. If U +// doesn't have enough range to represent the sum of elements, an error is +// returned. +template <typename T, typename U> +port::Status NumElements(const std::vector<T>& shape, U* num_elements) { + static_assert( + std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(), + "vector type exceed capabilities of NumElements"); + + *num_elements = 1; + for (const T& dim : shape) { + if (dim < 0) { + // TensorFlow's shapes sometimes include -1 to represent an "unknown" + // size but TOCO isn't able to create arrays of unknown sizes and will + // crash in RequiredBufferSizeForShape(). + return port::Status(false, + "Tensor shape should not include negative values"); + } + if (static_cast<uint64_t>(dim) > + std::numeric_limits<U>::max() / *num_elements) { + *num_elements = 0; + return port::Status(false, "Tensor shape is too large"); + } + *num_elements *= dim; + } + return port::Status::OK(); +} + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 22955ce956..87fd30db2c 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -93,4 +93,85 @@ TEST_P(ShapeTest, Agrees) { INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest, ::testing::ValuesIn(CreateShapePairs())); +static const char kNegativeValuesMessage[] = + "Tensor shape should not include negative values"; +static const char kLargeTensorMessage[] = "Tensor shape is too large"; + +TEST(NumElementsTest, Int) { + int count; + port::Status status = port::Status::OK(); + + status = NumElements(std::vector<int>{1024, 1024, 2047}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 2146435072); + + status = NumElements(std::vector<int>{1, 2, -3}, &count); + EXPECT_EQ(status.error_message(), kNegativeValuesMessage); + + status = NumElements(std::vector<int>{1024, 1024, 2048}, &count); + EXPECT_EQ(status.error_message(), kLargeTensorMessage); +} + +TEST(NumElementsTest, Int32) { + int32_t count; + port::Status status = port::Status::OK(); + + status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 2146435072); + + status = NumElements(std::vector<int32_t>{1, 2, -3}, &count); + EXPECT_EQ(status.error_message(), kNegativeValuesMessage); + + status = NumElements(std::vector<int32_t>{1024, 1024, 2048}, &count); + EXPECT_EQ(status.error_message(), kLargeTensorMessage); +} + +TEST(NumElementsTest, Int64) { + int64_t count; + port::Status status = port::Status::OK(); + + status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 9223090561878065152LL); + + status = NumElements(std::vector<int64_t>{1, 2, -3}, &count); + EXPECT_EQ(status.error_message(), kNegativeValuesMessage); + + status = NumElements(std::vector<int64_t>{16777216, 16777216, 32768}, &count); + EXPECT_EQ(status.error_message(), kLargeTensorMessage); +} + +TEST(NumElementsTest, UnsignedInt32) { + uint32_t count; + port::Status status = port::Status::OK(); + + status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 4292870144); + + status = NumElements(std::vector<int>{1, 2, -3}, &count); + EXPECT_EQ(status.error_message(), kNegativeValuesMessage); + + status = NumElements(std::vector<uint32_t>{1024, 2048, 2048}, &count); + EXPECT_EQ(status.error_message(), kLargeTensorMessage); +} + +TEST(NumElementsTest, UnsignedInt64) { + uint64_t count; + port::Status status = port::Status::OK(); + + status = + NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 18446462598732840960ULL); + + status = NumElements(std::vector<int>{1, 2, -3}, &count); + EXPECT_EQ(status.error_message(), kNegativeValuesMessage); + + status = + NumElements(std::vector<uint64_t>{16777216, 16777216, 65536}, &count); + EXPECT_EQ(status.error_message(), kLargeTensorMessage); +} + } // namespace toco |