aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 10:12:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 10:15:57 -0700
commitec33cb09255dc88fb5fc3403cbfb9e0c48805eb3 (patch)
tree60f77cdf38433e38b306e0f01c829f1c5d4e54f2 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent21f139075de212ccaab69bb89bb96d8b98282523 (diff)
Support for shape attributes in custom ops for Toco
PiperOrigin-RevId: 206012140
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc26
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 032c863945..f36f720857 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1045,6 +1045,11 @@ tensorflow::Status ConvertSimpleOperator(
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
+ // Names of special attributes in TF graph that are used by Toco.
+ static constexpr char kAttrOutputQuantized[] = "_output_quantized";
+ static constexpr char kAttrOutputTypes[] = "_output_types";
+ static constexpr char kAttrOutputShapes[] = "_output_shapes";
+
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1055,11 +1060,11 @@ tensorflow::Status ConvertUnsupportedOperator(
op->tensorflow_op = node.op();
node.SerializeToString(&op->tensorflow_node_def);
model->operators.emplace_back(op);
- if (HasAttr(node, "_output_quantized")) {
- op->quantized = GetBoolAttr(node, "_output_quantized");
+ if (HasAttr(node, kAttrOutputQuantized)) {
+ op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
- if (HasAttr(node, "_output_types")) {
- const auto& output_types = GetListAttr(node, "_output_types");
+ if (HasAttr(node, kAttrOutputTypes)) {
+ const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
}
@@ -1067,6 +1072,19 @@ tensorflow::Status ConvertUnsupportedOperator(
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
}
+ if (HasAttr(node, kAttrOutputShapes)) {
+ const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
+ Shape output_shape;
+ for (int i = 0; i < output_shapes.shape_size(); ++i) {
+ const auto status =
+ ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
+ &output_shape);
+ if (!status.ok()) {
+ return status;
+ }
+ op->output_shapes.push_back(output_shape);
+ }
+ }
return tensorflow::Status::OK();
}