aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 08:51:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 08:53:57 -0700
commit22eed5405906de6df8846bd9ce4ee0a57917aa3c (patch)
treeeebdb88caccc1202e67513193090d0f160cf1854 /tensorflow/contrib/lite/toco/tooling_util.cc
parent72fd2b8e97f301039ac0eb60435cbbddf36212f6 (diff)
Automated g4 rollback of changelist 195091587
PiperOrigin-RevId: 195098224
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc79
1 files changed, 29 insertions, 50 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 36f38ba8b0..f334c51bbb 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
-#include "third_party/re2/re2.h"
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
@@ -1984,58 +1983,38 @@ void FinishBuildingRNNStates(Model* model) {
}
}
-// Returns the array names that match the ArraysExtraInfo's name and
-// name_regexp. The regexp match is for a full match.
-std::unordered_set<string> ScanArrayNames(
- const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
- std::unordered_set<string> matches;
- if (model.HasArray(entry.name())) {
- matches.insert(entry.name());
- }
- if (!entry.name_regexp().empty()) {
- const auto& arrays = model.GetArrayMap();
- const RE2 name_regexp = {entry.name_regexp()};
- for (auto it = arrays.begin(); it != arrays.end(); ++it) {
- if (RE2::FullMatch(it->first, name_regexp)) {
- matches.insert(it->first);
- }
- }
- }
- return matches;
-}
-
void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
- const auto matches = ScanArrayNames(*model, entry);
- for (const auto& matched_name : matches) {
- auto& array = model->GetArray(matched_name);
- if (entry.has_min() || entry.has_max()) {
- CHECK_EQ(entry.has_min(), entry.has_max());
- auto& minmax = array.GetOrCreateMinMax();
- minmax.min = entry.min();
- minmax.max = entry.max();
- }
- if (entry.has_data_type() && quantize_output) {
- array.final_data_type =
- ConvertIODataTypeToArrayDataType(entry.data_type());
- }
- if (entry.has_shape()) {
- array.clear_shape();
- // Make sure to create the shape even if there are no dims, to
- // correctly record 0-D shapes.
- array.mutable_shape();
- for (int dim : entry.shape().dims()) {
- array.mutable_shape()->mutable_dims()->push_back(dim);
- }
+ if (!model->HasArray(entry.name())) {
+ continue;
+ }
+ auto& array = model->GetArray(entry.name());
+ if (entry.has_min() || entry.has_max()) {
+ CHECK_EQ(entry.has_min(), entry.has_max());
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
+ }
+ if (entry.has_data_type() && quantize_output) {
+ array.final_data_type =
+ ConvertIODataTypeToArrayDataType(entry.data_type());
+ }
+ if (entry.has_shape()) {
+ array.clear_shape();
+ // Make sure to create the shape even if there are no dims, to
+ // correctly record 0-D shapes.
+ array.mutable_shape();
+ for (int dim : entry.shape().dims()) {
+ array.mutable_shape()->mutable_dims()->push_back(dim);
}
- if (entry.has_constant_float_value()) {
- CHECK(array.has_shape());
- if (array.data_type == ArrayDataType::kFloat) {
- auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
- data.resize(RequiredBufferSizeForShape(array.shape()));
- for (float& f : data) {
- f = entry.constant_float_value();
- }
+ }
+ if (entry.has_constant_float_value()) {
+ CHECK(array.has_shape());
+ if (array.data_type == ArrayDataType::kFloat) {
+ auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ data.resize(RequiredBufferSizeForShape(array.shape()));
+ for (float& f : data) {
+ f = entry.constant_float_value();
}
}
}