diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-01 21:20:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-01 21:23:32 -0700 |
commit | d077fb3bcc0483f6326714161bb4b3f51a078332 (patch) | |
tree | 2dd0627ce4885a09e7c9be26333e7235211f770e /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | dbdd276a05c417963b3f06f71e801540bde9ab7c (diff) |
Replace boilerplate code with function template.
PiperOrigin-RevId: 198963930
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 561 |
1 files changed, 64 insertions, 497 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 94ec7c24d4..0a57015d29 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -656,81 +656,6 @@ void ConvertRandomUniform(const NodeDef& node, model->operators.emplace_back(std::move(op)); } -void ConvertReluOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* relu = new ReluOperator; - relu->inputs.push_back(input_name); - relu->outputs.push_back(node.name()); - model->operators.emplace_back(relu); -} - -void ConvertRelu6Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu6"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new Relu6Operator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLogOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Log"); - CheckInputsCount(node, tf_import_flags, 1); - - auto op = absl::make_unique<LogOperator>(); - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(std::move(op)); -} - -void ConvertLogisticOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sigmoid"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new LogisticOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertTanhOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tanh"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new TanhOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK(node.op() == "Div" || node.op() == "RealDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new DivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertIdentityOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -787,38 +712,6 @@ void ConvertFakeQuantWithMinMaxVars( model->operators.emplace_back(op); } -void ConvertNegOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Neg"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new NegOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRsqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rsqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowRsqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertSqueezeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -840,66 +733,6 @@ void ConvertSqueezeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertSquareOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Square"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSquareOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Add"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new AddOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "AddN"); - const int num_inputs = GetInputsCount(node, tf_import_flags); - auto* op = new AddNOperator; - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mul"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MulOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSubOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sub"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new SubOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSumOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -915,67 +748,6 @@ void ConvertSumOperator(const NodeDef& node, } } -void ConvertTileOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tile"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowTileOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Slice"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new SliceOperator; - for (int i = 0; i < 3; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Pad"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new PadOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "PadV2"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new PadV2Operator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->inputs.push_back(node.input(2)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertShapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Shape"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowShapeOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSplitOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -993,18 +765,6 @@ void ConvertSplitOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertMergeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Merge"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMergeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSwitchOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1034,18 +794,6 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } -void ConvertLogSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LogSoftmax"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* log_softmax = new LogSoftmaxOperator; - log_softmax->inputs.push_back(input_name); - log_softmax->outputs.push_back(node.name()); - model->operators.emplace_back(log_softmax); -} - void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1142,17 +890,6 @@ void ConvertAvgPoolOperator(const NodeDef& node, model->operators.emplace_back(avgpool); } -void ConvertReshapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Reshape"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowReshapeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertBatchMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1215,24 +952,12 @@ void ConvertConcatOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertAllOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "All"); - auto* op = new TensorFlowAllOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAssertOperator(const NodeDef& node, +// This method supports simple operators without additional attributes. +template <typename Op> +void ConvertSimpleOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Assert"); - auto* op = new TensorFlowAssertOperator; + auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); @@ -1241,69 +966,13 @@ void ConvertAssertOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertLessOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Less"); - auto* op = new TensorFlowLessOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LessEqual"); - auto* op = new TensorFlowLessEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sin"); - auto* op = new SinOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Greater"); - auto* op = new TensorFlowGreaterOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "GreaterEqual"); - auto* op = new TensorFlowGreaterEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); +// This method supports simple operators without additional attributes. +template <typename Op, unsigned int NumInputs> +void ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, NumInputs); + ConvertSimpleOperator<Op>(node, tf_import_flags, model); } void ConvertMaxOperator(const NodeDef& node, @@ -1336,29 +1005,6 @@ void ConvertMinOperator(const NodeDef& node, } } -void ConvertMaximumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Maximum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaximumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMinimumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Minimum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinimumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertUnsupportedOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1387,19 +1033,6 @@ void ConvertUnsupportedOperator(const NodeDef& node, } } -void ConvertSelectOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 3); - - auto* op = new SelectOperator; - for (const auto& input : node.input()) { - op->inputs.push_back(input); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1678,17 +1311,6 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Exp"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new ExpOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1802,53 +1424,6 @@ void ConvertTransposeConvOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpandDimsOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ExpandDims"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new ExpandDimsOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFillOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Fill"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FillOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorDivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorModOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorMod"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorModOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertRangeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1869,17 +1444,6 @@ void ConvertRangeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertRankOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rank"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new RankOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStackOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1900,17 +1464,6 @@ void ConvertStackOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertTransposeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Transpose"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TransposeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't @@ -2174,25 +1727,26 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "BiasAdd") { ConvertBiasAddOperator(node, tf_import_flags, model); } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); + ConvertSimpleOperator<ReluOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); + ConvertSimpleOperator<Relu6Operator, 1>(node, tf_import_flags, model); } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); + ConvertSimpleOperator<LogisticOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TanhOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "MaxPool") { ConvertMaxPoolOperator(node, tf_import_flags, model); } else if (node.op() == "AvgPool") { ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowReshapeOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "BatchMatMul") { ConvertBatchMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator<DivOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "StopGradient") { ConvertIdentityOperator(node, tf_import_flags, model); @@ -2201,27 +1755,31 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "FakeQuantWithMinMaxArgs") { ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); + ConvertSimpleOperator<NegOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>(node, tf_import_flags, + model); } else if (node.op() == "Squeeze") { ConvertSqueezeOperator(node, tf_import_flags, model); } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowSqrtOperator, 1>(node, tf_import_flags, + model); } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowSquareOperator, 1>(node, tf_import_flags, + model); } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); + ConvertSimpleOperator<AddOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); + ConvertSimpleOperator<AddNOperator>(node, tf_import_flags, model); } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); + ConvertSimpleOperator<MulOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); + ConvertSimpleOperator<SubOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "Sum") { ConvertSumOperator(node, tf_import_flags, model); } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowTileOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "Concat" || node.op() == "ConcatV2") { ConvertConcatOperator(node, tf_import_flags, model); } else if (node.op() == "LRN") { @@ -2229,41 +1787,50 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); + ConvertSimpleOperator<LogOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); + ConvertSimpleOperator<LogSoftmaxOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowAllOperator>(node, tf_import_flags, model); } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowAssertOperator>(node, tf_import_flags, + model); } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowLessOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowGreaterOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>( + node, tf_import_flags, model); } else if (node.op() == "Max") { ConvertMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Min") { ConvertMinOperator(node, tf_import_flags, model); } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowMaximumOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowMinimumOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowMergeOperator, 2>(node, tf_import_flags, + model); } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); + ConvertSimpleOperator<PadOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "PadV2") { - ConvertPadV2Operator(node, tf_import_flags, model); + ConvertSimpleOperator<PadV2Operator, 3>(node, tf_import_flags, model); } else if (node.op() == "StridedSlice") { ConvertStridedSliceOperator(node, tf_import_flags, model); } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TensorFlowShapeOperator, 1>(node, tf_import_flags, + model); } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); + ConvertSimpleOperator<SliceOperator, 3>(node, tf_import_flags, model); } else if (node.op() == "Split") { ConvertSplitOperator(node, tf_import_flags, model); } else if (node.op() == "Switch") { @@ -2300,25 +1867,25 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "NextIteration") { ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); + ConvertSimpleOperator<ExpandDimsOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); + ConvertSimpleOperator<FillOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator<FloorDivOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); + ConvertSimpleOperator<FloorModOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "Range") { ConvertRangeOperator(node, tf_import_flags, model); } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); + ConvertSimpleOperator<RankOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "Stack" || node.op() == "Pack") { ConvertStackOperator(node, tf_import_flags, model); } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); + ConvertSimpleOperator<TransposeOperator, 2>(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); + ConvertSimpleOperator<ExpOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "TopK" || node.op() == "TopKV2") { ConvertTopKV2Operator(node, tf_import_flags, model); } else if (node.op() == "DynamicPartition") { @@ -2329,9 +1896,9 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "RandomUniform") { ConvertRandomUniform(node, tf_import_flags, model); } else if (node.op() == "Sin") { - ConvertSinOperator(node, tf_import_flags, model); + ConvertSimpleOperator<SinOperator, 1>(node, tf_import_flags, model); } else if (node.op() == "Select") { - ConvertSelectOperator(node, tf_import_flags, model); + ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model); } else if (node.op() == "SparseToDense") { ConvertSparseToDenseOperator(node, tf_import_flags, model); } else { |