aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 10:12:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 10:15:57 -0700
commitec33cb09255dc88fb5fc3403cbfb9e0c48805eb3 (patch)
tree60f77cdf38433e38b306e0f01c829f1c5d4e54f2 /tensorflow
parent21f139075de212ccaab69bb89bb96d8b98282523 (diff)
Support for shape attributes in custom ops for Toco
PiperOrigin-RevId: 206012140
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc13
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc26
-rw-r--r--tensorflow/contrib/lite/toco/model.h82
5 files changed, 114 insertions, 47 deletions
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
index 2296f5a064..d979353bb3 100644
--- a/tensorflow/contrib/lite/g3doc/custom_operators.md
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -136,3 +136,39 @@ operations instead of a single operator.
6. Use TF_LITE_ENSURE(context, condition) to check for a specific condition.
Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e.,
these should be done before any resources are allocated that will leak.
+
+## Special TF Graph Attributes
+
+When Toco convertes a TF graph into TFLite format, it makes some assumption
+about custom operations that might be not correct. In this case, the generated
+graph can be not executable.
+
+It is possible to add aditional information about your custom op output to TF
+graph before it is converted. The following attributes are supported:
+
+- **_output_quantized** a boolean attribute, true if the operation outputs are
+ quantized
+- **_output_types** a list of types for output tensors
+- **_output_shapes** a list of shapes for output tensors
+
+### Setting the Attributes
+
+This is an example how the attributes can be set:
+
+```python
+frozen_graph_def = tf.graph_util.convert_variables_to_constants(...)
+for node in frozen_graph_def.node:
+ if node.op == 'sin':
+ node.attr['_output_types'].list.type.extend([
+ types_pb2.DT_FLOAT,
+ ])
+ node.attr['_output_shapes'].list.shape.extend([
+ tf.TensorShape([10]),
+ ])
+ node.attr['_output_quantized'].b = False
+tflite_model = tf.contrib.lite.toco_convert(
+ frozen_graph_def,...)
+```
+
+**Note:** After the attributes are set, the graph can not be executed by
+Tensorflow, therefore it should be done just before the conversion.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 9848d55c83..9c22497d5e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -154,8 +154,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
return false;
}
for (int i = 0; i < op->outputs.size(); ++i) {
- auto output = op->outputs[i];
- auto data_type = unsupported_op->output_data_types[i];
+ const string& output = op->outputs[i];
+ const ArrayDataType data_type = unsupported_op->output_data_types[i];
model->GetArray(output).data_type = data_type;
}
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 62ed5c46e9..a03b589bae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1786,8 +1786,19 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessArgMinMaxOperator<ArgMinOperator>(
model, static_cast<ArgMinOperator*>(op));
break;
- case OperatorType::kUnsupported:
+ case OperatorType::kUnsupported: {
+ const auto* unsupported_op =
+ static_cast<TensorFlowUnsupportedOperator*>(op);
+ // Attribute can be not specified, ignore it.
+ if (unsupported_op->output_shapes.size() < op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < op->outputs.size(); ++i) {
+ const string& output = op->outputs[i];
+ model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
+ }
break;
+ }
case OperatorType::kSvdf:
ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 032c863945..f36f720857 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1045,6 +1045,11 @@ tensorflow::Status ConvertSimpleOperator(
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
+ // Names of special attributes in TF graph that are used by Toco.
+ static constexpr char kAttrOutputQuantized[] = "_output_quantized";
+ static constexpr char kAttrOutputTypes[] = "_output_types";
+ static constexpr char kAttrOutputShapes[] = "_output_shapes";
+
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1055,11 +1060,11 @@ tensorflow::Status ConvertUnsupportedOperator(
op->tensorflow_op = node.op();
node.SerializeToString(&op->tensorflow_node_def);
model->operators.emplace_back(op);
- if (HasAttr(node, "_output_quantized")) {
- op->quantized = GetBoolAttr(node, "_output_quantized");
+ if (HasAttr(node, kAttrOutputQuantized)) {
+ op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
- if (HasAttr(node, "_output_types")) {
- const auto& output_types = GetListAttr(node, "_output_types");
+ if (HasAttr(node, kAttrOutputTypes)) {
+ const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
}
@@ -1067,6 +1072,19 @@ tensorflow::Status ConvertUnsupportedOperator(
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
}
+ if (HasAttr(node, kAttrOutputShapes)) {
+ const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
+ Shape output_shape;
+ for (int i = 0; i < output_shapes.shape_size(); ++i) {
+ const auto status =
+ ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
+ &output_shape);
+ if (!status.ok()) {
+ return status;
+ }
+ op->output_shapes.push_back(output_shape);
+ }
+ }
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index d629787939..6459dccf64 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -292,6 +292,46 @@ struct Buffer : GenericBuffer {
std::vector<DataType<A>> data;
};
+class Shape {
+ public:
+ // For Shape, we stick to half-way encapsulation for now:
+ // we hide the raw dims_ member, but expose it raw by accessors
+ // because from some brainstorming, it's not at all easy to
+ // anticipate which flavor of more hermetic encapsulation would
+ // actually buy us future-proof-ness without being needlessly
+ // cumbersome.
+ Shape() {}
+ Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
+
+ void ReplaceDims(std::initializer_list<int> dim_list) {
+ dims_ = std::vector<int>(dim_list);
+ }
+
+ const std::vector<int>& dims() const { return dims_; }
+ std::vector<int>* mutable_dims() { return &dims_; }
+ const int dimensions_count() const { return dims_.size(); }
+
+ // We still have that one convenience accessor to avoid
+ // the awkward double bracket issue: shape.dims()[i].
+ int dims(int i) const {
+ // Always check for out-of-bounds accesses, even in optimized builds where
+ // standard assertions are disabled. Out-of-bounds access here is a common
+ // occurrence.
+ CHECK_GE(i, 0);
+ CHECK_GT(dims_.size(), i);
+ return dims_[i];
+ }
+
+ bool operator==(const Shape& comp) const {
+ return (this->dims_ == comp.dims());
+ }
+
+ bool operator!=(const Shape& comp) const { return !((*this) == comp); }
+
+ private:
+ std::vector<int> dims_;
+};
+
// Base class for all operator classes.
struct Operator {
// Non-default-constructible: only OperatorType-specific subclass
@@ -1469,6 +1509,8 @@ struct TensorFlowUnsupportedOperator : Operator {
bool quantized = false;
// Output data types
std::vector<ArrayDataType> output_data_types;
+ // Output shapes.
+ std::vector<Shape> output_shapes;
};
// Softmax activation function.
@@ -1739,46 +1781,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) {
return a.start < b.start;
}
-class Shape {
- public:
- // For Shape, we stick to half-way encapsulation for now:
- // we hide the raw dims_ member, but expose it raw by accessors
- // because from some brainstorming, it's not at all easy to
- // anticipate which flavor of more hermetic encapsulation would
- // actually buy us future-proof-ness without being needlessly
- // cumbersome.
- Shape() {}
- Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
-
- void ReplaceDims(std::initializer_list<int> dim_list) {
- dims_ = std::vector<int>(dim_list);
- }
-
- const std::vector<int>& dims() const { return dims_; }
- std::vector<int>* mutable_dims() { return &dims_; }
- const int dimensions_count() const { return dims_.size(); }
-
- // We still have that one convenience accessor to avoid
- // the awkward double bracket issue: shape.dims()[i].
- int dims(int i) const {
- // Always check for out-of-bounds accesses, even in optimized builds where
- // standard assertions are disabled. Out-of-bounds access here is a common
- // occurrence.
- CHECK_GE(i, 0);
- CHECK_GT(dims_.size(), i);
- return dims_[i];
- }
-
- bool operator==(const Shape& comp) const {
- return (this->dims_ == comp.dims());
- }
-
- bool operator!=(const Shape& comp) const { return !((*this) == comp); }
-
- private:
- std::vector<int> dims_;
-};
-
// Array represents an array (either a constant parameter array or an
// activations array) in a Model.
struct Array {