diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/export_tensorflow.cc | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index b52a79282c..61e9106783 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model, strides.mutable_list()->add_i(src_op.stride_height); strides.mutable_list()->add_i(src_op.stride_width); strides.mutable_list()->add_i(1); + // TODO(b/116063589): To return a working TF GraphDef, we should be returning + // the correct SpaceToBatchNd and BatchToSpaceND operation before and after + // the conv since TF doesn't support dilations. + if ((src_op.dilation_width_factor != 1) || + (src_op.dilation_height_factor != 1)) { + auto& dilations = (*dc2d_op->mutable_attr())["dilations"]; + dilations.mutable_list()->add_i(1); + dilations.mutable_list()->add_i(src_op.dilation_height_factor); + dilations.mutable_list()->add_i(src_op.dilation_width_factor); + dilations.mutable_list()->add_i(1); + } string padding; if (src_op.padding.type == PaddingType::kSame) { padding = "SAME"; @@ -1968,6 +1979,19 @@ void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op, (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis); } +void ConvertZerosLikeOperator(const Model& model, + const TensorFlowZerosLikeOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node(); + zeros_like_op->set_op(op_name); + zeros_like_op->set_name(src_op.outputs[0]); + DCHECK_EQ(src_op.inputs.size(), 1); + *zeros_like_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*zeros_like_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2233,6 +2257,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kUnpack) { ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op), "Unpack", tensorflow_graph); + } else if (src_op.type == OperatorType::kZerosLike) { + ConvertZerosLikeOperator( + model, static_cast<const TensorFlowZerosLikeOperator&>(src_op), + "ZerosLike", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } |