diff options
author | 2018-07-25 10:12:24 -0700 | |
---|---|---|
committer | 2018-07-25 10:15:57 -0700 | |
commit | ec33cb09255dc88fb5fc3403cbfb9e0c48805eb3 (patch) | |
tree | 60f77cdf38433e38b306e0f01c829f1c5d4e54f2 /tensorflow | |
parent | 21f139075de212ccaab69bb89bb96d8b98282523 (diff) |
Support for shape attributes in custom ops for Toco
PiperOrigin-RevId: 206012140
Diffstat (limited to 'tensorflow')
5 files changed, 114 insertions, 47 deletions
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 2296f5a064..d979353bb3 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -136,3 +136,39 @@ operations instead of a single operator. 6. Use TF_LITE_ENSURE(context, condition) to check for a specific condition. Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., these should be done before any resources are allocated that will leak. + +## Special TF Graph Attributes + +When Toco convertes a TF graph into TFLite format, it makes some assumption +about custom operations that might be not correct. In this case, the generated +graph can be not executable. + +It is possible to add aditional information about your custom op output to TF +graph before it is converted. The following attributes are supported: + +- **_output_quantized** a boolean attribute, true if the operation outputs are + quantized +- **_output_types** a list of types for output tensors +- **_output_shapes** a list of shapes for output tensors + +### Setting the Attributes + +This is an example how the attributes can be set: + +```python +frozen_graph_def = tf.graph_util.convert_variables_to_constants(...) +for node in frozen_graph_def.node: + if node.op == 'sin': + node.attr['_output_types'].list.type.extend([ + types_pb2.DT_FLOAT, + ]) + node.attr['_output_shapes'].list.shape.extend([ + tf.TensorShape([10]), + ]) + node.attr['_output_quantized'].b = False +tflite_model = tf.contrib.lite.toco_convert( + frozen_graph_def,...) +``` + +**Note:** After the attributes are set, the graph can not be executed by +Tensorflow, therefore it should be done just before the conversion. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 9848d55c83..9c22497d5e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -154,8 +154,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { return false; } for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; + const string& output = op->outputs[i]; + const ArrayDataType data_type = unsupported_op->output_data_types[i]; model->GetArray(output).data_type = data_type; } break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 62ed5c46e9..a03b589bae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1786,8 +1786,19 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessArgMinMaxOperator<ArgMinOperator>( model, static_cast<ArgMinOperator*>(op)); break; - case OperatorType::kUnsupported: + case OperatorType::kUnsupported: { + const auto* unsupported_op = + static_cast<TensorFlowUnsupportedOperator*>(op); + // Attribute can be not specified, ignore it. + if (unsupported_op->output_shapes.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + const string& output = op->outputs[i]; + model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i)); + } break; + } case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op)); break; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 032c863945..f36f720857 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1045,6 +1045,11 @@ tensorflow::Status ConvertSimpleOperator( tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { + // Names of special attributes in TF graph that are used by Toco. + static constexpr char kAttrOutputQuantized[] = "_output_quantized"; + static constexpr char kAttrOutputTypes[] = "_output_types"; + static constexpr char kAttrOutputShapes[] = "_output_shapes"; + LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1055,11 +1060,11 @@ tensorflow::Status ConvertUnsupportedOperator( op->tensorflow_op = node.op(); node.SerializeToString(&op->tensorflow_node_def); model->operators.emplace_back(op); - if (HasAttr(node, "_output_quantized")) { - op->quantized = GetBoolAttr(node, "_output_quantized"); + if (HasAttr(node, kAttrOutputQuantized)) { + op->quantized = GetBoolAttr(node, kAttrOutputQuantized); } - if (HasAttr(node, "_output_types")) { - const auto& output_types = GetListAttr(node, "_output_types"); + if (HasAttr(node, kAttrOutputTypes)) { + const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } @@ -1067,6 +1072,19 @@ tensorflow::Status ConvertUnsupportedOperator( const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + if (HasAttr(node, kAttrOutputShapes)) { + const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); + Shape output_shape; + for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto status = + ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, + &output_shape); + if (!status.ok()) { + return status; + } + op->output_shapes.push_back(output_shape); + } + } return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d629787939..6459dccf64 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -292,6 +292,46 @@ struct Buffer : GenericBuffer { std::vector<DataType<A>> data; }; +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list<int> dim_list) { + dims_ = std::vector<int>(dim_list); + } + + const std::vector<int>& dims() const { return dims_; } + std::vector<int>* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { + // Always check for out-of-bounds accesses, even in optimized builds where + // standard assertions are disabled. Out-of-bounds access here is a common + // occurrence. + CHECK_GE(i, 0); + CHECK_GT(dims_.size(), i); + return dims_[i]; + } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector<int> dims_; +}; + // Base class for all operator classes. struct Operator { // Non-default-constructible: only OperatorType-specific subclass @@ -1469,6 +1509,8 @@ struct TensorFlowUnsupportedOperator : Operator { bool quantized = false; // Output data types std::vector<ArrayDataType> output_data_types; + // Output shapes. + std::vector<Shape> output_shapes; }; // Softmax activation function. @@ -1739,46 +1781,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -class Shape { - public: - // For Shape, we stick to half-way encapsulation for now: - // we hide the raw dims_ member, but expose it raw by accessors - // because from some brainstorming, it's not at all easy to - // anticipate which flavor of more hermetic encapsulation would - // actually buy us future-proof-ness without being needlessly - // cumbersome. - Shape() {} - Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {} - - void ReplaceDims(std::initializer_list<int> dim_list) { - dims_ = std::vector<int>(dim_list); - } - - const std::vector<int>& dims() const { return dims_; } - std::vector<int>* mutable_dims() { return &dims_; } - const int dimensions_count() const { return dims_.size(); } - - // We still have that one convenience accessor to avoid - // the awkward double bracket issue: shape.dims()[i]. - int dims(int i) const { - // Always check for out-of-bounds accesses, even in optimized builds where - // standard assertions are disabled. Out-of-bounds access here is a common - // occurrence. - CHECK_GE(i, 0); - CHECK_GT(dims_.size(), i); - return dims_[i]; - } - - bool operator==(const Shape& comp) const { - return (this->dims_ == comp.dims()); - } - - bool operator!=(const Shape& comp) const { return !((*this) == comp); } - - private: - std::vector<int> dims_; -}; - // Array represents an array (either a constant parameter array or an // activations array) in a Model. struct Array { |