diff options
author | Jared Duke <jdduke@google.com> | 2018-09-12 09:23:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 09:28:12 -0700 |
commit | 6b507a6de855a6f988100904229b7f46a5652b88 (patch) | |
tree | cbb0c14a47f2da3dd0add9211f03641965b181f4 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | 9e78991b5c380b7fba0444685e5c6ef40e3c5b26 (diff) |
Add basic type propagation for unsupported ops in TFLite conversion
PiperOrigin-RevId: 212651704
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 9bc23c4b3c..eb36b3411d 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -58,6 +58,7 @@ using tensorflow::DT_STRING; using tensorflow::DT_UINT8; using tensorflow::GraphDef; using tensorflow::NodeDef; +using tensorflow::OpRegistry; using tensorflow::TensorProto; using tensorflow::TensorShapeProto; @@ -1079,6 +1080,23 @@ tensorflow::Status ConvertUnsupportedOperator( } else if (HasAttr(node, "Tout")) { const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); + } else { + const tensorflow::OpDef* op_def = nullptr; + if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { + for (const auto& output_arg : op_def->output_arg()) { + if (HasAttr(node, output_arg.type_attr())) { + op->output_data_types.push_back( + ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); + } else { + LOG(INFO) << "Op node missing output type attribute: " << node.name(); + } + } + } + if (op->output_data_types.empty()) { + // TODO(b/113613439): Figure out how to propagate types for custom ops + // that have no OpDef. + LOG(INFO) << "Unable to determine output type for op: " << node.op(); + } } if (HasAttr(node, kAttrOutputShapes)) { const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); |