aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 21:20:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 21:23:32 -0700
commitd077fb3bcc0483f6326714161bb4b3f51a078332 (patch)
tree2dd0627ce4885a09e7c9be26333e7235211f770e /tensorflow/contrib/lite/toco/import_tensorflow.cc
parentdbdd276a05c417963b3f06f71e801540bde9ab7c (diff)
Replace boilerplate code with function template.
PiperOrigin-RevId: 198963930
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc561
1 files changed, 64 insertions, 497 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 94ec7c24d4..0a57015d29 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -656,81 +656,6 @@ void ConvertRandomUniform(const NodeDef& node,
model->operators.emplace_back(std::move(op));
}
-void ConvertReluOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Relu");
- CheckInputsCount(node, tf_import_flags, 1);
- const auto& input_name = node.input(0);
- auto* relu = new ReluOperator;
- relu->inputs.push_back(input_name);
- relu->outputs.push_back(node.name());
- model->operators.emplace_back(relu);
-}
-
-void ConvertRelu6Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Relu6");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new Relu6Operator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertLogOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Log");
- CheckInputsCount(node, tf_import_flags, 1);
-
- auto op = absl::make_unique<LogOperator>();
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(std::move(op));
-}
-
-void ConvertLogisticOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sigmoid");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new LogisticOperator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertTanhOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Tanh");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new TanhOperator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertDivOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK(node.op() == "Div" || node.op() == "RealDiv");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new DivOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertIdentityOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -787,38 +712,6 @@ void ConvertFakeQuantWithMinMaxVars(
model->operators.emplace_back(op);
}
-void ConvertNegOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Neg");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new NegOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertRsqrtOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Rsqrt");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowRsqrtOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSqrtOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sqrt");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowSqrtOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
void ConvertSqueezeOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
@@ -840,66 +733,6 @@ void ConvertSqueezeOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertSquareOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Square");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowSquareOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAddOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Add");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new AddOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAddNOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "AddN");
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- auto* op = new AddNOperator;
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Mul");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new MulOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSubOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sub");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new SubOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertSumOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -915,67 +748,6 @@ void ConvertSumOperator(const NodeDef& node,
}
}
-void ConvertTileOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Tile");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowTileOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSliceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Slice");
- CheckInputsCount(node, tf_import_flags, 3);
- auto* op = new SliceOperator;
- for (int i = 0; i < 3; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertPadOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Pad");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new PadOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertPadV2Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "PadV2");
- CheckInputsCount(node, tf_import_flags, 3);
- auto* op = new PadV2Operator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->inputs.push_back(node.input(2));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertShapeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Shape");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowShapeOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertSplitOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -993,18 +765,6 @@ void ConvertSplitOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertMergeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Merge");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMergeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertSwitchOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1034,18 +794,6 @@ void ConvertSoftmaxOperator(const NodeDef& node,
model->operators.emplace_back(softmax);
}
-void ConvertLogSoftmaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "LogSoftmax");
- CheckInputsCount(node, tf_import_flags, 1);
- const auto& input_name = node.input(0);
- auto* log_softmax = new LogSoftmaxOperator;
- log_softmax->inputs.push_back(input_name);
- log_softmax->outputs.push_back(node.name());
- model->operators.emplace_back(log_softmax);
-}
-
void ConvertLRNOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1142,17 +890,6 @@ void ConvertAvgPoolOperator(const NodeDef& node,
model->operators.emplace_back(avgpool);
}
-void ConvertReshapeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Reshape");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowReshapeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
void ConvertBatchMatMulOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
@@ -1215,24 +952,12 @@ void ConvertConcatOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertAllOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "All");
- auto* op = new TensorFlowAllOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAssertOperator(const NodeDef& node,
+// This method supports simple operators without additional attributes.
+template <typename Op>
+void ConvertSimpleOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "Assert");
- auto* op = new TensorFlowAssertOperator;
+ auto* op = new Op;
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
@@ -1241,69 +966,13 @@ void ConvertAssertOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertLessOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Less");
- auto* op = new TensorFlowLessOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertLessEqualOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "LessEqual");
- auto* op = new TensorFlowLessEqualOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSinOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sin");
- auto* op = new SinOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertGreaterOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Greater");
- auto* op = new TensorFlowGreaterOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertGreaterEqualOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "GreaterEqual");
- auto* op = new TensorFlowGreaterEqualOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
+// This method supports simple operators without additional attributes.
+template <typename Op, unsigned int NumInputs>
+void ConvertSimpleOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CheckInputsCount(node, tf_import_flags, NumInputs);
+ ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
void ConvertMaxOperator(const NodeDef& node,
@@ -1336,29 +1005,6 @@ void ConvertMinOperator(const NodeDef& node,
}
}
-void ConvertMaximumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Maximum");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMaximumOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertMinimumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Minimum");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMinimumOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
void ConvertUnsupportedOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
@@ -1387,19 +1033,6 @@ void ConvertUnsupportedOperator(const NodeDef& node,
}
}
-void ConvertSelectOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CheckInputsCount(node, tf_import_flags, 3);
-
- auto* op = new SelectOperator;
- for (const auto& input : node.input()) {
- op->inputs.push_back(input);
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertStridedSliceOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1678,17 +1311,6 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertExpOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Exp");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new ExpOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertMeanOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1802,53 +1424,6 @@ void ConvertTransposeConvOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertExpandDimsOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "ExpandDims");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new ExpandDimsOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFillOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Fill");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FillOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFloorDivOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "FloorDiv");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FloorDivOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFloorModOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "FloorMod");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FloorModOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
void ConvertRangeOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
@@ -1869,17 +1444,6 @@ void ConvertRangeOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertRankOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Rank");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new RankOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
void ConvertStackOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1900,17 +1464,6 @@ void ConvertStackOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
-void ConvertTransposeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Transpose");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TransposeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
// Some TensorFlow ops only occur in graph cycles, representing
// control flow. We do not currently support control flow, so we wouldn't
@@ -2174,25 +1727,26 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else if (node.op() == "BiasAdd") {
ConvertBiasAddOperator(node, tf_import_flags, model);
} else if (node.op() == "Relu") {
- ConvertReluOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<ReluOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Relu6") {
- ConvertRelu6Operator(node, tf_import_flags, model);
+ ConvertSimpleOperator<Relu6Operator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Sigmoid") {
- ConvertLogisticOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<LogisticOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Tanh") {
- ConvertTanhOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TanhOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "MaxPool") {
ConvertMaxPoolOperator(node, tf_import_flags, model);
} else if (node.op() == "AvgPool") {
ConvertAvgPoolOperator(node, tf_import_flags, model);
} else if (node.op() == "Reshape") {
- ConvertReshapeOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowReshapeOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "BatchMatMul") {
ConvertBatchMatMulOperator(node, tf_import_flags, model);
} else if (node.op() == "MatMul") {
ConvertMatMulOperator(node, tf_import_flags, model);
} else if (node.op() == "Div" || node.op() == "RealDiv") {
- ConvertDivOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<DivOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "StopGradient") {
ConvertIdentityOperator(node, tf_import_flags, model);
@@ -2201,27 +1755,31 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else if (node.op() == "FakeQuantWithMinMaxArgs") {
ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
} else if (node.op() == "Neg") {
- ConvertNegOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<NegOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Rsqrt") {
- ConvertRsqrtOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>(node, tf_import_flags,
+ model);
} else if (node.op() == "Squeeze") {
ConvertSqueezeOperator(node, tf_import_flags, model);
} else if (node.op() == "Sqrt") {
- ConvertSqrtOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowSqrtOperator, 1>(node, tf_import_flags,
+ model);
} else if (node.op() == "Square") {
- ConvertSquareOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowSquareOperator, 1>(node, tf_import_flags,
+ model);
} else if (node.op() == "Add") {
- ConvertAddOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<AddOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "AddN") {
- ConvertAddNOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<AddNOperator>(node, tf_import_flags, model);
} else if (node.op() == "Mul") {
- ConvertMulOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<MulOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "Sub") {
- ConvertSubOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<SubOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "Sum") {
ConvertSumOperator(node, tf_import_flags, model);
} else if (node.op() == "Tile") {
- ConvertTileOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowTileOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "Concat" || node.op() == "ConcatV2") {
ConvertConcatOperator(node, tf_import_flags, model);
} else if (node.op() == "LRN") {
@@ -2229,41 +1787,50 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else if (node.op() == "Softmax") {
ConvertSoftmaxOperator(node, tf_import_flags, model);
} else if (node.op() == "Log") {
- ConvertLogOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<LogOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "LogSoftmax") {
- ConvertLogSoftmaxOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<LogSoftmaxOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "All") {
- ConvertAllOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowAllOperator>(node, tf_import_flags, model);
} else if (node.op() == "Assert") {
- ConvertAssertOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowAssertOperator>(node, tf_import_flags,
+ model);
} else if (node.op() == "Less") {
- ConvertLessOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowLessOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "LessEqual") {
- ConvertLessEqualOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "Greater") {
- ConvertGreaterOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowGreaterOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "GreaterEqual") {
- ConvertGreaterEqualOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>(
+ node, tf_import_flags, model);
} else if (node.op() == "Max") {
ConvertMaxOperator(node, tf_import_flags, model);
} else if (node.op() == "Min") {
ConvertMinOperator(node, tf_import_flags, model);
} else if (node.op() == "Maximum") {
- ConvertMaximumOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowMaximumOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "Minimum") {
- ConvertMinimumOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowMinimumOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "Merge") {
- ConvertMergeOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowMergeOperator, 2>(node, tf_import_flags,
+ model);
} else if (node.op() == "Pad") {
- ConvertPadOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<PadOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "PadV2") {
- ConvertPadV2Operator(node, tf_import_flags, model);
+ ConvertSimpleOperator<PadV2Operator, 3>(node, tf_import_flags, model);
} else if (node.op() == "StridedSlice") {
ConvertStridedSliceOperator(node, tf_import_flags, model);
} else if (node.op() == "Shape") {
- ConvertShapeOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TensorFlowShapeOperator, 1>(node, tf_import_flags,
+ model);
} else if (node.op() == "Slice") {
- ConvertSliceOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<SliceOperator, 3>(node, tf_import_flags, model);
} else if (node.op() == "Split") {
ConvertSplitOperator(node, tf_import_flags, model);
} else if (node.op() == "Switch") {
@@ -2300,25 +1867,25 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else if (node.op() == "NextIteration") {
ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
} else if (node.op() == "ExpandDims") {
- ConvertExpandDimsOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<ExpandDimsOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "Fill") {
- ConvertFillOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<FillOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "FloorDiv") {
- ConvertFloorDivOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<FloorDivOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "FloorMod") {
- ConvertFloorModOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<FloorModOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "Range") {
ConvertRangeOperator(node, tf_import_flags, model);
} else if (node.op() == "Rank") {
- ConvertRankOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<RankOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Stack" || node.op() == "Pack") {
ConvertStackOperator(node, tf_import_flags, model);
} else if (node.op() == "Transpose") {
- ConvertTransposeOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<TransposeOperator, 2>(node, tf_import_flags, model);
} else if (node.op() == "ArgMax") {
ConvertArgMaxOperator(node, tf_import_flags, model);
} else if (node.op() == "Exp") {
- ConvertExpOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<ExpOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "TopK" || node.op() == "TopKV2") {
ConvertTopKV2Operator(node, tf_import_flags, model);
} else if (node.op() == "DynamicPartition") {
@@ -2329,9 +1896,9 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
} else if (node.op() == "RandomUniform") {
ConvertRandomUniform(node, tf_import_flags, model);
} else if (node.op() == "Sin") {
- ConvertSinOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<SinOperator, 1>(node, tf_import_flags, model);
} else if (node.op() == "Select") {
- ConvertSelectOperator(node, tf_import_flags, model);
+ ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model);
} else if (node.op() == "SparseToDense") {
ConvertSparseToDenseOperator(node, tf_import_flags, model);
} else {