diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 174 |
1 files changed, 95 insertions, 79 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 55e39d963f..f36f720857 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -755,6 +755,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( op->outputs.push_back(node.name()); // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -774,6 +777,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( } op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -799,22 +805,6 @@ tensorflow::Status ConvertSqueezeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertSumOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sum"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowSumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertSplitOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -984,18 +974,19 @@ tensorflow::Status ConvertMatMulOperator( Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"), - false); - CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"), - false); CHECK(!HasAttr(node, "adjoint_a") || (GetBoolAttr(node, "adjoint_a") == false)); CHECK(!HasAttr(node, "adjoint_b") || (GetBoolAttr(node, "adjoint_b") == false)); auto* matmul = new TensorFlowMatMulOperator; + if (HasAttr(node, "transpose_a")) { + matmul->transpose_a = GetBoolAttr(node, "transpose_a"); + } + if (HasAttr(node, "transpose_b")) { + matmul->transpose_b = GetBoolAttr(node, "transpose_b"); + } + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); @@ -1051,41 +1042,14 @@ tensorflow::Status ConvertSimpleOperator( return ConvertSimpleOperator<Op>(node, tf_import_flags, model); } -tensorflow::Status ConvertMaxOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Max"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMaxOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { + // Names of special attributes in TF graph that are used by Toco. + static constexpr char kAttrOutputQuantized[] = "_output_quantized"; + static constexpr char kAttrOutputTypes[] = "_output_types"; + static constexpr char kAttrOutputShapes[] = "_output_shapes"; + LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1096,11 +1060,11 @@ tensorflow::Status ConvertUnsupportedOperator( op->tensorflow_op = node.op(); node.SerializeToString(&op->tensorflow_node_def); model->operators.emplace_back(op); - if (HasAttr(node, "_output_quantized")) { - op->quantized = GetBoolAttr(node, "_output_quantized"); + if (HasAttr(node, kAttrOutputQuantized)) { + op->quantized = GetBoolAttr(node, kAttrOutputQuantized); } - if (HasAttr(node, "_output_types")) { - const auto& output_types = GetListAttr(node, "_output_types"); + if (HasAttr(node, kAttrOutputTypes)) { + const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } @@ -1108,6 +1072,19 @@ tensorflow::Status ConvertUnsupportedOperator( const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + if (HasAttr(node, kAttrOutputShapes)) { + const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); + Shape output_shape; + for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto status = + ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, + &output_shape); + if (!status.ok()) { + return status; + } + op->output_shapes.push_back(output_shape); + } + } return tensorflow::Status::OK(); } @@ -1222,17 +1199,27 @@ tensorflow::Status ConvertGatherOperator( auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); return tensorflow::Status::OK(); } -tensorflow::Status ConvertArgMaxOperator( +template <typename Op, const char* op_name> +tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "ArgMax"); + CHECK_EQ(node.op(), op_name); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1241,7 +1228,7 @@ tensorflow::Status ConvertArgMaxOperator( : DT_INT64; CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32); CHECK(output_type == DT_INT64 || output_type == DT_INT32); - auto* op = new ArgMaxOperator; + auto* op = new Op; op->output_data_type = ConvertDataType(output_type); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1404,12 +1391,12 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertMeanOperator( +template <typename T> +tensorflow::Status ConvertReduceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Mean"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new MeanOperator; + auto* op = new T; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); @@ -1542,11 +1529,15 @@ tensorflow::Status ConvertRangeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertStackOperator( +// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but +// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" +// operators also exist, but involve manipulating the TF runtime stack, and are +// not directly related to tf.stack() usage. +tensorflow::Status ConvertPackOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK((node.op() == "Stack") || (node.op() == "Pack")); - auto* op = new StackOperator; + CHECK_EQ(node.op(), "Pack"); + auto op = absl::make_unique<PackOperator>(); const int num_inputs = GetInputsCount(node, tf_import_flags); QCHECK_GE(num_inputs, 1) << node.op() @@ -1556,10 +1547,11 @@ tensorflow::Status ConvertStackOperator( for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - // Both "Stack" and "Pack" have the "axis" attribute. + op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs; op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); op->outputs.push_back(node.name()); - model->operators.emplace_back(op); + model->operators.emplace_back(std::move(op)); return tensorflow::Status::OK(); } @@ -1605,6 +1597,24 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique<AnyOperator>(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1832,12 +1842,17 @@ using ConverterType = tensorflow::Status (*)( Model* model); using ConverterMapType = std::unordered_map<std::string, ConverterType>; +constexpr char kArgMax[] = "ArgMax"; +constexpr char kArgMin[] = "ArgMin"; + ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map<std::string, ConverterType>({ {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, - {"ArgMax", ConvertArgMaxOperator}, + {"Any", ConvertAnyOperator}, + {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, + {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, @@ -1878,28 +1893,30 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>}, {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>}, {"Log", ConvertSimpleOperator<LogOperator, 1>}, - {"Log", ConvertSimpleOperator<LogOperator, 1>}, {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, + {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>}, + {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>}, {"MatMul", ConvertMatMulOperator}, - {"Max", ConvertMaxOperator}, + {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>}, - {"Mean", ConvertMeanOperator}, + {"Mean", ConvertReduceOperator<MeanOperator>}, {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>}, - {"Min", ConvertMinOperator}, + {"Min", ConvertReduceOperator<TensorFlowMinOperator>}, {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>}, {"Mul", ConvertSimpleOperator<MulOperator, 2>}, {"Neg", ConvertSimpleOperator<NegOperator, 1>}, {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, {"NoOp", ConvertNoOpOperator}, {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>}, - {"Pack", ConvertStackOperator}, + {"Pack", ConvertPackOperator}, {"Pad", ConvertSimpleOperator<PadOperator, 2>}, {"PadV2", ConvertSimpleOperator<PadV2Operator, 3>}, {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, {"Placeholder", ConvertPlaceholderOperator}, {"PlaceholderWithDefault", ConvertIdentityOperator}, {"Pow", ConvertSimpleOperator<PowOperator, 2>}, + {"Prod", ConvertReduceOperator<TensorFlowProdOperator>}, {"RandomUniform", ConvertRandomUniform}, {"Range", ConvertRangeOperator}, {"Rank", ConvertSimpleOperator<RankOperator, 1>}, @@ -1922,11 +1939,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>}, {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>}, {"Squeeze", ConvertSqueezeOperator}, - {"Stack", ConvertStackOperator}, {"StopGradient", ConvertIdentityOperator}, {"StridedSlice", ConvertStridedSliceOperator}, {"Sub", ConvertSimpleOperator<SubOperator, 2>}, - {"Sum", ConvertSumOperator}, + {"Sum", ConvertReduceOperator<TensorFlowSumOperator>}, {"Svdf", ConvertSvdfOperator}, {"Switch", ConvertSwitchOperator}, {"Tanh", ConvertSimpleOperator<TanhOperator, 1>}, |