aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc228
1 files changed, 183 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6be6b25f93..b79bb300f0 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
LegacyScalarPolicy::kAvoidLegacyScalars);
}
+void ConvertBoolTensorConst(const Model& model, const string& name,
+ GraphDef* tensorflow_graph) {
+ if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
+ return;
+ }
+ CHECK(model.HasArray(name));
+ const auto& array = model.GetArray(name);
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
+ const_op->set_op("Const");
+ const_op->set_name(name);
+ (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
+ auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
+ tensor->set_dtype(DT_BOOL);
+ const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
+ for (auto index : data) {
+ tensor->add_bool_val(index);
+ }
+ const auto& array_shape = array.shape();
+ auto* shape = tensor->mutable_tensor_shape();
+ for (int i = 0; i < array_shape.dimensions_count(); i++) {
+ shape->add_dim()->set_size(array_shape.dims(i));
+ }
+}
+
void ConvertIntTensorConst(const Model& model, const string& name,
GraphDef* tensorflow_graph) {
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
@@ -621,7 +645,8 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op,
CHECK_EQ(src_op.inputs.size(), 2);
*add_op->add_input() = src_op.inputs[0];
*add_op->add_input() = src_op.inputs[1];
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
@@ -633,7 +658,8 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
*add_op->add_input() = input;
}
(*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertMulOperator(const Model& model, const MulOperator& src_op,
@@ -644,16 +670,18 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op,
CHECK_EQ(src_op.inputs.size(), 2);
*add_op->add_input() = src_op.inputs[0];
*add_op->add_input() = src_op.inputs[1];
- (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*add_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
-void ConvertReluOperator(const ReluOperator& src_op,
+void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
relu_op->set_op("Relu");
relu_op->set_name(src_op.outputs[0]);
*relu_op->add_input() = src_op.inputs[0];
- (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
+ (*relu_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
}
void ConvertRelu1Operator(const Relu1Operator& src_op,
@@ -884,6 +912,9 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
if (src_op.num_bits) {
(*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
}
+ if (src_op.narrow_range) {
+ (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
+ }
}
void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
@@ -1107,13 +1138,27 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
- gather_op->set_op("Gather");
+ gather_op->set_op("GatherV2");
gather_op->set_name(src_op.outputs[0]);
- CHECK_EQ(src_op.inputs.size(), 2);
*gather_op->add_input() = src_op.inputs[0];
*gather_op->add_input() = src_op.inputs[1];
+ if (!src_op.axis) {
+ // Dynamic axis.
+ CHECK_EQ(src_op.inputs.size(), 3);
+ *gather_op->add_input() = src_op.inputs[2];
+ } else {
+ // Constant axis.
+ CHECK_EQ(src_op.inputs.size(), 2);
+ const string gather_axis =
+ AvailableArrayName(model, gather_op->name() + "/axis");
+ CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
+ tensorflow_graph);
+ *gather_op->add_input() = gather_axis;
+ }
+
(*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
+ (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
const tensorflow::DataType params_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*gather_op->mutable_attr())["Tparams"].set_type(params_type);
@@ -1135,6 +1180,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
+void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
+ argmin_op->set_op("ArgMin");
+ argmin_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *argmin_op->add_input() = src_op.inputs[0];
+ *argmin_op->add_input() = src_op.inputs[1];
+ (*argmin_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+ (*argmin_op->mutable_attr())["Tidx"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[1]));
+ (*argmin_op->mutable_attr())["output_type"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+}
+
void ConvertTransposeOperator(const Model& model,
const TransposeOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1188,17 +1249,17 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
GetTensorFlowDataType(src_op.dtype));
}
-void ConvertStackOperator(const Model& model, const StackOperator& src_op,
- GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* stack_op = tensorflow_graph->add_node();
- stack_op->set_op("Stack");
- stack_op->set_name(src_op.outputs[0]);
+void ConvertPackOperator(const Model& model, const PackOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
+ pack_op->set_op("Pack");
+ pack_op->set_name(src_op.outputs[0]);
for (const auto& input : src_op.inputs) {
- *stack_op->add_input() = input;
+ *pack_op->add_input() = input;
}
- (*stack_op->mutable_attr())["elem_type"].set_type(
- GetTensorFlowDataType(model, src_op.outputs[0]));
- (*stack_op->mutable_attr())["axis"].set_i(src_op.axis);
+ (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
+ (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
+ (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype));
}
void ConvertFillOperator(const Model& model, const FillOperator& src_op,
@@ -1604,10 +1665,11 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
}
-void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
- GraphDef* tensorflow_graph) {
+template <typename T>
+void ConvertReduceOperator(const Model& model, const T& src_op,
+ GraphDef* tensorflow_graph, const string& op_name) {
tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
- new_op->set_op("Mean");
+ new_op->set_op(op_name);
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*new_op->add_input() = src_op.inputs[0];
@@ -1616,6 +1678,9 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
const tensorflow::DataType params_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
+ const tensorflow::DataType indices_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
if (src_op.keep_dims) {
(*new_op->mutable_attr())["keep_dims"].set_b(true);
@@ -1672,43 +1737,43 @@ void ConvertSubOperator(const Model& model, const SubOperator& src_op,
void ConvertTensorFlowMinimumOperator(const Model& model,
const TensorFlowMinimumOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Minimum");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
+ min_op->set_op("Minimum");
+ min_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
+ *min_op->add_input() = src_op.inputs[0];
+ *min_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*min_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTensorFlowMaximumOperator(const Model& model,
const TensorFlowMaximumOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Maximum");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
+ max_op->set_op("Maximum");
+ max_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
+ *max_op->add_input() = src_op.inputs[0];
+ *max_op->add_input() = src_op.inputs[1];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[0]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*max_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
- sub_op->set_op("Select");
- sub_op->set_name(src_op.outputs[0]);
+ tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
+ select_op->set_op("Select");
+ select_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
- *sub_op->add_input() = src_op.inputs[0];
- *sub_op->add_input() = src_op.inputs[1];
- *sub_op->add_input() = src_op.inputs[2];
+ *select_op->add_input() = src_op.inputs[0];
+ *select_op->add_input() = src_op.inputs[1];
+ *select_op->add_input() = src_op.inputs[2];
const tensorflow::DataType data_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
- (*sub_op->mutable_attr())["T"].set_type(data_type);
+ (*select_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTileOperator(const Model& model,
@@ -1731,11 +1796,14 @@ void ConvertTileOperator(const Model& model,
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
- topk_op->set_op("TOPKV2");
+ topk_op->set_op("TopKV2");
topk_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*topk_op->add_input() = src_op.inputs[0];
*topk_op->add_input() = src_op.inputs[1];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*topk_op->mutable_attr())["T"].set_type(data_type);
(*topk_op->mutable_attr())["sorted"].set_b(true);
}
@@ -1806,6 +1874,43 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op,
(*pow_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertAnyOperator(const Model& model, const AnyOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* any_op = tensorflow_graph->add_node();
+ any_op->set_op("Any");
+ any_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *any_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*any_op->mutable_attr())["Tidx"].set_type(data_type);
+ (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims);
+}
+
+void ConvertLogicalAndOperator(const Model& model,
+ const LogicalAndOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
+ logical_op->set_op("LogicalAnd");
+ logical_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *logical_op->add_input() = src_op.inputs[i];
+ }
+}
+
+void ConvertLogicalNotOperator(const Model& model,
+ const LogicalNotOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
+ logical_op->set_op("LogicalNot");
+ logical_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *logical_op->add_input() = src_op.inputs[0];
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1842,7 +1947,7 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRelu) {
- ConvertReluOperator(static_cast<const ReluOperator&>(src_op),
+ ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kRelu1) {
ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
@@ -1942,8 +2047,24 @@ void ConvertOperator(const Model& model, const Operator& src_op,
model, static_cast<const StridedSliceOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kMean) {
- ConvertMeanOperator(model, static_cast<const MeanOperator&>(src_op),
- tensorflow_graph);
+ ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
+ tensorflow_graph, "Mean");
+ } else if (src_op.type == OperatorType::kSum) {
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowSumOperator&>(src_op),
+ tensorflow_graph, "Sum");
+ } else if (src_op.type == OperatorType::kReduceProd) {
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowProdOperator&>(src_op),
+ tensorflow_graph, "Prod");
+ } else if (src_op.type == OperatorType::kReduceMin) {
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowMaxOperator&>(src_op),
+ tensorflow_graph, "Min");
+ } else if (src_op.type == OperatorType::kReduceMax) {
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowMaxOperator&>(src_op),
+ tensorflow_graph, "Max");
} else if (src_op.type == OperatorType::kSub) {
ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
tensorflow_graph);
@@ -1964,6 +2085,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kArgMax) {
ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kArgMin) {
+ ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kTopK_V2) {
ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
tensorflow_graph);
@@ -1980,9 +2104,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kRange) {
ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kStack) {
- ConvertStackOperator(model, static_cast<const StackOperator&>(src_op),
- tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPack) {
+ ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kFill) {
ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
tensorflow_graph);
@@ -2023,6 +2147,17 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kPow) {
ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAny) {
+ ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalAnd) {
+ ConvertLogicalAndOperator(model,
+ static_cast<const LogicalAndOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLogicalNot) {
+ ConvertLogicalNotOperator(model,
+ static_cast<const LogicalNotOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
@@ -2101,6 +2236,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
const auto& array = *array_pair.second;
if (array.buffer) {
switch (array.data_type) {
+ case ArrayDataType::kBool:
+ ConvertBoolTensorConst(model, array_name, tensorflow_graph);
+ break;
case ArrayDataType::kFloat:
ConvertFloatTensorConst(model, array_name, tensorflow_graph);
break;