aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc174
1 files changed, 95 insertions, 79 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 55e39d963f..f36f720857 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -755,6 +755,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
op->outputs.push_back(node.name());
// tf.fake_quant_with_min_max_args num_bits defaults to 8.
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
+ if (HasAttr(node, "narrow_range")) {
+ op->narrow_range = GetBoolAttr(node, "narrow_range");
+ }
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
@@ -774,6 +777,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars(
}
op->outputs.push_back(node.name());
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
+ if (HasAttr(node, "narrow_range")) {
+ op->narrow_range = GetBoolAttr(node, "narrow_range");
+ }
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
@@ -799,22 +805,6 @@ tensorflow::Status ConvertSqueezeOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertSumOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sum");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new TensorFlowSumOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
- op->keep_dims = GetBoolAttr(node, "keep_dims");
- }
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertSplitOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -984,18 +974,19 @@ tensorflow::Status ConvertMatMulOperator(
Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- // Transpose flags should be easy to support, but we don't have a
- // GraphDef with them to test on at the moment.
- CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"),
- false);
- CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"),
- false);
CHECK(!HasAttr(node, "adjoint_a") ||
(GetBoolAttr(node, "adjoint_a") == false));
CHECK(!HasAttr(node, "adjoint_b") ||
(GetBoolAttr(node, "adjoint_b") == false));
auto* matmul = new TensorFlowMatMulOperator;
+ if (HasAttr(node, "transpose_a")) {
+ matmul->transpose_a = GetBoolAttr(node, "transpose_a");
+ }
+ if (HasAttr(node, "transpose_b")) {
+ matmul->transpose_b = GetBoolAttr(node, "transpose_b");
+ }
+
matmul->inputs = {node.input(0), node.input(1)};
matmul->outputs = {node.name()};
model->operators.emplace_back(matmul);
@@ -1051,41 +1042,14 @@ tensorflow::Status ConvertSimpleOperator(
return ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
-tensorflow::Status ConvertMaxOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Max");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new TensorFlowMaxOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
- op->keep_dims = GetBoolAttr(node, "keep_dims");
- }
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status ConvertMinOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Min");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new TensorFlowMinOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
- op->keep_dims = GetBoolAttr(node, "keep_dims");
- }
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
+ // Names of special attributes in TF graph that are used by Toco.
+ static constexpr char kAttrOutputQuantized[] = "_output_quantized";
+ static constexpr char kAttrOutputTypes[] = "_output_types";
+ static constexpr char kAttrOutputShapes[] = "_output_shapes";
+
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1096,11 +1060,11 @@ tensorflow::Status ConvertUnsupportedOperator(
op->tensorflow_op = node.op();
node.SerializeToString(&op->tensorflow_node_def);
model->operators.emplace_back(op);
- if (HasAttr(node, "_output_quantized")) {
- op->quantized = GetBoolAttr(node, "_output_quantized");
+ if (HasAttr(node, kAttrOutputQuantized)) {
+ op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
- if (HasAttr(node, "_output_types")) {
- const auto& output_types = GetListAttr(node, "_output_types");
+ if (HasAttr(node, kAttrOutputTypes)) {
+ const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
}
@@ -1108,6 +1072,19 @@ tensorflow::Status ConvertUnsupportedOperator(
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
}
+ if (HasAttr(node, kAttrOutputShapes)) {
+ const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
+ Shape output_shape;
+ for (int i = 0; i < output_shapes.shape_size(); ++i) {
+ const auto status =
+ ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
+ &output_shape);
+ if (!status.ok()) {
+ return status;
+ }
+ op->output_shapes.push_back(output_shape);
+ }
+ }
return tensorflow::Status::OK();
}
@@ -1222,17 +1199,27 @@ tensorflow::Status ConvertGatherOperator(
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
- // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we
- // should read it an pass it on to the TF Lite Interpreter.
+ if (node.input_size() >= 3) {
+ // GatherV2 form where we are provided an axis. It may be either a constant
+ // or runtime defined value, so we just wire up the array and let
+ // ResolveGatherAttributes take care of it later on.
+ const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
+ CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
+ op->inputs.push_back(node.input(2));
+ } else {
+ // Gather form that assumes axis=0.
+ op->axis = {0};
+ }
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertArgMaxOperator(
+template <typename Op, const char* op_name>
+tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "ArgMax");
+ CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1241,7 +1228,7 @@ tensorflow::Status ConvertArgMaxOperator(
: DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
- auto* op = new ArgMaxOperator;
+ auto* op = new Op;
op->output_data_type = ConvertDataType(output_type);
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1404,12 +1391,12 @@ tensorflow::Status ConvertBatchToSpaceNDOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertMeanOperator(
+template <typename T>
+tensorflow::Status ConvertReduceOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "Mean");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new MeanOperator;
+ auto* op = new T;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
@@ -1542,11 +1529,15 @@ tensorflow::Status ConvertRangeOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertStackOperator(
+// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
+// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack"
+// operators also exist, but involve manipulating the TF runtime stack, and are
+// not directly related to tf.stack() usage.
+tensorflow::Status ConvertPackOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK((node.op() == "Stack") || (node.op() == "Pack"));
- auto* op = new StackOperator;
+ CHECK_EQ(node.op(), "Pack");
+ auto op = absl::make_unique<PackOperator>();
const int num_inputs = GetInputsCount(node, tf_import_flags);
QCHECK_GE(num_inputs, 1)
<< node.op()
@@ -1556,10 +1547,11 @@ tensorflow::Status ConvertStackOperator(
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
- // Both "Stack" and "Pack" have the "axis" attribute.
+ op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
+ op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
+ model->operators.emplace_back(std::move(op));
return tensorflow::Status::OK();
}
@@ -1605,6 +1597,24 @@ tensorflow::Status ConvertShapeOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertAnyOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Any");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+ const auto idx_type =
+ HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
+ CHECK(idx_type == DT_INT32);
+ auto op = absl::make_unique<AnyOperator>();
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ op->keep_dims =
+ HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false;
+ model->operators.push_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1832,12 +1842,17 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+constexpr char kArgMax[] = "ArgMax";
+constexpr char kArgMin[] = "ArgMin";
+
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"ArgMax", ConvertArgMaxOperator},
+ {"Any", ConvertAnyOperator},
+ {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
+ {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},
@@ -1878,28 +1893,30 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
{"Log", ConvertSimpleOperator<LogOperator, 1>},
- {"Log", ConvertSimpleOperator<LogOperator, 1>},
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
+ {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
+ {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
{"MatMul", ConvertMatMulOperator},
- {"Max", ConvertMaxOperator},
+ {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
- {"Mean", ConvertMeanOperator},
+ {"Mean", ConvertReduceOperator<MeanOperator>},
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
- {"Min", ConvertMinOperator},
+ {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
{"Mul", ConvertSimpleOperator<MulOperator, 2>},
{"Neg", ConvertSimpleOperator<NegOperator, 1>},
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
{"NoOp", ConvertNoOpOperator},
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
- {"Pack", ConvertStackOperator},
+ {"Pack", ConvertPackOperator},
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
{"ParallelDynamicStitch", ConvertDynamicStitchOperator},
{"Placeholder", ConvertPlaceholderOperator},
{"PlaceholderWithDefault", ConvertIdentityOperator},
{"Pow", ConvertSimpleOperator<PowOperator, 2>},
+ {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
{"RandomUniform", ConvertRandomUniform},
{"Range", ConvertRangeOperator},
{"Rank", ConvertSimpleOperator<RankOperator, 1>},
@@ -1922,11 +1939,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>},
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>},
{"Squeeze", ConvertSqueezeOperator},
- {"Stack", ConvertStackOperator},
{"StopGradient", ConvertIdentityOperator},
{"StridedSlice", ConvertStridedSliceOperator},
{"Sub", ConvertSimpleOperator<SubOperator, 2>},
- {"Sum", ConvertSumOperator},
+ {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
{"Svdf", ConvertSvdfOperator},
{"Switch", ConvertSwitchOperator},
{"Tanh", ConvertSimpleOperator<TanhOperator, 1>},