diff options
author | Jared Duke <jdduke@google.com> | 2018-09-14 11:42:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 11:45:56 -0700 |
commit | 39f50af5634b8a4d2132b57bad2152308a0fd41c (patch) | |
tree | 5a5d0b0a9722067b702995dc84a1c4d8156d36a4 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | c20a7b81d79d30db9e990309ddb419bcb48120cc (diff) |
Improve output parsing for unsupported ops
PiperOrigin-RevId: 213017532
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 82 |
1 files changed, 52 insertions, 30 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index efc1007925..2ccfd36b7c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -69,6 +69,13 @@ bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; } +bool HasWildcardDimension(const TensorShapeProto& shape) { + for (const auto& dim : shape.dim()) { + if (dim.size() == -1) return true; + } + return false; +} + const string& GetStringAttr(const NodeDef& node, const string& attr_name) { CHECK(HasAttr(node, attr_name)); const auto& attr = node.attr().at(attr_name); @@ -1054,15 +1061,27 @@ tensorflow::Status ConvertUnsupportedOperator( "_support_output_type_float_in_quantized_op"; LOG(INFO) << "Converting unsupported operation: " << node.op(); + auto* op = new TensorFlowUnsupportedOperator; + op->tensorflow_op = node.op(); + node.SerializeToString(&op->tensorflow_node_def); + model->operators.emplace_back(op); + + // Parse inputs. const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - op->outputs.push_back(node.name()); - op->tensorflow_op = node.op(); - node.SerializeToString(&op->tensorflow_node_def); - model->operators.emplace_back(op); + + // Parse outputs. + op->outputs.push_back(node.name()); // Implicit :0. + const tensorflow::OpDef* op_def = nullptr; + if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { + for (int i = 1; i < op_def->output_arg_size(); ++i) { + op->outputs.push_back(absl::StrCat(node.name(), ":", i)); + } + } + // Parse if the op supports quantization if (HasAttr(node, kAttrOutputQuantized)) { op->quantized = GetBoolAttr(node, kAttrOutputQuantized); @@ -1072,6 +1091,8 @@ tensorflow::Status ConvertUnsupportedOperator( op->support_output_type_float_in_quantized_op = GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp); } + + // Parse output type(s). if (HasAttr(node, kAttrOutputTypes)) { const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { @@ -1080,33 +1101,40 @@ tensorflow::Status ConvertUnsupportedOperator( } else if (HasAttr(node, "Tout")) { const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); - } else { - const tensorflow::OpDef* op_def = nullptr; - if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) { - for (const auto& output_arg : op_def->output_arg()) { - if (HasAttr(node, output_arg.type_attr())) { - op->output_data_types.push_back( - ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); - } else { - LOG(INFO) << "Op node missing output type attribute: " << node.name(); - op->output_data_types.clear(); - break; - } + } else if (op_def != nullptr) { + for (const auto& output_arg : op_def->output_arg()) { + if (HasAttr(node, output_arg.type_attr())) { + op->output_data_types.push_back( + ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr()))); + } else { + LOG(INFO) << "Op node missing output type attribute: " << node.name(); + op->output_data_types.clear(); + break; } } - if (op->output_data_types.empty()) { - // TODO(b/113613439): Figure out how to propagate types for custom ops - // that have no OpDef. - LOG(INFO) << "Unable to determine output type for op: " << node.op(); - } + } else { + // TODO(b/113613439): Figure out how to propagate types for custom ops + // that have no OpDef. + LOG(INFO) << "Unable to determine output type for op: " << node.op(); } + + // Parse output shape(s). if (HasAttr(node, kAttrOutputShapes)) { const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); Shape output_shape; for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto& shape = output_shapes.shape(i); + // TOCO doesn't yet properly handle shapes with wildcard dimensions. + // TODO(b/113613439): Handle shape inference for unsupported ops that have + // shapes with wildcard dimensions. + if (HasWildcardDimension(shape)) { + LOG(INFO) << "Skipping wildcard output shape(s) for node: " + << node.name(); + op->output_shapes.clear(); + break; + } const auto status = - ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, - &output_shape); + ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape); if (!status.ok()) { return status; } @@ -1159,15 +1187,9 @@ tensorflow::Status ConvertPlaceholderOperator( if (node.attr().count("shape")) { const auto& shape = GetShapeAttr(node, "shape"); auto num_dims = shape.dim_size(); - bool has_wildcard = false; - for (std::size_t i = 0; i < num_dims; i++) { - if (shape.dim(i).size() == -1) { - has_wildcard = true; - } - } // TODO(b/62716978): This logic needs to be revisted. During dims // refactoring it is an interim fix. - if (num_dims > 0 && !has_wildcard) { + if (num_dims > 0 && !HasWildcardDimension(shape)) { auto& dst_array_dims = *array.mutable_shape()->mutable_dims(); dst_array_dims.resize(num_dims); for (std::size_t i = 0; i < num_dims; i++) { |