aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc8
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc8
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc46
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc174
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc2
-rw-r--r--tensorflow/contrib/lite/toco/model.h39
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc49
58 files changed, 288 insertions, 270 deletions
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index d4da8f5dfe..5961d30bf5 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -148,7 +148,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
if (!IsAllocatableTransientArray(model, array_name)) {
return 0;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(array->has_shape())
<< "Array '" << array_name << "' doesn't have a shape";
if (array->data_type == ArrayDataType::kNone) {
@@ -185,7 +185,7 @@ void AllocateTransientArray(const Model& model, const string& array_name,
}
const std::size_t size =
TransientArraySize(model, array_name, transient_data_alignment);
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!array->alloc);
allocator->Allocate(size, &array->GetOrCreateAlloc());
}
@@ -197,7 +197,7 @@ void DeallocateTransientArray(const Model& model, const string& array_name,
if (!IsAllocatableTransientArray(model, array_name)) {
return;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!!array->alloc);
allocator->Deallocate(*array->alloc);
}
@@ -231,7 +231,7 @@ void AllocateTransientArrays(Model* model,
// Construct a sorted map of array names, so that other layout engines can
// match exactly.
std::map<string, const Array*> ordered_arrays_map;
- for (const auto& pair : model->arrays) {
+ for (const auto& pair : model->GetArrayMap()) {
ordered_arrays_map[pair.first] = pair.second.get();
}
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 39809216c7..c726eb6d86 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -278,8 +278,8 @@ std::vector<const Operator*> OperatorsToDump(const Model& model) {
if (last_specified) {
// Return only the part of the graph between graphviz_first_array
// and graphviz_last_array.
- CHECK(model.arrays.count(dump_options.graphviz_first_array));
- CHECK(model.arrays.count(dump_options.graphviz_last_array));
+ CHECK(model.HasArray(dump_options.graphviz_first_array));
+ CHECK(model.HasArray(dump_options.graphviz_last_array));
std::unordered_set<string> arrays_already_produced;
std::vector<string> arrays_to_produce;
arrays_to_produce.push_back(dump_options.graphviz_last_array);
@@ -336,7 +336,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
op_properties.color.TextColorString().c_str());
// Add nodes and edges for all inputs of the operator.
for (const auto& input : op.inputs) {
- if (model.arrays.count(input) == 0) {
+ if (!model.HasArray(input)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
@@ -352,7 +352,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
}
// Add nodes and edges for all outputs of the operator.
for (const auto& output : op.outputs) {
- if (model.arrays.count(output) == 0) {
+ if (!model.HasArray(output)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 90fa442746..4fc01dbc20 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -156,8 +156,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -177,8 +177,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -193,8 +193,8 @@ void ConvertIntTensorConst(const Model& model, const string& name,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- CHECK(model.arrays.count(name));
- const auto& array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& array = model.GetArray(name);
auto* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
@@ -324,7 +324,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_array_name =
WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_array_name);
@@ -361,7 +361,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
// We need to convert that to H x W x InputDepth x Multiplier.
// That's only a matter of constructing a Dims object; the actual
// array layout is the same.
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& src_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
const auto& src_weights_array = model.GetArray(src_weights_name);
@@ -404,7 +404,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_name);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
@@ -469,10 +469,10 @@ void ConvertFullyConnectedOperator(const Model& model,
(*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*matmul_op->mutable_attr())["transpose_a"].set_b(false);
(*matmul_op->mutable_attr())["transpose_b"].set_b(false);
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& fc_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
- const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+ const auto& fc_weights_array = model.GetArray(fc_weights_name);
const auto& fc_weights_shape = fc_weights_array.shape();
CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
@@ -492,8 +492,8 @@ void ConvertFullyConnectedOperator(const Model& model,
biasadd_op->add_input(matmul_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
- const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+ CHECK(model.HasArray(src_op.inputs[2]));
+ const auto& bias_array = model.GetArray(src_op.inputs[2]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
@@ -625,7 +625,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
*reshape_op->add_input() = softmax_size;
(*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
- const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+ const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
int32 flattened_size = 1;
for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
flattened_size *= input_shape.dims(i);
@@ -1013,8 +1013,8 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Op names have been chosen to match the tf.slim LSTM naming
// as closely as possible.
const int axis =
- model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
- ->shape()
+ model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+ .shape()
.dimensions_count() -
1;
// Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
@@ -1033,9 +1033,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Write weights
const string weights_output = base + "weights";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
const auto& weights_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Convert 4D FullyConnected weights into 2D matrix
const auto& weights_shape = weights_array.shape();
CHECK_EQ(weights_shape.dimensions_count(), 2);
@@ -1059,9 +1059,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Write biases
const string biases_output = base + "biases";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
const auto& bias_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
@@ -1557,7 +1557,7 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
- const auto& state_array = *model.arrays.at(name);
+ const auto& state_array = model.GetArray(name);
if (state_array.has_shape()) {
const auto& state_shape = state_array.shape();
const int kDims = state_shape.dimensions_count();
@@ -1574,7 +1574,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(),
- model.arrays.at(input_array.name())->data_type,
+ model.GetArray(input_array.name()).data_type,
tensorflow_graph);
}
for (const auto& rnn_state : model.flags.rnn_states()) {
@@ -1588,7 +1588,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
// by the above operators export. It's important that this comes
// after, as some operators need to export arrays that they reference
// in a specific way, rather than in the generic way done below.
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& array_name = array_pair.first;
const auto& array = *array_pair.second;
if (array.buffer) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 3bde9b0169..56f48d47de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -35,7 +35,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(expand_op->inputs.size(), 2);
CHECK_EQ(expand_op->outputs.size(), 1);
- const auto& input_array = *model->arrays[expand_op->inputs[0]];
+ const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
@@ -46,7 +46,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
return false;
}
- const auto& axis_array = *model->arrays[expand_op->inputs[1]];
+ const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
return false;
@@ -86,7 +86,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
if (IsDiscardableArray(*model, axis_array_name) &&
CountOpsWithInput(*model, axis_array_name) == 1 &&
!GetOpWithOutput(*model, axis_array_name)) {
- model->arrays.erase(axis_array_name);
+ model->EraseArray(axis_array_name);
}
// Replace the operator in the graph.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index bf454c40c7..d38db85280 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -58,7 +58,7 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
depthwiseconv_op->outputs = {conv_op->outputs[0]};
if (conv_op->outputs.size() > 1) {
// delete the im2col array.
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
}
depthwiseconv_op->fused_activation_function =
conv_op->fused_activation_function;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index a234c20924..c2b166033c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -29,7 +29,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
- const auto& output_array = *model->arrays[transpose_op->outputs[0]];
+ const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
@@ -70,7 +70,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
// Delete perm array if unused
if (IsDiscardableArray(*model, perm_array_name) &&
CountOpsWithInput(*model, perm_array_name) == 1) {
- model->arrays.erase(perm_array_name);
+ model->EraseArray(perm_array_name);
}
// Replace the operator in the graph.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 1735b51e5b..076415ece8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -35,7 +35,7 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
// We already have an im2col array
return false;
}
- const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ const auto& weights_array = model->GetArray(conv_op->inputs[1]);
if (!weights_array.has_shape()) {
// We need to yield until weights dims have been resolved, because
// from the weights dims we determine whether an im2col array is
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index 79854cba34..498c864bde 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -53,7 +53,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
}
void ClearArrayQuantizationParams(const string& array_name, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
CHECK(array->quantization_params);
for (auto& input_array : *model->flags.mutable_input_arrays()) {
if (input_array.name() == array_name) {
@@ -77,7 +77,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) {
bool DequantizeArray(const string& array_name,
GraphTransformation* transformation, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
if (!array->quantization_params) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
index fea360740f..95558ef5ec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -45,7 +45,7 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
// Drop min/max inputs
for (int i = 1; i < fakequant_op->inputs.size(); i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
index a3ed6663bc..f7fd878b7e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -32,7 +32,7 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
// Drop the im2col array.
CHECK_EQ(conv_op->outputs.size(), 2);
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
conv_op->outputs.resize(1);
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
index ad4a6f9b78..88e59664ec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -91,7 +91,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
} else {
LOG(FATAL) << "Unhandled activation function type";
}
- model->arrays.erase(ac_op->inputs[0]);
+ model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index 4619d8bbee..dcbbead517 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -285,13 +285,13 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
LogName(*following_op));
- model->arrays.erase(binary_op->outputs[0]);
+ model->EraseArray(binary_op->outputs[0]);
following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
const auto& old_constant_param_name =
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 8948653ec3..5b57178b18 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -309,7 +309,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LOG(FATAL) << "should not get here";
}
- model->arrays.erase(preceding_op->outputs[0]);
+ model->EraseArray(preceding_op->outputs[0]);
preceding_op->outputs[0] = binary_op->outputs[0];
preceding_op->fused_activation_function =
binary_op->fused_activation_function;
@@ -317,7 +317,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
index f861c4147a..2340f0e850 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -31,13 +31,13 @@ namespace {
void PrintModelStats(const string& label, const Model& model) {
int quantized_arrays = 0;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (array.second->quantization_params) {
quantized_arrays++;
}
}
LOG(INFO) << label << ": " << model.operators.size() << " operators, "
- << model.arrays.size() << " arrays (" << quantized_arrays
+ << model.GetArrayMap().size() << " arrays (" << quantized_arrays
<< " quantized)";
}
@@ -91,14 +91,9 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
}
} while (found_new_useful_arrays);
// Erase arrays that aren't useful, and that are discardable.
- for (auto it = model->arrays.begin(); it != model->arrays.end();) {
- if (useful_arrays.count(it->first) ||
- !IsDiscardableArray(*model, it->first)) {
- ++it;
- } else {
- it = model->arrays.erase(it);
- }
- }
+ model->EraseArrays([&](const string& name) {
+ return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
+ });
// Erase operators that do not produce a useful output array.
for (auto it = model->operators.begin(); it != model->operators.end();) {
// Only need to test the first output, as we simultaneously added all of
@@ -118,8 +113,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
std::vector<RnnState> rnn_states_to_keep;
for (const auto& rnn_state : model->flags.rnn_states()) {
const bool dangling =
- !model->arrays.count(rnn_state.back_edge_source_array()) ||
- !model->arrays.count(rnn_state.state_array());
+ !model->HasArray(rnn_state.back_edge_source_array()) ||
+ !model->HasArray(rnn_state.state_array());
if (dangling) {
CHECK(rnn_state.discardable());
} else {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index 01b75e37c6..419a0776a6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -150,19 +150,19 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
// Erase the subgraph that is now replaced by L2Normalization
model->operators.erase(FindOperator(model, square_op));
- model->arrays.erase(sum_op->inputs[0]);
+ model->EraseArray(sum_op->inputs[0]);
if (sum_op->inputs.size() > 1) {
- model->arrays.erase(sum_op->inputs[1]);
+ model->EraseArray(sum_op->inputs[1]);
}
model->operators.erase(FindOperator(model, sum_op));
if (add_op) {
- model->arrays.erase(add_op->inputs[0]);
- model->arrays.erase(add_op->inputs[1]);
+ model->EraseArray(add_op->inputs[0]);
+ model->EraseArray(add_op->inputs[1]);
model->operators.erase(FindOperator(model, add_op));
}
- model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
- model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
index 1865416fc2..e4d52476c6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -92,8 +92,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
// Erase intermediate arrays, keeping input to square op.
- model->arrays.erase(avpool_op->inputs[0]);
- model->arrays.erase(sqrt_op->inputs[0]);
+ model->EraseArray(avpool_op->inputs[0]);
+ model->EraseArray(sqrt_op->inputs[0]);
// Erase three operators being replaced.
model->operators.erase(FindOperator(model, square_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index cfc77024e7..d36e950609 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -89,12 +89,12 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
// Erase Maximum scalar input & operator
- model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->EraseArray(maximum_op->inputs[scalar_input_index]);
model->operators.erase(FindOperator(model, maximum_op));
// Erase Minimum inputs & operator
- model->arrays.erase(minimum_op->inputs[0]);
- model->arrays.erase(minimum_op->inputs[1]);
+ model->EraseArray(minimum_op->inputs[0]);
+ model->EraseArray(minimum_op->inputs[1]);
model->operators.erase(FindOperator(model, minimum_op));
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 29b55d9bfc..f0d107232b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -27,7 +27,7 @@ namespace {
void SetDataTypeForAllOutputs(Model* model, Operator* op,
ArrayDataType data_type) {
for (const auto& output : op->outputs) {
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
}
} // namespace
@@ -39,7 +39,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// If the data type of some input is unknown, we need to yield.
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
- model->arrays[input]->data_type == ArrayDataType::kNone) {
+ model->GetArray(input).data_type == ArrayDataType::kNone) {
return false;
}
}
@@ -47,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// end if we changed anything, and return the correct boolean value.
std::unordered_map<string, ArrayDataType> old_output_data_types;
for (const auto& output : op->outputs) {
- old_output_data_types[output] = model->arrays[output]->data_type;
+ old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
if (op->type == OperatorType::kDequantize ||
@@ -69,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kFill) {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
// Data type of the Cast op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* cast_op = static_cast<CastOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
} else if (op->type == OperatorType::kArgMax) {
// Data type of the ArgMax op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
} else if (op->type == OperatorType::kRange) {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
@@ -91,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// Otherwise use the first input
CHECK_GE(op->inputs.size(), 1);
- data_type = model->arrays[op->inputs[0]]->data_type;
+ data_type = model->GetArray(op->inputs[0]).data_type;
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, data_type);
@@ -103,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
auto output = op->outputs[i];
auto data_type = unsupported_op->output_data_types[i];
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
} else if (op->type == OperatorType::kExpandDims) {
// Yield on ExpandDim until it is converted to Reshape
@@ -111,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
}
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
- if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ if (old_output_data_types[output] != model->GetArray(output).data_type) {
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a939efb4db..ff0a3bd881 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -85,7 +85,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
const string& weights_name = op.inputs[1];
- const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ const auto& weights_shape = model.GetArray(weights_name).shape();
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected) {
return weights_shape.dims(0);
@@ -98,7 +98,7 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
bool EnsureBiasVectorShape(Model* model, Operator* op) {
const string& weights_name = op->inputs[1];
- const auto& weights_array = *model->arrays[weights_name];
+ const auto& weights_array = model->GetArray(weights_name);
// Yield until weights shape has been resolved.
if (!weights_array.has_shape()) {
return false;
@@ -107,7 +107,7 @@ bool EnsureBiasVectorShape(Model* model, Operator* op) {
if (op->inputs.size() < 3) {
return false;
}
- auto& bias_array = *model->arrays[op->inputs[2]];
+ auto& bias_array = model->GetArray(op->inputs[2]);
if (bias_array.has_shape()) {
return true;
}
@@ -126,7 +126,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -134,7 +134,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -156,7 +156,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
if (op->outputs.size() == 2) {
const auto& output_shape = output_array.shape();
const int input_depth = weights_shape.dims(3);
- auto& im2col_array = *model->arrays[op->outputs[1]];
+ auto& im2col_array = model->GetArray(op->outputs[1]);
im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
output_shape.dims(2),
input_depth * kheight * kwidth});
@@ -168,7 +168,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -176,7 +176,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -209,7 +209,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
}
void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -232,7 +232,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
}
void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -258,7 +258,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
void ProcessFillOperator(Model* model, FillOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
@@ -287,7 +287,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -295,7 +295,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_GE(input_shape.dimensions_count(), 1);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -315,13 +315,13 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
void ProcessTensorFlowReshapeOperator(Model* model,
TensorFlowReshapeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -377,14 +377,14 @@ void ProcessTensorFlowReshapeOperator(Model* model,
}
void ProcessSimpleOperator(Model* model, Operator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
@@ -394,14 +394,14 @@ void ProcessSimpleOperator(Model* model, Operator* op) {
void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input0_array = *model->arrays[op->inputs[0]];
- const auto& input1_array = *model->arrays[op->inputs[1]];
+ const auto& input0_array = model->GetArray(op->inputs[0]);
+ const auto& input1_array = model->GetArray(op->inputs[1]);
// Yield until input dims have been resolved.
if (!input0_array.has_shape() || !input1_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
&output_array);
}
@@ -424,11 +424,11 @@ bool KeepDims(const Operator& op) {
void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
CHECK_LE(op->inputs.size(), 2);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
return;
}
@@ -436,7 +436,7 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = *model->arrays[op->inputs[1]];
+ const auto& reduction_array = model->GetArray(op->inputs[1]);
if (!reduction_array.buffer) {
return;
}
@@ -476,11 +476,11 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
if (op->begin.empty()) return;
// Yield until input dims have been resolved.
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
const Shape& input_shape = input_array.shape();
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
CHECK_EQ(input_shape.dims().size(), op->size.size());
@@ -500,7 +500,7 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -515,20 +515,20 @@ void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
// Yield until input dims have been resolved.
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
if (!input_array.has_shape()) {
return;
}
}
auto& output_array = model->GetArray(op->outputs[0]);
// Use 0 input as basis for output dimensions.
- const auto& first_input_array = *model->arrays[op->inputs[0]];
+ const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
CHECK(input_array.has_shape());
if (input_array.shape().dimensions_count() == 0) {
continue;
@@ -548,16 +548,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
void ProcessRangeOperator(Model* model, RangeOperator* op) {
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until input dims have been resolved.
return;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
return;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
return;
}
@@ -599,7 +599,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) {
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
const string& input_name = op->inputs[1];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -618,13 +618,13 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->outputs.size(), op->num_split);
for (const auto& output : op->outputs) {
- model->arrays[output]->copy_shape(output_shape);
+ model->GetArray(output).copy_shape(output_shape);
}
}
void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -641,7 +641,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -658,7 +658,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -679,14 +679,14 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- if (!model->arrays[op->inputs[0]]->has_shape() ||
- !model->arrays[op->inputs[1]]->has_shape()) {
+ if (!model->GetArray(op->inputs[0]).has_shape() ||
+ !model->GetArray(op->inputs[1]).has_shape()) {
return;
}
- const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+ const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
const string& output_size_name = op->inputs[1];
- const auto& output_size_array = *model->arrays[output_size_name];
+ const auto& output_size_array = model->GetArray(output_size_name);
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
CHECK(output_size_array.has_shape());
const auto& output_size_shape = output_size_array.shape();
@@ -697,9 +697,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
}
std::vector<int32> output_shape =
output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
- input_data_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
+ output_shape[1], input_data_shape.dims(3)}));
}
void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
@@ -708,7 +708,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
const auto& input_array =
- *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
// Yield until all input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -717,7 +717,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(input_shape.dimensions_count(), 2);
const auto& prev_activ_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_activ_array.has_shape()) {
return;
@@ -726,7 +726,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(prev_activ_shape.dimensions_count(), 2);
const auto& weights_array =
- *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -735,7 +735,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_EQ(weights_shape.dimensions_count(), 2);
const auto& bias_array =
- *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
// Yield until bias dims have been resolved.
if (!bias_array.has_shape()) {
return;
@@ -744,7 +744,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(bias_shape.dimensions_count(), 1);
const auto& prev_state_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_state_array.has_shape()) {
return;
@@ -784,7 +784,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
}
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -797,8 +797,8 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& paddings_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& paddings_array_shape = paddings_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -830,13 +830,13 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
int output_height = height_with_paddings / block_height;
int output_width = width_with_paddings / block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) * block_height * block_width, output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -846,8 +846,8 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& crops_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& crops_array_shape = crops_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -882,15 +882,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
int output_height = input_height * block_height;
int output_width = input_width * block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) / (block_height * block_width), output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessGatherOperator(Model* model, GatherOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
- const auto& indices_array = *model->arrays[op->inputs[1]];
- auto& output_array = *model->arrays[op->outputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ const auto& indices_array = model->GetArray(op->inputs[1]);
+ auto& output_array = model->GetArray(op->outputs[0]);
// Bail if we already know the output shape.
if (output_array.has_shape()) {
@@ -924,7 +924,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
@@ -932,7 +932,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
if (op->left_padding.empty()) return;
CHECK_EQ(op->left_padding.size(), op->right_padding.size());
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
Shape output_shape = input_array.shape();
@@ -949,13 +949,13 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
void ProcessRankOperator(Model* model, RankOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -970,13 +970,13 @@ void ProcessRankOperator(Model* model, RankOperator* op) {
void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -991,7 +991,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
void ProcessStackOperator(Model* model, StackOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
@@ -1032,7 +1032,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
@@ -1112,12 +1112,12 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
const std::vector<int>& input_dims = input_array.shape().dims();
@@ -1136,18 +1136,18 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
- auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ auto& weights_feature_array = model->GetArray(op->inputs[1]);
if (!weights_feature_array.has_shape()) return;
- const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ const auto& weights_time_array = model->GetArray(op->inputs[2]);
if (!weights_time_array.has_shape()) return;
const bool has_bias = (op->inputs.size() == 4);
if (has_bias) {
- const auto& bias_array = *model->arrays[op->inputs[3]];
+ const auto& bias_array = model->GetArray(op->inputs[3]);
if (!bias_array.has_shape()) return;
}
@@ -1164,13 +1164,13 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
}
void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -1204,7 +1204,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -1222,7 +1222,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
}
output_dims.push_back(1);
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
@@ -1236,8 +1236,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
auto* op = it->get();
std::unordered_map<string, std::vector<int>> old_output_dims;
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape()) {
- old_output_dims[output] = model->arrays[output]->shape().dims();
+ if (model->GetArray(output).has_shape()) {
+ old_output_dims[output] = model->GetArray(output).shape().dims();
}
}
@@ -1433,10 +1433,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
// Return true if any output dim changed, false if none changed.
// Assumption: no transformation clears an output shape, they only add shapes.
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape() &&
- (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ if (model->GetArray(output).has_shape() &&
+ (old_output_dims[output] != model->GetArray(output).shape().dims())) {
AddMessageF("Set shape of %s to [%s]", output,
- absl::StrJoin(model->arrays[output]->shape().dims(), ","));
+ absl::StrJoin(model->GetArray(output).shape().dims(), ","));
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 56082b965a..b973b2b813 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -412,7 +412,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
}
}
- model->arrays.erase(dequantize_op->outputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
model->operators.erase(dequantize_it);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
index 371ced388a..11f8d4b6ee 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -80,7 +80,7 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
// else.
for (int i = 1; i <= 2; i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
index 3992e7d1ef..c3b2709a33 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -51,7 +51,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
// Remove the node and its output array.
AddMessageF("Removed final %s", LogName(*dequantize_op));
- model->arrays.erase(output);
+ model->EraseArray(output);
model->operators.erase(dequantize_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
index 6add443f2d..8512e6bb5a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -81,7 +81,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
// Now check if the constant operand makes this binary
// operator trivial.
const auto& constant_input_array =
- *model->arrays[binary_op->inputs[index_of_constant_input]];
+ model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
index 23a5c857e8..936854a04f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -59,7 +59,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
for (const string& input : trivial_inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
concat_op->inputs = nontrivial_inputs;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 047389f69a..587f171bbf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -124,7 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
}
}
if (!is_referenced) {
- model->arrays.erase(removal_candidate);
+ model->EraseArray(removal_candidate);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index e6cca8acf3..aa2c293382 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -33,7 +33,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
// the model. We allow specifying an arbitrary input_array,
// treating the part of the graph leading up to it as unused.
for (const auto& output : op->outputs) {
- CHECK(model->arrays.count(output));
+ CHECK(model->HasArray(output));
// If this output is provided as the model's input array,
// then we don't need this operator to produce its contents.
if (IsInputArray(*model, output)) {
@@ -93,7 +93,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
@@ -116,7 +116,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
continue;
}
// Generic case: do delete this output array.
- model->arrays.erase(output);
+ model->EraseArray(output);
}
model->operators.erase(it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
index 3eb7fa3896..fb109eb91b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -121,9 +121,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
}
// Remove the old param arrays
- model->arrays.erase(bn_op->inputs[1]);
- model->arrays.erase(bn_op->inputs[2]);
- model->arrays.erase(bn_op->inputs[3]);
+ model->EraseArray(bn_op->inputs[1]);
+ model->EraseArray(bn_op->inputs[2]);
+ model->EraseArray(bn_op->inputs[3]);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
index 7777d4f543..a06919e228 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -42,7 +42,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
return false;
// Handle crops
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& crops_array = model->GetArray(op->inputs[2]);
if (!crops_array.has_shape()) return false;
const std::vector<int>& crops_dims = crops_array.shape().dims();
if (crops_dims.size() != 2) {
@@ -58,7 +58,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
}
// Handle block_shape
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index fd51df4058..5e779f6765 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -166,8 +166,9 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
void EvaluateBinaryOperatorOnConstantInputs(Model* model,
const Operator* binary_op) {
- const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
- const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+ const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
+ const auto output_data_type =
+ model->GetArray(binary_op->outputs[0]).data_type;
#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
if (inputs_data_type == InputsDataType && \
output_data_type == OutputDataType) { \
@@ -214,7 +215,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
return false;
}
- auto& output_array = *model->arrays[binary_op->outputs[0]];
+ auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
return false;
@@ -239,10 +240,10 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
// Remove the binary operator and its inputs
if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
- model->arrays.erase(binary_op->inputs[0]);
+ model->EraseArray(binary_op->inputs[0]);
}
if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
- model->arrays.erase(binary_op->inputs[1]);
+ model->EraseArray(binary_op->inputs[1]);
}
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index 9835f86398..833c97c758 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -189,7 +189,7 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove all the resolved arrays.
for (const string& input_name : concat_op->inputs) {
- model->arrays.erase(input_name);
+ model->EraseArray(input_name);
}
// Remove concatenate operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 244adcc4c4..81fe37d7e0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -66,7 +66,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
output_buffer.data[i] = dst_val;
}
if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
- model->arrays.erase(fakequant_op->inputs[0]);
+ model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
index 9da51d9147..f6f95481b5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
@@ -104,11 +104,11 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
// Erase input arrays if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
// Erase the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
index 383d54aa5a..1a0ba9e2bc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
@@ -28,17 +28,17 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
auto* op = static_cast<RangeOperator*>(base_op);
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
@@ -52,7 +52,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
@@ -87,15 +87,15 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
if (IsDiscardableArray(*model, op->inputs[2]) &&
CountOpsWithInput(*model, op->inputs[2]) == 1) {
- model->arrays.erase(op->inputs[2]);
+ model->EraseArray(op->inputs[2]);
}
// Delete the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 35b81dd550..9ea01acd05 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -62,7 +62,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
model->operators.erase(it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
index 86c76141a4..ea0d6dc820 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
@@ -101,7 +101,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 3976d9cbb4..a0cfc3d597 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -186,7 +186,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
// Erase input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
// Erase the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 26ff9d887b..1cd2aff28c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -199,7 +199,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
for (const auto& input : unary_op->inputs) {
if (CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
AddMessageF("Resolved constant %s to the equivalent constant array",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
index b77be3f5c0..013b50ac9b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
@@ -36,7 +36,7 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
if (op->inputs.size() != 2) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& indices_array = *model->arrays[op->inputs[1]];
+ const auto& indices_array = model->GetArray(op->inputs[1]);
if (!indices_array.has_shape()) return false;
op->axis = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
index d5f5869c62..8a8e723cf7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -35,7 +35,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& array = *model->arrays[op->inputs[1]];
+ const auto& array = model->GetArray(op->inputs[1]);
if (!array.has_shape()) return false;
const std::vector<int>& dims = array.shape().dims();
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index b5093bc4c7..5c68f87f6c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -103,7 +103,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
AddMessageF("Reordered axes for array %s", input_array_name);
// Remove the op and output array.
- model->arrays.erase(output_array_name);
+ model->EraseArray(output_array_name);
model->operators.erase(reorder_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
index bed2a85bd2..2e063e3554 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -37,7 +37,7 @@ bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
if (!op->shape.empty()) return false;
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
- const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
index 1d0a2ec8f6..e760d08e5a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -36,10 +36,10 @@ bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- const auto& begin_array = *model->arrays[op->inputs[1]];
+ const auto& begin_array = model->GetArray(op->inputs[1]);
if (!begin_array.has_shape()) return false;
- const auto& size_array = *model->arrays[op->inputs[2]];
+ const auto& size_array = model->GetArray(op->inputs[2]);
if (!size_array.has_shape()) return false;
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
index a73f16735c..dad6aceccf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
@@ -45,7 +45,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
return false;
// Handle paddings.
- const auto& paddings_array = *model->arrays[op->inputs[paddings_index]];
+ const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
if (!paddings_array.has_shape()) return false;
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
if (paddings_dims.size() != 2) {
@@ -61,7 +61,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
}
// Handle block_shape.
- const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
+ const auto& block_shape_array =
+ model->GetArray(op->inputs[block_shape_index]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index dbe69adcbd..de4d06be2a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -31,13 +31,13 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->inputs.size(), 4);
- const auto& start_array = *model->arrays[op->inputs[1]];
+ const auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
- const auto& stop_array = *model->arrays[op->inputs[2]];
+ const auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
- const auto& stride_array = *model->arrays[op->inputs[3]];
+ const auto& stride_array = model->GetArray(op->inputs[3]);
if (!stride_array.has_shape()) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
index c6723a880e..5c0c1e3478 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -75,7 +75,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
// Remove the axis array if it is not used by anything else.
if (CountOpsWithInput(*model, axis_name) == 1) {
- model->arrays.erase(axis_name);
+ model->EraseArray(axis_name);
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index bea7487051..ad1e56888e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -69,7 +69,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
LogName(*matmul_op), LogName(*fc_op));
const auto& previous_op_output = previous_op->outputs[0];
if (CountOpsWithInput(*model, previous_op_output) == 1) {
- model->arrays.erase(previous_op_output);
+ model->EraseArray(previous_op_output);
}
CHECK_EQ(previous_op->inputs.size(), 2);
fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
@@ -78,7 +78,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
const auto& previous_op_shape = previous_op->inputs[1];
if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
!GetOpWithOutput(*model, previous_op_shape)) {
- model->arrays.erase(previous_op_shape);
+ model->EraseArray(previous_op_shape);
}
model->operators.erase(previous_op_it);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
index cfa5ce0716..477e7f13da 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -55,7 +55,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
// Remove the node and its output array.
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
- model->arrays.erase(merge_op->outputs[0]);
+ model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
index 150cf53da3..a418073441 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -103,7 +103,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
// Remove the output arrays if they are now unused.
for (int i = 0; i < 2; i++) {
if (!GetOpWithInput(*model, switch_op->outputs[i])) {
- model->arrays.erase(switch_op->outputs[i]);
+ model->EraseArray(switch_op->outputs[i]);
}
}
// Remove input arrays if they are only used by the switch itself and aren't
@@ -111,7 +111,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
for (const auto& input : switch_op->inputs) {
if (CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
// Remove the switch node itself.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
index 9f7e7c42a2..1ddf54c778 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
@@ -45,10 +45,10 @@ void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
model->operators.erase(tile_it);
if (!CountOpsWithInput(*model, tile_multiplier_array) &&
!GetOpWithOutput(*model, tile_multiplier_array)) {
- model->arrays.erase(tile_multiplier_array);
+ model->EraseArray(tile_multiplier_array);
}
if (!CountOpsWithInput(*model, tile_output_array)) {
- model->arrays.erase(tile_output_array);
+ model->EraseArray(tile_output_array);
}
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
index 12d966b261..a657ee00af 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
@@ -35,7 +35,7 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
// Handling perm.
- const auto& perm_array = *model->arrays[op->inputs[1]];
+ const auto& perm_array = model->GetArray(op->inputs[1]);
if (!perm_array.has_shape()) return false;
const std::vector<int>& perm_dims = perm_array.shape().dims();
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
index a14016e8e2..3a1d175b98 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-//#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -168,11 +167,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
@@ -187,11 +186,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
@@ -206,11 +205,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
index 4e273343df..2c7046c8c7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -63,7 +63,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op->outputs = op->outputs;
const string& tmp_array_name =
AvailableArrayName(*model, op->outputs[0] + "_unfused");
- CHECK(!model->arrays.count(tmp_array_name));
+ CHECK(!model->HasArray(tmp_array_name));
model->GetOrCreateArray(tmp_array_name);
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 995e9d67ca..1947271f55 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1652,7 +1652,7 @@ void StripCaretFromArrayNames(Model* model) {
output = string(absl::StripPrefix(output, "^"));
}
}
- for (auto& array : model->arrays) {
+ for (auto& array : model->GetArrayMap()) {
if (absl::StartsWith(array.first, "^")) {
LOG(FATAL) << "What?";
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index b079750bed..54fbba7381 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -1521,15 +1521,19 @@ struct Array {
// Our Model struct, represents an entire model (our "top-level" struct).
// Owns everything.
-struct Model {
+class Model {
+ public:
+ using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>;
+
+ bool HasArray(const string& name) const { return arrays.count(name) > 0; }
Array& GetArray(const string& name) const {
- DCHECK(arrays.count(name));
+ DCHECK(HasArray(name));
return *arrays.at(name);
}
Array& GetOrCreateArray(const string& name) {
// Make sure name is not used by an optional array
DCHECK(!optional_arrays.count(name));
- if (!arrays.count(name)) {
+ if (!HasArray(name)) {
Array* ptr = new Array;
arrays[name] = std::unique_ptr<Array>(ptr);
}
@@ -1544,18 +1548,27 @@ struct Model {
return optional_arrays.count(name);
}
+ // Note that this invalidates all array iterators.
+ void EraseArray(const string& name) { arrays.erase(name); }
+ void EraseArrays(std::function<bool(const string&)> discardable) {
+ for (auto it = arrays.begin(); it != arrays.end();) {
+ if (discardable(it->first)) {
+ it = arrays.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ const ArrayMap& GetArrayMap() const { return arrays; }
+
// Optional arrays are used for optional tensors,
// these tensors do not have data, but with reserved names as op inputs.
std::set<string> optional_arrays;
+
// The list of operators. Notice how it's a list of unique_ptr's, implying
// that the Model is what owns Operator's and keeps them alive.
std::vector<std::unique_ptr<Operator>> operators;
- // The associative array mapping names to Array's.
- // Notice how it's a container of unique_ptr's, implying
- // that the Model is what owns Array's and keeps them alive.
- // The Operator's refer to these Array's by their name strings, not by their
- // addresses. See Operator::inputs, Operator::outputs.
- std::unordered_map<string, std::unique_ptr<Array>> arrays;
+
// Generic flags, a place where we combine information passed to us via
// command-line parameters (e.g. --input_width=N) with information that
// we may or may not find in the input model file.
@@ -1564,6 +1577,14 @@ struct Model {
std::size_t transient_data_size = 0;
// For code-generation only: required alignment of the transient_data buffer
std::size_t transient_data_alignment = 0;
+
+ private:
+ // The associative array mapping names to Array's.
+ // Notice how it's a container of unique_ptr's, implying
+ // that the Model is what owns Array's and keeps them alive.
+ // The Operator's refer to these Array's by their name strings, not by their
+ // addresses. See Operator::inputs, Operator::outputs.
+ std::unordered_map<string, std::unique_ptr<Array>> arrays;
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 440353203e..391ef87029 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -62,7 +62,7 @@ namespace details {
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names.
std::set<string> names;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
names.insert(array_pair.first);
}
@@ -96,7 +96,7 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
// tensors in the tensors_map.
std::map<int, Offset<Tensor>> ordered_tensors;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& tensor_name = array_pair.first;
const toco::Array& array = *array_pair.second;
diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc
index 309fa6d7f6..aad6e780d5 100644
--- a/tensorflow/contrib/lite/toco/tflite/import_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc
@@ -114,7 +114,7 @@ TEST_F(ImportTest, Tensors) {
auto model = Import(ModelFlags(), InputModelAsString());
- ASSERT_GT(model->arrays.count("tensor_one"), 0);
+ ASSERT_GT(model->HasArray("tensor_one"), 0);
Array& a1 = model->GetArray("tensor_one");
EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 94b4d14696..afaa0fd0c7 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -133,7 +133,7 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
string const& array_name = model->flags.input_arrays(i).name();
- auto* array = model->arrays[array_name].get();
+ auto* array = &model->GetArray(array_name);
// Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index f9093ab973..900c60bd90 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -93,7 +93,7 @@ int CountOpsWithInput(const Model& model, const string& array_name) {
bool DeleteArrayIfUnused(const string& array_name, Model* model) {
if (CountOpsWithInput(*model, array_name) == 0) {
- model->arrays.erase(array_name);
+ model->EraseArray(array_name);
return true;
}
return false;
@@ -566,11 +566,11 @@ int RequiredBufferSizeForShape(const Shape& shape) {
}
bool IsConstantParameterArray(const Model& model, const string& name) {
- if (!model.arrays.count(name)) {
+ if (!model.HasArray(name)) {
return false;
}
- return !!model.arrays.at(name)->buffer;
+ return !!model.GetArray(name).buffer;
}
namespace {
@@ -633,17 +633,17 @@ void CheckNonExistentIOArrays(const Model& model) {
return;
}
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.arrays.count(input_array.name()))
+ CHECK(model.HasArray(input_array.name()))
<< "Input array not found: " << input_array.name();
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.arrays.count(output_array))
+ CHECK(model.HasArray(output_array))
<< "Output array not found: " << output_array;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.arrays.count(rnn_state.state_array()));
- CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+ CHECK(model.HasArray(rnn_state.state_array()));
+ CHECK(model.HasArray(rnn_state.back_edge_source_array()));
}
}
}
@@ -652,10 +652,10 @@ void CheckNonExistentIOArrays(const Model& model) {
void CheckNoMissingArray(const Model& model) {
for (const auto& op : model.operators) {
for (const auto& input : op->inputs) {
- CHECK(model.arrays.count(input) || model.optional_arrays.count(input));
+ CHECK(model.HasArray(input) || model.optional_arrays.count(input));
}
for (const auto& output : op->outputs) {
- CHECK(model.arrays.count(output));
+ CHECK(model.HasArray(output));
}
}
CheckNonExistentIOArrays(model);
@@ -664,12 +664,12 @@ void CheckNoMissingArray(const Model& model) {
void FixNoMissingArray(Model* model) {
for (const auto& op : model->operators) {
for (const auto& input : op->inputs) {
- if (!model->arrays.count(input)) {
+ if (!model->HasArray(input)) {
model->GetOrCreateArray(input);
}
}
for (const auto& output : op->outputs) {
- if (!model->arrays.count(output)) {
+ if (!model->HasArray(output)) {
model->GetOrCreateArray(output);
}
}
@@ -687,7 +687,7 @@ void FixNoMissingArray(Model* model) {
void CheckNoOrphanedArray(const Model& model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (IsDiscardableArray(model, array.first)) {
arrays_without_known_use.insert(array.first);
}
@@ -714,7 +714,7 @@ void CheckNoOrphanedArray(const Model& model) {
void FixNoOrphanedArray(Model* model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model->arrays) {
+ for (const auto& array : model->GetArrayMap()) {
arrays_without_known_use.insert(array.first);
}
for (const auto& op : model->operators) {
@@ -731,13 +731,13 @@ void FixNoOrphanedArray(Model* model) {
}
for (const auto& array : arrays_without_known_use) {
if (IsDiscardableArray(*model, array)) {
- model->arrays.erase(array);
+ model->EraseArray(array);
}
}
}
void CheckArrayFieldsConsistent(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
if (array->has_shape()) {
for (int d : array->shape().dims()) {
@@ -756,7 +756,7 @@ void CheckArrayFieldsConsistent(const Model& model) {
void CheckOperatorOrdering(const Model& model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
if (!GetOpWithOutput(model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
@@ -781,7 +781,7 @@ void CheckOperatorOrdering(const Model& model) {
void FixOperatorOrdering(Model* model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model->arrays) {
+ for (const auto& array_entry : model->GetArrayMap()) {
if (!GetOpWithOutput(*model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
@@ -936,7 +936,8 @@ void CheckModelCounts(const Model& model) {
if (count_type == "None") {
continue;
} else if (count_type == "Arrays") {
- CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+ CheckCountInRange(model_check, model.GetArrayMap().size(),
+ "count of arrays");
} else if (count_type == "Total") {
CheckCountInRange(model_check, model.operators.size(),
"count of all operator instances");
@@ -1297,7 +1298,7 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
return false;
}
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
return false;
@@ -1310,13 +1311,13 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
}
string AvailableArrayName(const Model& model, const string& name) {
- if (!model.arrays.count(name) && !model.optional_arrays.count(name)) {
+ if (!model.HasArray(name) && !model.optional_arrays.count(name)) {
return name;
}
const int kNumSuffixesToTry = 1000;
for (int i = 0; i < kNumSuffixesToTry; i++) {
const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
- if (!model.arrays.count(name_with_suffix)) {
+ if (!model.HasArray(name_with_suffix)) {
return name_with_suffix;
}
}
@@ -1334,12 +1335,12 @@ string ShapeToString(const Shape& shape) {
}
void PrintArrayShape(Model* model, const string& name) {
- if (!model->arrays[name]->has_shape()) {
+ if (!model->GetArray(name).has_shape()) {
LOG(INFO) << name << " has no shape";
return;
}
LOG(INFO) << name
- << " has shape: " << ShapeToString(model->arrays[name]->shape());
+ << " has shape: " << ShapeToString(model->GetArray(name).shape());
}
bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
@@ -1673,7 +1674,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
}
void CheckFinalDataTypesSatisfied(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type)