aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/export_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/export_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc28
1 files changed, 28 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b52a79282c..61e9106783 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/116063589): To return a working TF GraphDef, we should be returning
+ // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
+ // the conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
@@ -1968,6 +1979,19 @@ void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
(*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
}
+void ConvertZerosLikeOperator(const Model& model,
+ const TensorFlowZerosLikeOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
+ zeros_like_op->set_op(op_name);
+ zeros_like_op->set_name(src_op.outputs[0]);
+ DCHECK_EQ(src_op.inputs.size(), 1);
+ *zeros_like_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2233,6 +2257,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kUnpack) {
ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
"Unpack", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kZerosLike) {
+ ConvertZerosLikeOperator(
+ model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
+ "ZerosLike", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}