diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/export_tensorflow.cc | 228 |
1 files changed, 183 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6be6b25f93..b79bb300f0 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name, LegacyScalarPolicy::kAvoidLegacyScalars); } +void ConvertBoolTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_BOOL); + const auto& data = array.GetBuffer<ArrayDataType::kBool>().data; + for (auto index : data) { + tensor->add_bool_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + void ConvertIntTensorConst(const Model& model, const string& name, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -621,7 +645,8 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, @@ -633,7 +658,8 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, *add_op->add_input() = input; } (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size()); - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertMulOperator(const Model& model, const MulOperator& src_op, @@ -644,16 +670,18 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op, CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } -void ConvertReluOperator(const ReluOperator& src_op, +void ConvertReluOperator(const Model& model, const ReluOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; - (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*relu_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertRelu1Operator(const Relu1Operator& src_op, @@ -884,6 +912,9 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, if (src_op.num_bits) { (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits); } + if (src_op.narrow_range) { + (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range); + } } void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, @@ -1107,13 +1138,27 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); - gather_op->set_op("Gather"); + gather_op->set_op("GatherV2"); gather_op->set_name(src_op.outputs[0]); - CHECK_EQ(src_op.inputs.size(), 2); *gather_op->add_input() = src_op.inputs[0]; *gather_op->add_input() = src_op.inputs[1]; + if (!src_op.axis) { + // Dynamic axis. + CHECK_EQ(src_op.inputs.size(), 3); + *gather_op->add_input() = src_op.inputs[2]; + } else { + // Constant axis. + CHECK_EQ(src_op.inputs.size(), 2); + const string gather_axis = + AvailableArrayName(model, gather_op->name() + "/axis"); + CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {}, + tensorflow_graph); + *gather_op->add_input() = gather_axis; + } + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); + (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32); const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); @@ -1135,6 +1180,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node(); + argmin_op->set_op("ArgMin"); + argmin_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *argmin_op->add_input() = src_op.inputs[0]; + *argmin_op->add_input() = src_op.inputs[1]; + (*argmin_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*argmin_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*argmin_op->mutable_attr())["output_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { @@ -1188,17 +1249,17 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GetTensorFlowDataType(src_op.dtype)); } -void ConvertStackOperator(const Model& model, const StackOperator& src_op, - GraphDef* tensorflow_graph) { - tensorflow::NodeDef* stack_op = tensorflow_graph->add_node(); - stack_op->set_op("Stack"); - stack_op->set_name(src_op.outputs[0]); +void ConvertPackOperator(const Model& model, const PackOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pack_op = tensorflow_graph->add_node(); + pack_op->set_op("Pack"); + pack_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { - *stack_op->add_input() = input; + *pack_op->add_input() = input; } - (*stack_op->mutable_attr())["elem_type"].set_type( - GetTensorFlowDataType(model, src_op.outputs[0])); - (*stack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size()); + (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype)); } void ConvertFillOperator(const Model& model, const FillOperator& src_op, @@ -1604,10 +1665,11 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); } -void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, - GraphDef* tensorflow_graph) { +template <typename T> +void ConvertReduceOperator(const Model& model, const T& src_op, + GraphDef* tensorflow_graph, const string& op_name) { tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); - new_op->set_op("Mean"); + new_op->set_op(op_name); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; @@ -1616,6 +1678,9 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); + const tensorflow::DataType indices_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*new_op->mutable_attr())["Tidx"].set_type(indices_type); if (src_op.keep_dims) { (*new_op->mutable_attr())["keep_dims"].set_b(true); @@ -1672,43 +1737,43 @@ void ConvertSubOperator(const Model& model, const SubOperator& src_op, void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Minimum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *min_op->add_input() = src_op.inputs[0]; + *min_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*min_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Maximum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*max_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Select"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* select_op = tensorflow_graph->add_node(); + select_op->set_op("Select"); + select_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - *sub_op->add_input() = src_op.inputs[2]; + *select_op->add_input() = src_op.inputs[0]; + *select_op->add_input() = src_op.inputs[1]; + *select_op->add_input() = src_op.inputs[2]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[1]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*select_op->mutable_attr())["T"].set_type(data_type); } void ConvertTileOperator(const Model& model, @@ -1731,11 +1796,14 @@ void ConvertTileOperator(const Model& model, void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); - topk_op->set_op("TOPKV2"); + topk_op->set_op("TopKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *topk_op->add_input() = src_op.inputs[0]; *topk_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*topk_op->mutable_attr())["T"].set_type(data_type); (*topk_op->mutable_attr())["sorted"].set_b(true); } @@ -1806,6 +1874,43 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op, (*pow_op->mutable_attr())["T"].set_type(data_type); } +void ConvertAnyOperator(const Model& model, const AnyOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* any_op = tensorflow_graph->add_node(); + any_op->set_op("Any"); + any_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *any_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*any_op->mutable_attr())["Tidx"].set_type(data_type); + (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims); +} + +void ConvertLogicalAndOperator(const Model& model, + const LogicalAndOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalAnd"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_op->add_input() = src_op.inputs[i]; + } +} + +void ConvertLogicalNotOperator(const Model& model, + const LogicalNotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalNot"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *logical_op->add_input() = src_op.inputs[0]; +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1842,7 +1947,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertMulOperator(model, static_cast<const MulOperator&>(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu) { - ConvertReluOperator(static_cast<const ReluOperator&>(src_op), + ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu1) { ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op), @@ -1942,8 +2047,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, model, static_cast<const StridedSliceOperator&>(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kMean) { - ConvertMeanOperator(model, static_cast<const MeanOperator&>(src_op), - tensorflow_graph); + ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op), + tensorflow_graph, "Mean"); + } else if (src_op.type == OperatorType::kSum) { + ConvertReduceOperator(model, + static_cast<const TensorFlowSumOperator&>(src_op), + tensorflow_graph, "Sum"); + } else if (src_op.type == OperatorType::kReduceProd) { + ConvertReduceOperator(model, + static_cast<const TensorFlowProdOperator&>(src_op), + tensorflow_graph, "Prod"); + } else if (src_op.type == OperatorType::kReduceMin) { + ConvertReduceOperator(model, + static_cast<const TensorFlowMaxOperator&>(src_op), + tensorflow_graph, "Min"); + } else if (src_op.type == OperatorType::kReduceMax) { + ConvertReduceOperator(model, + static_cast<const TensorFlowMaxOperator&>(src_op), + tensorflow_graph, "Max"); } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast<const SubOperator&>(src_op), tensorflow_graph); @@ -1964,6 +2085,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kArgMin) { + ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kTopK_V2) { ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op), tensorflow_graph); @@ -1980,9 +2104,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kRange) { ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kStack) { - ConvertStackOperator(model, static_cast<const StackOperator&>(src_op), - tensorflow_graph); + } else if (src_op.type == OperatorType::kPack) { + ConvertPackOperator(model, static_cast<const PackOperator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kFill) { ConvertFillOperator(model, static_cast<const FillOperator&>(src_op), tensorflow_graph); @@ -2023,6 +2147,17 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kPow) { ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow", tensorflow_graph); + } else if (src_op.type == OperatorType::kAny) { + ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalAnd) { + ConvertLogicalAndOperator(model, + static_cast<const LogicalAndOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalNot) { + ConvertLogicalNotOperator(model, + static_cast<const LogicalNotOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -2101,6 +2236,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, const auto& array = *array_pair.second; if (array.buffer) { switch (array.data_type) { + case ArrayDataType::kBool: + ConvertBoolTensorConst(model, array_name, tensorflow_graph); + break; case ArrayDataType::kFloat: ConvertFloatTensorConst(model, array_name, tensorflow_graph); break; |