aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-13 16:58:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 17:02:49 -0700
commit4b42a284683416ab6159f32c903321af9dc9a591 (patch)
tree9ed2dab1ec07a6713538bda1a6a47759d3055521 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent4137d84a3b41638d4048e45ab579662c18a06df5 (diff)
Reland "Add basic type propagation for unsupported ops in TFLite conversion"
The original CL was rolled back due to op registration conflicts in the pip. Resolve the issue by only including core:ops in the toco binary itself, not in intermediate libraries. PiperOrigin-RevId: 212902838
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4b3c..efc1007925 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -1079,6 +1080,25 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
+ } else {
+ const tensorflow::OpDef* op_def = nullptr;
+ if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
+ }
+ }
+ }
+ if (op->output_data_types.empty()) {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
+ }
}
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);