aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-17 05:31:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-17 05:34:37 -0700
commit5cb77a7ac4741df72e1739c4fda3f552afc9c47c (patch)
tree872a7673e5969140cd26bca56dfa5ccfed60cfab
parent17d3bff7d575f8082142b0d96ee7a1719eabdb85 (diff)
Convert ImportTensorFlow method from switch to table based.
PiperOrigin-RevId: 200892708
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc632
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc13
2 files changed, 305 insertions, 340 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 120e858717..e33b430937 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -574,9 +574,9 @@ tensorflow::Status ConvertConvOperator(
return tensorflow::Status::OK();
}
-void ConvertDepthwiseConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDepthwiseConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "DepthwiseConv2dNative");
CheckInputsCount(node, tf_import_flags, 2);
@@ -625,11 +625,12 @@ void ConvertDepthwiseConvOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(conv);
+ return tensorflow::Status::OK();
}
-void ConvertDepthToSpaceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDepthToSpaceOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "DepthToSpace");
CheckInputsCount(node, tf_import_flags, 1);
@@ -640,11 +641,12 @@ void ConvertDepthToSpaceOperator(const NodeDef& node,
op->block_size = GetIntAttr(node, "block_size");
QCHECK_GE(op->block_size, 2);
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSpaceToDepthOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSpaceToDepthOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SpaceToDepth");
CheckInputsCount(node, tf_import_flags, 1);
@@ -662,11 +664,12 @@ void ConvertSpaceToDepthOperator(const NodeDef& node,
op->block_size = GetIntAttr(node, "block_size");
QCHECK_GE(op->block_size, 2);
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBiasAddOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertBiasAddOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "BiasAdd");
CheckInputsCount(node, tf_import_flags, 2);
@@ -678,11 +681,12 @@ void ConvertBiasAddOperator(const NodeDef& node,
biasadd->inputs.push_back(bias_name);
biasadd->outputs.push_back(node.name());
model->operators.emplace_back(biasadd);
+ return tensorflow::Status::OK();
}
-void ConvertRandomUniform(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertRandomUniform(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "RandomUniform");
CheckInputsCount(node, tf_import_flags, 1);
@@ -695,11 +699,12 @@ void ConvertRandomUniform(const NodeDef& node,
op->seed2 = GetIntAttr(node, "seed2");
CHECK(model != nullptr);
model->operators.emplace_back(std::move(op));
+ return tensorflow::Status::OK();
}
-void ConvertIdentityOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertIdentityOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
auto* op = new TensorFlowIdentityOperator;
@@ -716,9 +721,10 @@ void ConvertIdentityOperator(const NodeDef& node,
op->inputs.push_back(input_name);
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFakeQuantWithMinMaxArgs(
+tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
@@ -733,9 +739,10 @@ void ConvertFakeQuantWithMinMaxArgs(
// tf.fake_quant_with_min_max_args num_bits defaults to 8.
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFakeQuantWithMinMaxVars(
+tensorflow::Status ConvertFakeQuantWithMinMaxVars(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
@@ -751,12 +758,12 @@ void ConvertFakeQuantWithMinMaxVars(
op->outputs.push_back(node.name());
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-
-void ConvertSqueezeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSqueezeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Squeeze");
CheckInputsCount(node, tf_import_flags, 1);
auto* op = new SqueezeOperator;
@@ -772,11 +779,12 @@ void ConvertSqueezeOperator(const NodeDef& node,
}
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSumOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Sum");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new TensorFlowSumOperator;
@@ -787,11 +795,12 @@ void ConvertSumOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertSplitOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSplitOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Split");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new TensorFlowSplitOperator;
@@ -804,11 +813,12 @@ void ConvertSplitOperator(const NodeDef& node,
}
op->num_split = num_split;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSwitchOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSwitchOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Switch");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new TensorFlowSwitchOperator;
@@ -818,11 +828,12 @@ void ConvertSwitchOperator(const NodeDef& node,
// Switch operators have two outputs: "name" and "name:1".
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSoftmaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSoftmaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Softmax");
CheckInputsCount(node, tf_import_flags, 1);
const auto& input_name = node.input(0);
@@ -833,11 +844,12 @@ void ConvertSoftmaxOperator(const NodeDef& node,
CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
softmax->beta = 1.f;
model->operators.emplace_back(softmax);
+ return tensorflow::Status::OK();
}
-void ConvertLRNOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertLRNOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "LRN");
CheckInputsCount(node, tf_import_flags, 1);
const auto& input_name = node.input(0);
@@ -849,11 +861,12 @@ void ConvertLRNOperator(const NodeDef& node,
lrn->alpha = GetFloatAttr(node, "alpha");
lrn->beta = GetFloatAttr(node, "beta");
model->operators.emplace_back(lrn);
+ return tensorflow::Status::OK();
}
-void ConvertMaxPoolOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMaxPoolOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "MaxPool");
CheckInputsCount(node, tf_import_flags, 1);
const auto& input_name = node.input(0);
@@ -891,11 +904,12 @@ void ConvertMaxPoolOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(maxpool);
+ return tensorflow::Status::OK();
}
-void ConvertAvgPoolOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertAvgPoolOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "AvgPool");
CheckInputsCount(node, tf_import_flags, 1);
const auto& input_name = node.input(0);
@@ -929,12 +943,12 @@ void ConvertAvgPoolOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(avgpool);
+ return tensorflow::Status::OK();
}
-
-void ConvertBatchMatMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertBatchMatMulOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CheckInputsCount(node, tf_import_flags, 2);
// https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
@@ -945,11 +959,12 @@ void ConvertBatchMatMulOperator(const NodeDef& node,
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
model->operators.emplace_back(batch_matmul);
+ return tensorflow::Status::OK();
}
-void ConvertMatMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMatMulOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CheckInputsCount(node, tf_import_flags, 2);
// Transpose flags should be easy to support, but we don't have a
@@ -967,11 +982,12 @@ void ConvertMatMulOperator(const NodeDef& node,
matmul->inputs = {node.input(0), node.input(1)};
matmul->outputs = {node.name()};
model->operators.emplace_back(matmul);
+ return tensorflow::Status::OK();
}
-void ConvertConcatOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConcatOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
Operator* op = nullptr;
if (node.op() == "Concat") {
op = new TensorFlowConcatOperator;
@@ -991,13 +1007,14 @@ void ConvertConcatOperator(const NodeDef& node,
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
// This method supports simple operators without additional attributes.
template <typename Op>
-void ConvertSimpleOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSimpleOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
auto* op = new Op;
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
@@ -1005,20 +1022,21 @@ void ConvertSimpleOperator(const NodeDef& node,
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
// This method supports simple operators without additional attributes.
template <typename Op, unsigned int NumInputs>
-void ConvertSimpleOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSimpleOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CheckInputsCount(node, tf_import_flags, NumInputs);
- ConvertSimpleOperator<Op>(node, tf_import_flags, model);
+ return ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
-void ConvertMaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Max");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new TensorFlowMaxOperator;
@@ -1029,11 +1047,12 @@ void ConvertMaxOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertMinOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMinOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Min");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new TensorFlowMinOperator;
@@ -1044,12 +1063,12 @@ void ConvertMinOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-
-void ConvertUnsupportedOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertUnsupportedOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1072,11 +1091,12 @@ void ConvertUnsupportedOperator(const NodeDef& node,
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
}
+ return tensorflow::Status::OK();
}
-void ConvertStridedSliceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertStridedSliceOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "StridedSlice");
// TODO(soroosh): The 4th input (strides) should be e optional, to be
// consistent with TF.
@@ -1100,11 +1120,12 @@ void ConvertStridedSliceOperator(const NodeDef& node,
: 0;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertPlaceholderOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertPlaceholderOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
if (node.op() == "Placeholder") {
CheckInputsCount(node, tf_import_flags, 0);
@@ -1132,15 +1153,18 @@ void ConvertPlaceholderOperator(const NodeDef& node,
}
}
}
+ return tensorflow::Status::OK();
}
-void ConvertNoOpOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {}
+tensorflow::Status ConvertNoOpOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ return tensorflow::Status::OK();
+}
-void ConvertCastOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertCastOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Cast");
CheckInputsCount(node, tf_import_flags, 1);
const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
@@ -1151,11 +1175,12 @@ void ConvertCastOperator(const NodeDef& node,
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFloorOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertFloorOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Floor");
CheckInputsCount(node, tf_import_flags, 1);
const auto data_type = GetDataTypeAttr(node, "T");
@@ -1164,11 +1189,12 @@ void ConvertFloorOperator(const NodeDef& node,
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertGatherOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertGatherOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Gather" || node.op() == "GatherV2");
if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2);
if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3);
@@ -1181,11 +1207,12 @@ void ConvertGatherOperator(const NodeDef& node,
// should read it an pass it on to the TF Lite Interpreter.
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertArgMaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertArgMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "ArgMax");
CheckInputsCount(node, tf_import_flags, 2);
const auto axis_data_type =
@@ -1201,11 +1228,12 @@ void ConvertArgMaxOperator(const NodeDef& node,
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertResizeBilinearOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertResizeBilinearOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "ResizeBilinear");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new ResizeBilinearOperator;
@@ -1219,9 +1247,10 @@ void ConvertResizeBilinearOperator(const NodeDef& node,
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBatchNormWithGlobalNormalizationOperator(
+tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
@@ -1268,11 +1297,12 @@ void ConvertBatchNormWithGlobalNormalizationOperator(
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFusedBatchNormOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertFusedBatchNormOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "FusedBatchNorm");
CheckInputsCount(node, tf_import_flags, 5);
@@ -1320,11 +1350,12 @@ void ConvertFusedBatchNormOperator(const NodeDef& node,
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSpaceToBatchNDOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSpaceToBatchNDOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SpaceToBatchND");
CheckInputsCount(node, tf_import_flags, 3);
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@@ -1335,11 +1366,12 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBatchToSpaceNDOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertBatchToSpaceNDOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "BatchToSpaceND");
CheckInputsCount(node, tf_import_flags, 3);
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
@@ -1350,11 +1382,12 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertMeanOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMeanOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Mean");
CheckInputsCount(node, tf_import_flags, 2);
auto* op = new MeanOperator;
@@ -1367,11 +1400,12 @@ void ConvertMeanOperator(const NodeDef& node,
} else if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertSvdfOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSvdfOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Svdf");
const int input_size = GetInputsCount(node, tf_import_flags);
QCHECK(input_size == 3 || input_size == 4)
@@ -1394,12 +1428,13 @@ void ConvertSvdfOperator(const NodeDef& node,
}
op->rank = node.attr().at("Rank").i();
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
// This is just bare bones support to get the shapes to propagate.
-void ConvertTransposeConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertTransposeConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2DBackpropInput");
CheckInputsCount(node, tf_import_flags, 3);
auto* op = new TransposeConvOperator;
@@ -1465,12 +1500,12 @@ void ConvertTransposeConvOperator(const NodeDef& node,
"Conv2DBackpropInput nodes.";
}
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-
-void ConvertRangeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertRangeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Range");
CheckInputsCount(node, tf_import_flags, 3);
auto* op = new RangeOperator;
@@ -1485,11 +1520,12 @@ void ConvertRangeOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertStackOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertStackOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK((node.op() == "Stack") || (node.op() == "Pack"));
auto* op = new StackOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1505,9 +1541,9 @@ void ConvertStackOperator(const NodeDef& node,
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-
// Some TensorFlow ops only occur in graph cycles, representing
// control flow. We do not currently support control flow, so we wouldn't
// be able to fully support such graphs, including performing inference,
@@ -1518,7 +1554,7 @@ void ConvertStackOperator(const NodeDef& node,
// such ops as RNN back-edges, which is technically incorrect (does not
// allow representing the op's semantics) but good enough to get a
// graph visualization.
-void ConvertOperatorSpecialCasedAsRNNBackEdge(
+tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
// At the moment, the only type of operator special-cased in this way is
@@ -1531,6 +1567,7 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge(
rnn_state->set_discardable(true);
rnn_state->set_state_array(node.name());
rnn_state->set_back_edge_source_array(node.input(0));
+ return tensorflow::Status::OK();
}
void StripCaretFromArrayNames(Model* model) {
@@ -1673,9 +1710,9 @@ bool InlineAllFunctions(GraphDef* graphdef) {
return graph_modified;
}
-void ConvertTopKV2Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertTopKV2Operator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
auto op = absl::make_unique<TopKV2Operator>();
op->inputs.push_back(node.input(0));
@@ -1692,9 +1729,10 @@ void ConvertTopKV2Operator(const NodeDef& node,
op->outputs.push_back(node.name());
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertDynamicPartitionOperator(
+tensorflow::Status ConvertDynamicPartitionOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
auto op = absl::make_unique<DynamicPartitionOperator>();
@@ -1709,11 +1747,12 @@ void ConvertDynamicPartitionOperator(
op->outputs.push_back(node.name() + ":" + std::to_string(i));
}
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertDynamicStitchOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDynamicStitchOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
// The parallel and non-parallel variants are the same besides whether they
// have a parallel loop; there are no behavioral differences.
CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
@@ -1727,11 +1766,12 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertSparseToDenseOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSparseToDenseOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SparseToDense");
CheckInputsCount(node, tf_import_flags, 4);
@@ -1745,217 +1785,132 @@ void ConvertSparseToDenseOperator(const NodeDef& node,
? GetBoolAttr(node, "validate_indices")
: true;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
} // namespace
namespace internal {
+
+using ConverterType = tensorflow::Status (*)(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model);
+using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+
+ConverterMapType GetTensorFlowNodeConverterMap() {
+ return std::unordered_map<std::string, ConverterType>({
+ {"Add", ConvertSimpleOperator<AddOperator, 2>},
+ {"AddN", ConvertSimpleOperator<AddNOperator>},
+ {"All", ConvertSimpleOperator<TensorFlowAllOperator>},
+ {"ArgMax", ConvertArgMaxOperator},
+ {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
+ {"AvgPool", ConvertAvgPoolOperator},
+ {"BatchMatMul", ConvertBatchMatMulOperator},
+ {"BatchNormWithGlobalNormalization",
+ ConvertBatchNormWithGlobalNormalizationOperator},
+ {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
+ {"BiasAdd", ConvertBiasAddOperator},
+ {"Cast", ConvertCastOperator},
+ {"CheckNumerics", ConvertIdentityOperator},
+ {"Concat", ConvertConcatOperator},
+ {"ConcatV2", ConvertConcatOperator},
+ {"Const", ConvertConstOperator},
+ {"Conv2D", ConvertConvOperator},
+ {"Conv2DBackpropInput", ConvertTransposeConvOperator},
+ {"DepthToSpace", ConvertDepthToSpaceOperator},
+ {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
+ {"Div", ConvertSimpleOperator<DivOperator, 2>},
+ {"DynamicPartition", ConvertDynamicPartitionOperator},
+ {"DynamicStitch", ConvertDynamicStitchOperator},
+ {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2>},
+ {"Exp", ConvertSimpleOperator<ExpOperator, 1>},
+ {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2>},
+ {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
+ {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
+ {"Fill", ConvertSimpleOperator<FillOperator, 2>},
+ {"Floor", ConvertFloorOperator},
+ {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2>},
+ {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2>},
+ {"FusedBatchNorm", ConvertFusedBatchNormOperator},
+ {"Gather", ConvertGatherOperator},
+ {"GatherV2", ConvertGatherOperator},
+ {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2>},
+ {"GreaterEqual",
+ ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>},
+ {"Identity", ConvertIdentityOperator},
+ {"LRN", ConvertLRNOperator},
+ {"LegacyFedInput", ConvertPlaceholderOperator},
+ {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
+ {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
+ {"Log", ConvertSimpleOperator<LogOperator, 1>},
+ {"Log", ConvertSimpleOperator<LogOperator, 1>},
+ {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
+ {"MatMul", ConvertMatMulOperator},
+ {"Max", ConvertMaxOperator},
+ {"MaxPool", ConvertMaxPoolOperator},
+ {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
+ {"Mean", ConvertMeanOperator},
+ {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
+ {"Min", ConvertMinOperator},
+ {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
+ {"Mul", ConvertSimpleOperator<MulOperator, 2>},
+ {"Neg", ConvertSimpleOperator<NegOperator, 1>},
+ {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
+ {"NoOp", ConvertNoOpOperator},
+ {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
+ {"Pack", ConvertStackOperator},
+ {"Pad", ConvertSimpleOperator<PadOperator, 2>},
+ {"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
+ {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
+ {"Placeholder", ConvertPlaceholderOperator},
+ {"PlaceholderWithDefault", ConvertIdentityOperator},
+ {"RandomUniform", ConvertRandomUniform},
+ {"Range", ConvertRangeOperator},
+ {"Rank", ConvertSimpleOperator<RankOperator, 1>},
+ {"RealDiv", ConvertSimpleOperator<DivOperator, 2>},
+ {"Relu", ConvertSimpleOperator<ReluOperator, 1>},
+ {"Relu6", ConvertSimpleOperator<Relu6Operator, 1>},
+ {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2>},
+ {"ResizeBilinear", ConvertResizeBilinearOperator},
+ {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>},
+ {"Select", ConvertSimpleOperator<SelectOperator, 3>},
+ {"Shape", ConvertSimpleOperator<TensorFlowShapeOperator, 1>},
+ {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>},
+ {"Sin", ConvertSimpleOperator<SinOperator, 1>},
+ {"Slice", ConvertSimpleOperator<SliceOperator, 3>},
+ {"Softmax", ConvertSoftmaxOperator},
+ {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
+ {"SpaceToDepth", ConvertSpaceToDepthOperator},
+ {"SparseToDense", ConvertSparseToDenseOperator},
+ {"Split", ConvertSplitOperator},
+ {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>},
+ {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>},
+ {"Squeeze", ConvertSqueezeOperator},
+ {"Stack", ConvertStackOperator},
+ {"StopGradient", ConvertIdentityOperator},
+ {"StridedSlice", ConvertStridedSliceOperator},
+ {"Sub", ConvertSimpleOperator<SubOperator, 2>},
+ {"Sum", ConvertSumOperator},
+ {"Svdf", ConvertSvdfOperator},
+ {"Switch", ConvertSwitchOperator},
+ {"Tanh", ConvertSimpleOperator<TanhOperator, 1>},
+ {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2>},
+ {"TopK", ConvertTopKV2Operator},
+ {"TopKV2", ConvertTopKV2Operator},
+ {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
+ });
+}
+
tensorflow::Status ImportTensorFlowNode(
const tensorflow::NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags, Model* model) {
- // TODO(ahentz): Historically these functions all CHECK-fail on error. We've
- // been slowly converting them to return Status.
- if (node.op() == "Const") {
- return ConvertConstOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2D") {
- return ConvertConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2DBackpropInput") {
- ConvertTransposeConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthwiseConv2dNative") {
- ConvertDepthwiseConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthToSpace") {
- ConvertDepthToSpaceOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToDepth") {
- ConvertSpaceToDepthOperator(node, tf_import_flags, model);
- } else if (node.op() == "BiasAdd") {
- ConvertBiasAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu") {
- ConvertSimpleOperator<ReluOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Relu6") {
- ConvertSimpleOperator<Relu6Operator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Sigmoid") {
- ConvertSimpleOperator<LogisticOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Tanh") {
- ConvertSimpleOperator<TanhOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "MaxPool") {
- ConvertMaxPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "AvgPool") {
- ConvertAvgPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "Reshape") {
- ConvertSimpleOperator<TensorFlowReshapeOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "BatchMatMul") {
- ConvertBatchMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "MatMul") {
- ConvertMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Div" || node.op() == "RealDiv") {
- ConvertSimpleOperator<DivOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
- node.op() == "StopGradient") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxVars") {
- ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxArgs") {
- ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
- } else if (node.op() == "Neg") {
- ConvertSimpleOperator<NegOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Rsqrt") {
- ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>(node, tf_import_flags,
- model);
- } else if (node.op() == "Squeeze") {
- ConvertSqueezeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sqrt") {
- ConvertSimpleOperator<TensorFlowSqrtOperator, 1>(node, tf_import_flags,
- model);
- } else if (node.op() == "Square") {
- ConvertSimpleOperator<TensorFlowSquareOperator, 1>(node, tf_import_flags,
- model);
- } else if (node.op() == "Add") {
- ConvertSimpleOperator<AddOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "AddN") {
- ConvertSimpleOperator<AddNOperator>(node, tf_import_flags, model);
- } else if (node.op() == "Mul") {
- ConvertSimpleOperator<MulOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "Sub") {
- ConvertSimpleOperator<SubOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "Sum") {
- ConvertSumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tile") {
- ConvertSimpleOperator<TensorFlowTileOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
- ConvertConcatOperator(node, tf_import_flags, model);
- } else if (node.op() == "LRN") {
- ConvertLRNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Softmax") {
- ConvertSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Log") {
- ConvertSimpleOperator<LogOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "LogSoftmax") {
- ConvertSimpleOperator<LogSoftmaxOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "All") {
- ConvertSimpleOperator<TensorFlowAllOperator>(node, tf_import_flags, model);
- } else if (node.op() == "Assert") {
- ConvertSimpleOperator<TensorFlowAssertOperator>(node, tf_import_flags,
- model);
- } else if (node.op() == "Less") {
- ConvertSimpleOperator<TensorFlowLessOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "LessEqual") {
- ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "Greater") {
- ConvertSimpleOperator<TensorFlowGreaterOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "GreaterEqual") {
- ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>(
- node, tf_import_flags, model);
- } else if (node.op() == "Max") {
- ConvertMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Min") {
- ConvertMinOperator(node, tf_import_flags, model);
- } else if (node.op() == "Maximum") {
- ConvertSimpleOperator<TensorFlowMaximumOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "Minimum") {
- ConvertSimpleOperator<TensorFlowMinimumOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "Merge") {
- ConvertSimpleOperator<TensorFlowMergeOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "Pad") {
- ConvertSimpleOperator<PadOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "PadV2") {
- ConvertSimpleOperator<PadV2Operator, 3>(node, tf_import_flags, model);
- } else if (node.op() == "StridedSlice") {
- ConvertStridedSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Shape") {
- ConvertSimpleOperator<TensorFlowShapeOperator, 1>(node, tf_import_flags,
- model);
- } else if (node.op() == "Slice") {
- ConvertSimpleOperator<SliceOperator, 3>(node, tf_import_flags, model);
- } else if (node.op() == "Split") {
- ConvertSplitOperator(node, tf_import_flags, model);
- } else if (node.op() == "Switch") {
- ConvertSwitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "Placeholder") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "PlaceholderWithDefault") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "LegacyFedInput") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "NoOp") {
- ConvertNoOpOperator(node, tf_import_flags, model);
- } else if (node.op() == "Cast") {
- ConvertCastOperator(node, tf_import_flags, model);
- } else if (node.op() == "Floor") {
- ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather" || node.op() == "GatherV2") {
- ConvertGatherOperator(node, tf_import_flags, model);
- } else if (node.op() == "ResizeBilinear") {
- ConvertResizeBilinearOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchNormWithGlobalNormalization") {
- ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
- model);
- } else if (node.op() == "FusedBatchNorm") {
- ConvertFusedBatchNormOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToBatchND") {
- ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchToSpaceND") {
- ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mean") {
- ConvertMeanOperator(node, tf_import_flags, model);
- } else if (node.op() == "Svdf") {
- ConvertSvdfOperator(node, tf_import_flags, model);
- } else if (node.op() == "NextIteration") {
- ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
- } else if (node.op() == "ExpandDims") {
- ConvertSimpleOperator<ExpandDimsOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "Fill") {
- ConvertSimpleOperator<FillOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "FloorDiv") {
- ConvertSimpleOperator<FloorDivOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "FloorMod") {
- ConvertSimpleOperator<FloorModOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "Range") {
- ConvertRangeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rank") {
- ConvertSimpleOperator<RankOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Stack" || node.op() == "Pack") {
- ConvertStackOperator(node, tf_import_flags, model);
- } else if (node.op() == "Transpose") {
- ConvertSimpleOperator<TransposeOperator, 2>(node, tf_import_flags, model);
- } else if (node.op() == "ArgMax") {
- ConvertArgMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Exp") {
- ConvertSimpleOperator<ExpOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "TopK" || node.op() == "TopKV2") {
- ConvertTopKV2Operator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicPartition") {
- ConvertDynamicPartitionOperator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicStitch" ||
- node.op() == "ParallelDynamicStitch") {
- ConvertDynamicStitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "RandomUniform") {
- ConvertRandomUniform(node, tf_import_flags, model);
- } else if (node.op() == "Sin") {
- ConvertSimpleOperator<SinOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Log") {
- ConvertSimpleOperator<LogOperator, 1>(node, tf_import_flags, model);
- } else if (node.op() == "Select") {
- ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model);
- } else if (node.op() == "SparseToDense") {
- ConvertSparseToDenseOperator(node, tf_import_flags, model);
- } else if (node.op() == "Equal") {
- ConvertSimpleOperator<TensorFlowEqualOperator, 2>(node, tf_import_flags,
- model);
- } else if (node.op() == "NotEqual") {
- ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>(node, tf_import_flags,
- model);
+ const TensorFlowImportFlags& tf_import_flags, Model* model,
+ const ConverterMapType& converter_map) {
+ auto converter = converter_map.find(node.op());
+ if (converter == converter_map.end()) {
+ return ConvertUnsupportedOperator(node, tf_import_flags, model);
} else {
- ConvertUnsupportedOperator(node, tf_import_flags, model);
+ return converter->second(node, tf_import_flags, model);
}
- return tensorflow::Status::OK();
}
} // namespace internal
@@ -1981,10 +1936,13 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
+ const internal::ConverterMapType& converter_map =
+ internal::GetTensorFlowNodeConverterMap();
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
- auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model);
+ auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
+ converter_map);
CHECK(status.ok()) << status.error_message();
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index d18c329a43..90e6f698ef 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -36,8 +36,14 @@ using tensorflow::NodeDef;
using tensorflow::Status;
namespace internal {
+using ConverterType = tensorflow::Status (*)(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model);
+using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+
+ConverterMapType GetTensorFlowNodeConverterMap();
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
- Model*);
+ Model*, const ConverterMapType&);
} // namespace internal
namespace {
@@ -105,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
Status ImportNode(const NodeDef& node) {
Model model;
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
- &model);
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
+ converter);
}
};