diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-02 07:51:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-02 07:54:26 -0700 |
commit | ba1c33faeb6df1ae363888e2e7330e219f0679ea (patch) | |
tree | 7d46859699b76ffbb6e0232e1eaec729c5729844 /tensorflow/contrib/lite/toco/tooling_util.cc | |
parent | 5e1448f691afe6e9ba57bb67497311c45b855b82 (diff) |
ArraysExtraInfo: Add name_regexp field and regexp name matching.
PiperOrigin-RevId: 195091587
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 79 |
1 files changed, 50 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index f334c51bbb..36f38ba8b0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "third_party/re2/re2.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) { } } +// Returns the array names that match the ArraysExtraInfo's name and +// name_regexp. The regexp match is for a full match. +std::unordered_set<string> ScanArrayNames( + const Model& model, const toco::ArraysExtraInfo_Entry& entry) { + std::unordered_set<string> matches; + if (model.HasArray(entry.name())) { + matches.insert(entry.name()); + } + if (!entry.name_regexp().empty()) { + const auto& arrays = model.GetArrayMap(); + const RE2 name_regexp = {entry.name_regexp()}; + for (auto it = arrays.begin(); it != arrays.end(); ++it) { + if (RE2::FullMatch(it->first, name_regexp)) { + matches.insert(it->first); + } + } + } + return matches; +} + void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { - if (!model->HasArray(entry.name())) { - continue; - } - auto& array = model->GetArray(entry.name()); - if (entry.has_min() || entry.has_max()) { - CHECK_EQ(entry.has_min(), entry.has_max()); - auto& minmax = array.GetOrCreateMinMax(); - minmax.min = entry.min(); - minmax.max = entry.max(); - } - if (entry.has_data_type() && quantize_output) { - array.final_data_type = - ConvertIODataTypeToArrayDataType(entry.data_type()); - } - if (entry.has_shape()) { - array.clear_shape(); - // Make sure to create the shape even if there are no dims, to - // correctly record 0-D shapes. - array.mutable_shape(); - for (int dim : entry.shape().dims()) { - array.mutable_shape()->mutable_dims()->push_back(dim); + const auto matches = ScanArrayNames(*model, entry); + for (const auto& matched_name : matches) { + auto& array = model->GetArray(matched_name); + if (entry.has_min() || entry.has_max()) { + CHECK_EQ(entry.has_min(), entry.has_max()); + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = entry.min(); + minmax.max = entry.max(); } - } - if (entry.has_constant_float_value()) { - CHECK(array.has_shape()); - if (array.data_type == ArrayDataType::kFloat) { - auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data; - data.resize(RequiredBufferSizeForShape(array.shape())); - for (float& f : data) { - f = entry.constant_float_value(); + if (entry.has_data_type() && quantize_output) { + array.final_data_type = + ConvertIODataTypeToArrayDataType(entry.data_type()); + } + if (entry.has_shape()) { + array.clear_shape(); + // Make sure to create the shape even if there are no dims, to + // correctly record 0-D shapes. + array.mutable_shape(); + for (int dim : entry.shape().dims()) { + array.mutable_shape()->mutable_dims()->push_back(dim); + } + } + if (entry.has_constant_float_value()) { + CHECK(array.has_shape()); + if (array.data_type == ArrayDataType::kFloat) { + auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); + } } } } |