diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 122 |
1 files changed, 57 insertions, 65 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index bc439a2feb..032c863945 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -755,6 +755,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( op->outputs.push_back(node.name()); // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -774,6 +777,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( } op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -799,22 +805,6 @@ tensorflow::Status ConvertSqueezeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertSumOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sum"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowSumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertSplitOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1052,38 +1042,6 @@ tensorflow::Status ConvertSimpleOperator( return ConvertSimpleOperator<Op>(node, tf_import_flags, model); } -tensorflow::Status ConvertMaxOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Max"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMaxOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1223,8 +1181,17 @@ tensorflow::Status ConvertGatherOperator( auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); return tensorflow::Status::OK(); @@ -1406,12 +1373,12 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertMeanOperator( +template <typename T> +tensorflow::Status ConvertReduceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Mean"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new MeanOperator; + auto* op = new T; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); @@ -1544,11 +1511,15 @@ tensorflow::Status ConvertRangeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertStackOperator( +// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but +// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" +// operators also exist, but involve manipulating the TF runtime stack, and are +// not directly related to tf.stack() usage. +tensorflow::Status ConvertPackOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK((node.op() == "Stack") || (node.op() == "Pack")); - auto* op = new StackOperator; + CHECK_EQ(node.op(), "Pack"); + auto op = absl::make_unique<PackOperator>(); const int num_inputs = GetInputsCount(node, tf_import_flags); QCHECK_GE(num_inputs, 1) << node.op() @@ -1558,10 +1529,11 @@ tensorflow::Status ConvertStackOperator( for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - // Both "Stack" and "Pack" have the "axis" attribute. + op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs; op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); op->outputs.push_back(node.name()); - model->operators.emplace_back(op); + model->operators.emplace_back(std::move(op)); return tensorflow::Status::OK(); } @@ -1607,6 +1579,24 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique<AnyOperator>(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1842,6 +1832,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, + {"Any", ConvertAnyOperator}, {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, @@ -1884,28 +1875,30 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>}, {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>}, {"Log", ConvertSimpleOperator<LogOperator, 1>}, - {"Log", ConvertSimpleOperator<LogOperator, 1>}, {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, + {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>}, + {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>}, {"MatMul", ConvertMatMulOperator}, - {"Max", ConvertMaxOperator}, + {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>}, - {"Mean", ConvertMeanOperator}, + {"Mean", ConvertReduceOperator<MeanOperator>}, {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>}, - {"Min", ConvertMinOperator}, + {"Min", ConvertReduceOperator<TensorFlowMinOperator>}, {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>}, {"Mul", ConvertSimpleOperator<MulOperator, 2>}, {"Neg", ConvertSimpleOperator<NegOperator, 1>}, {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, {"NoOp", ConvertNoOpOperator}, {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>}, - {"Pack", ConvertStackOperator}, + {"Pack", ConvertPackOperator}, {"Pad", ConvertSimpleOperator<PadOperator, 2>}, {"PadV2", ConvertSimpleOperator<PadV2Operator, 3>}, {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, {"Placeholder", ConvertPlaceholderOperator}, {"PlaceholderWithDefault", ConvertIdentityOperator}, {"Pow", ConvertSimpleOperator<PowOperator, 2>}, + {"Prod", ConvertReduceOperator<TensorFlowProdOperator>}, {"RandomUniform", ConvertRandomUniform}, {"Range", ConvertRangeOperator}, {"Rank", ConvertSimpleOperator<RankOperator, 1>}, @@ -1928,11 +1921,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>}, {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>}, {"Squeeze", ConvertSqueezeOperator}, - {"Stack", ConvertStackOperator}, {"StopGradient", ConvertIdentityOperator}, {"StridedSlice", ConvertStridedSliceOperator}, {"Sub", ConvertSimpleOperator<SubOperator, 2>}, - {"Sum", ConvertSumOperator}, + {"Sum", ConvertReduceOperator<TensorFlowSumOperator>}, {"Svdf", ConvertSvdfOperator}, {"Switch", ConvertSwitchOperator}, {"Tanh", ConvertSimpleOperator<TanhOperator, 1>}, |