aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 07:51:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 07:54:26 -0700
commitba1c33faeb6df1ae363888e2e7330e219f0679ea (patch)
tree7d46859699b76ffbb6e0232e1eaec729c5729844 /tensorflow/contrib/lite/toco/tooling_util.cc
parent5e1448f691afe6e9ba57bb67497311c45b855b82 (diff)
ArraysExtraInfo: Add name_regexp field and regexp name matching.
PiperOrigin-RevId: 195091587
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc79
1 files changed, 50 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index f334c51bbb..36f38ba8b0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
+#include "third_party/re2/re2.h"
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
@@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) {
}
}
+// Returns the array names that match the ArraysExtraInfo's name and
+// name_regexp. The regexp match is for a full match.
+std::unordered_set<string> ScanArrayNames(
+ const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
+ std::unordered_set<string> matches;
+ if (model.HasArray(entry.name())) {
+ matches.insert(entry.name());
+ }
+ if (!entry.name_regexp().empty()) {
+ const auto& arrays = model.GetArrayMap();
+ const RE2 name_regexp = {entry.name_regexp()};
+ for (auto it = arrays.begin(); it != arrays.end(); ++it) {
+ if (RE2::FullMatch(it->first, name_regexp)) {
+ matches.insert(it->first);
+ }
+ }
+ }
+ return matches;
+}
+
void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
- if (!model->HasArray(entry.name())) {
- continue;
- }
- auto& array = model->GetArray(entry.name());
- if (entry.has_min() || entry.has_max()) {
- CHECK_EQ(entry.has_min(), entry.has_max());
- auto& minmax = array.GetOrCreateMinMax();
- minmax.min = entry.min();
- minmax.max = entry.max();
- }
- if (entry.has_data_type() && quantize_output) {
- array.final_data_type =
- ConvertIODataTypeToArrayDataType(entry.data_type());
- }
- if (entry.has_shape()) {
- array.clear_shape();
- // Make sure to create the shape even if there are no dims, to
- // correctly record 0-D shapes.
- array.mutable_shape();
- for (int dim : entry.shape().dims()) {
- array.mutable_shape()->mutable_dims()->push_back(dim);
+ const auto matches = ScanArrayNames(*model, entry);
+ for (const auto& matched_name : matches) {
+ auto& array = model->GetArray(matched_name);
+ if (entry.has_min() || entry.has_max()) {
+ CHECK_EQ(entry.has_min(), entry.has_max());
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
}
- }
- if (entry.has_constant_float_value()) {
- CHECK(array.has_shape());
- if (array.data_type == ArrayDataType::kFloat) {
- auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
- data.resize(RequiredBufferSizeForShape(array.shape()));
- for (float& f : data) {
- f = entry.constant_float_value();
+ if (entry.has_data_type() && quantize_output) {
+ array.final_data_type =
+ ConvertIODataTypeToArrayDataType(entry.data_type());
+ }
+ if (entry.has_shape()) {
+ array.clear_shape();
+ // Make sure to create the shape even if there are no dims, to
+ // correctly record 0-D shapes.
+ array.mutable_shape();
+ for (int dim : entry.shape().dims()) {
+ array.mutable_shape()->mutable_dims()->push_back(dim);
+ }
+ }
+ if (entry.has_constant_float_value()) {
+ CHECK(array.has_shape());
+ if (array.data_type == ArrayDataType::kFloat) {
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
+ }
}
}
}