aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 18:11:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 18:13:39 -0700
commitdb329cfe2dee382033ad3b3f5e1d906ff489a24d (patch)
tree9a2ac77e9d32a1a793c457f51c5b3e363b2c651a /tensorflow/contrib/lite/toco/tooling_util.cc
parent8f0a90b711480c12716d1a3b1094cc8b34939f2d (diff)
Automated g4 rollback of changelist 195091587
PiperOrigin-RevId: 195184798
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc79
1 files changed, 50 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index f334c51bbb..11293a5fe5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
+#include "re2/re2.h"
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
@@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) {
}
}
+// Returns the array names that match the ArraysExtraInfo's name and
+// name_regexp. The regexp match is for a full match.
+std::unordered_set<string> ScanArrayNames(
+ const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
+ std::unordered_set<string> matches;
+ if (model.HasArray(entry.name())) {
+ matches.insert(entry.name());
+ }
+ if (!entry.name_regexp().empty()) {
+ const auto& arrays = model.GetArrayMap();
+ const RE2 name_regexp = {entry.name_regexp()};
+ for (auto it = arrays.begin(); it != arrays.end(); ++it) {
+ if (RE2::FullMatch(it->first, name_regexp)) {
+ matches.insert(it->first);
+ }
+ }
+ }
+ return matches;
+}
+
void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
- if (!model->HasArray(entry.name())) {
- continue;
- }
- auto& array = model->GetArray(entry.name());
- if (entry.has_min() || entry.has_max()) {
- CHECK_EQ(entry.has_min(), entry.has_max());
- auto& minmax = array.GetOrCreateMinMax();
- minmax.min = entry.min();
- minmax.max = entry.max();
- }
- if (entry.has_data_type() && quantize_output) {
- array.final_data_type =
- ConvertIODataTypeToArrayDataType(entry.data_type());
- }
- if (entry.has_shape()) {
- array.clear_shape();
- // Make sure to create the shape even if there are no dims, to
- // correctly record 0-D shapes.
- array.mutable_shape();
- for (int dim : entry.shape().dims()) {
- array.mutable_shape()->mutable_dims()->push_back(dim);
+ const auto matches = ScanArrayNames(*model, entry);
+ for (const auto& matched_name : matches) {
+ auto& array = model->GetArray(matched_name);
+ if (entry.has_min() || entry.has_max()) {
+ CHECK_EQ(entry.has_min(), entry.has_max());
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
}
- }
- if (entry.has_constant_float_value()) {
- CHECK(array.has_shape());
- if (array.data_type == ArrayDataType::kFloat) {
- auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
- data.resize(RequiredBufferSizeForShape(array.shape()));
- for (float& f : data) {
- f = entry.constant_float_value();
+ if (entry.has_data_type() && quantize_output) {
+ array.final_data_type =
+ ConvertIODataTypeToArrayDataType(entry.data_type());
+ }
+ if (entry.has_shape()) {
+ array.clear_shape();
+ // Make sure to create the shape even if there are no dims, to
+ // correctly record 0-D shapes.
+ array.mutable_shape();
+ for (int dim : entry.shape().dims()) {
+ array.mutable_shape()->mutable_dims()->push_back(dim);
+ }
+ }
+ if (entry.has_constant_float_value()) {
+ CHECK(array.has_shape());
+ if (array.data_type == ArrayDataType::kFloat) {
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
+ }
}
}
}