diff options
58 files changed, 288 insertions, 270 deletions
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index d4da8f5dfe..5961d30bf5 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -148,7 +148,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name, if (!IsAllocatableTransientArray(model, array_name)) { return 0; } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(array->has_shape()) << "Array '" << array_name << "' doesn't have a shape"; if (array->data_type == ArrayDataType::kNone) { @@ -185,7 +185,7 @@ void AllocateTransientArray(const Model& model, const string& array_name, } const std::size_t size = TransientArraySize(model, array_name, transient_data_alignment); - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(!array->alloc); allocator->Allocate(size, &array->GetOrCreateAlloc()); } @@ -197,7 +197,7 @@ void DeallocateTransientArray(const Model& model, const string& array_name, if (!IsAllocatableTransientArray(model, array_name)) { return; } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); CHECK(!!array->alloc); allocator->Deallocate(*array->alloc); } @@ -231,7 +231,7 @@ void AllocateTransientArrays(Model* model, // Construct a sorted map of array names, so that other layout engines can // match exactly. std::map<string, const Array*> ordered_arrays_map; - for (const auto& pair : model->arrays) { + for (const auto& pair : model->GetArrayMap()) { ordered_arrays_map[pair.first] = pair.second.get(); } diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 39809216c7..c726eb6d86 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -278,8 +278,8 @@ std::vector<const Operator*> OperatorsToDump(const Model& model) { if (last_specified) { // Return only the part of the graph between graphviz_first_array // and graphviz_last_array. - CHECK(model.arrays.count(dump_options.graphviz_first_array)); - CHECK(model.arrays.count(dump_options.graphviz_last_array)); + CHECK(model.HasArray(dump_options.graphviz_first_array)); + CHECK(model.HasArray(dump_options.graphviz_last_array)); std::unordered_set<string> arrays_already_produced; std::vector<string> arrays_to_produce; arrays_to_produce.push_back(dump_options.graphviz_last_array); @@ -336,7 +336,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { op_properties.color.TextColorString().c_str()); // Add nodes and edges for all inputs of the operator. for (const auto& input : op.inputs) { - if (model.arrays.count(input) == 0) { + if (!model.HasArray(input)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } @@ -352,7 +352,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } // Add nodes and edges for all outputs of the operator. for (const auto& output : op.outputs) { - if (model.arrays.count(output) == 0) { + if (!model.HasArray(output)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 90fa442746..4fc01dbc20 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -156,8 +156,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name, const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); - CHECK(model.arrays.count(name)); - const auto& input_array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& input_array = model.GetArray(name); const auto& input_shape = input_array.shape(); CHECK(input_array.buffer); CHECK(input_array.buffer->type == ArrayDataType::kFloat); @@ -177,8 +177,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name, const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); - CHECK(model.arrays.count(name)); - const auto& input_array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& input_array = model.GetArray(name); const auto& input_shape = input_array.shape(); CHECK(input_array.buffer); CHECK(input_array.buffer->type == ArrayDataType::kFloat); @@ -193,8 +193,8 @@ void ConvertIntTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - CHECK(model.arrays.count(name)); - const auto& array = *model.arrays.at(name); + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); auto* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); @@ -324,7 +324,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, biasadd_op->add_input(conv_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); + CHECK(model.HasArray(src_op.inputs[2])); const string& bias_array_name = WalkUpToConstantArray(model, src_op.inputs[2]); const auto& bias_array = model.GetArray(bias_array_name); @@ -361,7 +361,7 @@ void ConvertDepthwiseConvOperator(const Model& model, // We need to convert that to H x W x InputDepth x Multiplier. // That's only a matter of constructing a Dims object; the actual // array layout is the same. - CHECK(model.arrays.count(src_op.inputs[1])); + CHECK(model.HasArray(src_op.inputs[1])); const string& src_weights_name = WalkUpToConstantArray(model, src_op.inputs[1]); const auto& src_weights_array = model.GetArray(src_weights_name); @@ -404,7 +404,7 @@ void ConvertDepthwiseConvOperator(const Model& model, biasadd_op->add_input(conv_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); + CHECK(model.HasArray(src_op.inputs[2])); const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]); const auto& bias_array = model.GetArray(bias_name); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. @@ -469,10 +469,10 @@ void ConvertFullyConnectedOperator(const Model& model, (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); (*matmul_op->mutable_attr())["transpose_a"].set_b(false); (*matmul_op->mutable_attr())["transpose_b"].set_b(false); - CHECK(model.arrays.count(src_op.inputs[1])); + CHECK(model.HasArray(src_op.inputs[1])); const string& fc_weights_name = WalkUpToConstantArray(model, src_op.inputs[1]); - const auto& fc_weights_array = *model.arrays.at(fc_weights_name); + const auto& fc_weights_array = model.GetArray(fc_weights_name); const auto& fc_weights_shape = fc_weights_array.shape(); CHECK_EQ(fc_weights_shape.dimensions_count(), 2); CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, @@ -492,8 +492,8 @@ void ConvertFullyConnectedOperator(const Model& model, biasadd_op->add_input(matmul_output); biasadd_op->add_input(src_op.inputs[2]); (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); - CHECK(model.arrays.count(src_op.inputs[2])); - const auto& bias_array = *model.arrays.at(src_op.inputs[2]); + CHECK(model.HasArray(src_op.inputs[2])); + const auto& bias_array = model.GetArray(src_op.inputs[2]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -625,7 +625,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, *reshape_op->add_input() = softmax_size; (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); - const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape(); + const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); int32 flattened_size = 1; for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { flattened_size *= input_shape.dims(i); @@ -1013,8 +1013,8 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Op names have been chosen to match the tf.slim LSTM naming // as closely as possible. const int axis = - model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) - ->shape() + model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) + .shape() .dimensions_count() - 1; // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat @@ -1033,9 +1033,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write weights const string weights_output = base + "weights"; - CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); + CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); const auto& weights_array = - *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); // Convert 4D FullyConnected weights into 2D matrix const auto& weights_shape = weights_array.shape(); CHECK_EQ(weights_shape.dimensions_count(), 2); @@ -1059,9 +1059,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write biases const string biases_output = base + "biases"; - CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT])); + CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT])); const auto& bias_array = - *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -1557,7 +1557,7 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size, (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape(); - const auto& state_array = *model.arrays.at(name); + const auto& state_array = model.GetArray(name); if (state_array.has_shape()) { const auto& state_shape = state_array.shape(); const int kDims = state_shape.dimensions_count(); @@ -1574,7 +1574,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, GraphDef* tensorflow_graph) { for (const auto& input_array : model.flags.input_arrays()) { AddPlaceholder(input_array.name(), - model.arrays.at(input_array.name())->data_type, + model.GetArray(input_array.name()).data_type, tensorflow_graph); } for (const auto& rnn_state : model.flags.rnn_states()) { @@ -1588,7 +1588,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, // by the above operators export. It's important that this comes // after, as some operators need to export arrays that they reference // in a specific way, rather than in the generic way done below. - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { const string& array_name = array_pair.first; const auto& array = *array_pair.second; if (array.buffer) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 3bde9b0169..56f48d47de 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -35,7 +35,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(expand_op->inputs.size(), 2); CHECK_EQ(expand_op->outputs.size(), 1); - const auto& input_array = *model->arrays[expand_op->inputs[0]]; + const auto& input_array = model->GetArray(expand_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return false; @@ -46,7 +46,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { return false; } - const auto& axis_array = *model->arrays[expand_op->inputs[1]]; + const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { // Yield until input axis array shape has been resolved. return false; @@ -86,7 +86,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { if (IsDiscardableArray(*model, axis_array_name) && CountOpsWithInput(*model, axis_array_name) == 1 && !GetOpWithOutput(*model, axis_array_name)) { - model->arrays.erase(axis_array_name); + model->EraseArray(axis_array_name); } // Replace the operator in the graph. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index bf454c40c7..d38db85280 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -58,7 +58,7 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { depthwiseconv_op->outputs = {conv_op->outputs[0]}; if (conv_op->outputs.size() > 1) { // delete the im2col array. - model->arrays.erase(conv_op->outputs[1]); + model->EraseArray(conv_op->outputs[1]); } depthwiseconv_op->fused_activation_function = conv_op->fused_activation_function; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index a234c20924..c2b166033c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -29,7 +29,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { TransposeOperator* transpose_op = static_cast<TransposeOperator*>(transpose_it->get()); - const auto& output_array = *model->arrays[transpose_op->outputs[0]]; + const auto& output_array = model->GetArray(transpose_op->outputs[0]); if (!output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. return false; @@ -70,7 +70,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { // Delete perm array if unused if (IsDiscardableArray(*model, perm_array_name) && CountOpsWithInput(*model, perm_array_name) == 1) { - model->arrays.erase(perm_array_name); + model->EraseArray(perm_array_name); } // Replace the operator in the graph. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 1735b51e5b..076415ece8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -35,7 +35,7 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { // We already have an im2col array return false; } - const auto& weights_array = *model->arrays[conv_op->inputs[1]]; + const auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 79854cba34..498c864bde 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -53,7 +53,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput( } void ClearArrayQuantizationParams(const string& array_name, Model* model) { - auto* array = model->arrays.at(array_name).get(); + auto* array = &model->GetArray(array_name); CHECK(array->quantization_params); for (auto& input_array : *model->flags.mutable_input_arrays()) { if (input_array.name() == array_name) { @@ -77,7 +77,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) { bool DequantizeArray(const string& array_name, GraphTransformation* transformation, Model* model) { - auto* array = model->arrays.at(array_name).get(); + auto* array = &model->GetArray(array_name); if (!array->quantization_params) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc index fea360740f..95558ef5ec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -45,7 +45,7 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) { // Drop min/max inputs for (int i = 1; i < fakequant_op->inputs.size(); i++) { if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->arrays.erase(fakequant_op->inputs[i]); + model->EraseArray(fakequant_op->inputs[i]); } } fakequant_op->inputs.resize(1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc index a3ed6663bc..f7fd878b7e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -32,7 +32,7 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { // Drop the im2col array. CHECK_EQ(conv_op->outputs.size(), 2); - model->arrays.erase(conv_op->outputs[1]); + model->EraseArray(conv_op->outputs[1]); conv_op->outputs.resize(1); AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index ad4a6f9b78..88e59664ec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -91,7 +91,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { } else { LOG(FATAL) << "Unhandled activation function type"; } - model->arrays.erase(ac_op->inputs[0]); + model->EraseArray(ac_op->inputs[0]); op->outputs[0] = ac_op->outputs[0]; model->operators.erase(ac_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index 4619d8bbee..dcbbead517 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -285,13 +285,13 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF("Fusing %s into the following %s", LogName(*binary_op), LogName(*following_op)); - model->arrays.erase(binary_op->outputs[0]); + model->EraseArray(binary_op->outputs[0]); following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; const auto& old_constant_param_name = binary_op->inputs[index_of_constant_input]; CHECK(IsConstantParameterArray(*model, old_constant_param_name)); if (CountOpsWithInput(*model, old_constant_param_name) == 1) { - model->arrays.erase(old_constant_param_name); + model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 8948653ec3..5b57178b18 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -309,7 +309,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LOG(FATAL) << "should not get here"; } - model->arrays.erase(preceding_op->outputs[0]); + model->EraseArray(preceding_op->outputs[0]); preceding_op->outputs[0] = binary_op->outputs[0]; preceding_op->fused_activation_function = binary_op->fused_activation_function; @@ -317,7 +317,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { binary_op->inputs[index_of_constant_input]; CHECK(IsConstantParameterArray(*model, old_constant_param_name)); if (CountOpsWithInput(*model, old_constant_param_name) == 1) { - model->arrays.erase(old_constant_param_name); + model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc index f861c4147a..2340f0e850 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -31,13 +31,13 @@ namespace { void PrintModelStats(const string& label, const Model& model) { int quantized_arrays = 0; - for (const auto& array : model.arrays) { + for (const auto& array : model.GetArrayMap()) { if (array.second->quantization_params) { quantized_arrays++; } } LOG(INFO) << label << ": " << model.operators.size() << " operators, " - << model.arrays.size() << " arrays (" << quantized_arrays + << model.GetArrayMap().size() << " arrays (" << quantized_arrays << " quantized)"; } @@ -91,14 +91,9 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { } } while (found_new_useful_arrays); // Erase arrays that aren't useful, and that are discardable. - for (auto it = model->arrays.begin(); it != model->arrays.end();) { - if (useful_arrays.count(it->first) || - !IsDiscardableArray(*model, it->first)) { - ++it; - } else { - it = model->arrays.erase(it); - } - } + model->EraseArrays([&](const string& name) { + return (!useful_arrays.count(name) && IsDiscardableArray(*model, name)); + }); // Erase operators that do not produce a useful output array. for (auto it = model->operators.begin(); it != model->operators.end();) { // Only need to test the first output, as we simultaneously added all of @@ -118,8 +113,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) { std::vector<RnnState> rnn_states_to_keep; for (const auto& rnn_state : model->flags.rnn_states()) { const bool dangling = - !model->arrays.count(rnn_state.back_edge_source_array()) || - !model->arrays.count(rnn_state.state_array()); + !model->HasArray(rnn_state.back_edge_source_array()) || + !model->HasArray(rnn_state.state_array()); if (dangling) { CHECK(rnn_state.discardable()); } else { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 01b75e37c6..419a0776a6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -150,19 +150,19 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { // Erase the subgraph that is now replaced by L2Normalization model->operators.erase(FindOperator(model, square_op)); - model->arrays.erase(sum_op->inputs[0]); + model->EraseArray(sum_op->inputs[0]); if (sum_op->inputs.size() > 1) { - model->arrays.erase(sum_op->inputs[1]); + model->EraseArray(sum_op->inputs[1]); } model->operators.erase(FindOperator(model, sum_op)); if (add_op) { - model->arrays.erase(add_op->inputs[0]); - model->arrays.erase(add_op->inputs[1]); + model->EraseArray(add_op->inputs[0]); + model->EraseArray(add_op->inputs[1]); model->operators.erase(FindOperator(model, add_op)); } - model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]); + model->EraseArray(sqrt_or_rsqrt_op->inputs[0]); model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); - model->arrays.erase(div_or_mul_op->inputs[1]); + model->EraseArray(div_or_mul_op->inputs[1]); model->operators.erase(FindOperator(model, div_or_mul_op)); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index 1865416fc2..e4d52476c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -92,8 +92,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); // Erase intermediate arrays, keeping input to square op. - model->arrays.erase(avpool_op->inputs[0]); - model->arrays.erase(sqrt_op->inputs[0]); + model->EraseArray(avpool_op->inputs[0]); + model->EraseArray(sqrt_op->inputs[0]); // Erase three operators being replaced. model->operators.erase(FindOperator(model, square_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index cfc77024e7..d36e950609 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -89,12 +89,12 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); // Erase Maximum scalar input & operator - model->arrays.erase(maximum_op->inputs[scalar_input_index]); + model->EraseArray(maximum_op->inputs[scalar_input_index]); model->operators.erase(FindOperator(model, maximum_op)); // Erase Minimum inputs & operator - model->arrays.erase(minimum_op->inputs[0]); - model->arrays.erase(minimum_op->inputs[1]); + model->EraseArray(minimum_op->inputs[0]); + model->EraseArray(minimum_op->inputs[1]); model->operators.erase(FindOperator(model, minimum_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 29b55d9bfc..f0d107232b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -27,7 +27,7 @@ namespace { void SetDataTypeForAllOutputs(Model* model, Operator* op, ArrayDataType data_type) { for (const auto& output : op->outputs) { - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } } // namespace @@ -39,7 +39,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // If the data type of some input is unknown, we need to yield. for (const auto& input : op->inputs) { if (!model->IsOptionalArray(input) && - model->arrays[input]->data_type == ArrayDataType::kNone) { + model->GetArray(input).data_type == ArrayDataType::kNone) { return false; } } @@ -47,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // end if we changed anything, and return the correct boolean value. std::unordered_map<string, ArrayDataType> old_output_data_types; for (const auto& output : op->outputs) { - old_output_data_types[output] = model->arrays[output]->data_type; + old_output_data_types[output] = model->GetArray(output).data_type; } // Do the actual output data types propagation. if (op->type == OperatorType::kDequantize || @@ -69,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { op->type == OperatorType::kFill) { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); - const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kCast) { // Data type of the Cast op is specified. CHECK_EQ(op->outputs.size(), 1); auto* cast_op = static_cast<CastOperator*>(op); - model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; } else if (op->type == OperatorType::kArgMax) { // Data type of the ArgMax op is specified. CHECK_EQ(op->outputs.size(), 1); auto* argmax_op = static_cast<ArgMaxOperator*>(op); - model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type; + model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; } else if (op->type == OperatorType::kRange) { auto* range_op = static_cast<RangeOperator*>(op); // Output type of the Range op can be set via an attribute @@ -91,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // Otherwise use the first input CHECK_GE(op->inputs.size(), 1); - data_type = model->arrays[op->inputs[0]]->data_type; + data_type = model->GetArray(op->inputs[0]).data_type; } CHECK_EQ(op->outputs.size(), 1); SetDataTypeForAllOutputs(model, op, data_type); @@ -103,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { auto output = op->outputs[i]; auto data_type = unsupported_op->output_data_types[i]; - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } else if (op->type == OperatorType::kExpandDims) { // Yield on ExpandDim until it is converted to Reshape @@ -111,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); - const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { - if (old_output_data_types[output] != model->arrays[output]->data_type) { + if (old_output_data_types[output] != model->GetArray(output).data_type) { return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index a939efb4db..ff0a3bd881 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -85,7 +85,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, int GetOutputDepthFromWeights(const Model& model, const Operator& op) { const string& weights_name = op.inputs[1]; - const auto& weights_shape = model.arrays.at(weights_name)->shape(); + const auto& weights_shape = model.GetArray(weights_name).shape(); if (op.type == OperatorType::kConv || op.type == OperatorType::kFullyConnected) { return weights_shape.dims(0); @@ -98,7 +98,7 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) { bool EnsureBiasVectorShape(Model* model, Operator* op) { const string& weights_name = op->inputs[1]; - const auto& weights_array = *model->arrays[weights_name]; + const auto& weights_array = model->GetArray(weights_name); // Yield until weights shape has been resolved. if (!weights_array.has_shape()) { return false; @@ -107,7 +107,7 @@ bool EnsureBiasVectorShape(Model* model, Operator* op) { if (op->inputs.size() < 3) { return false; } - auto& bias_array = *model->arrays[op->inputs[2]]; + auto& bias_array = model->GetArray(op->inputs[2]); if (bias_array.has_shape()) { return true; } @@ -126,7 +126,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -134,7 +134,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { const auto& input_shape = input_array.shape(); CHECK_EQ(input_shape.dimensions_count(), 4); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -156,7 +156,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { if (op->outputs.size() == 2) { const auto& output_shape = output_array.shape(); const int input_depth = weights_shape.dims(3); - auto& im2col_array = *model->arrays[op->outputs[1]]; + auto& im2col_array = model->GetArray(op->outputs[1]); im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), output_shape.dims(2), input_depth * kheight * kwidth}); @@ -168,7 +168,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -176,7 +176,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const auto& input_shape = input_array.shape(); CHECK_EQ(input_shape.dimensions_count(), 4); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -209,7 +209,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { } void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -232,7 +232,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { } void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -258,7 +258,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { void ProcessFillOperator(Model* model, FillOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; @@ -287,7 +287,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -295,7 +295,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { const auto& input_shape = input_array.shape(); CHECK_GE(input_shape.dimensions_count(), 1); - const auto& weights_array = *model->arrays[op->inputs[1]]; + const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -315,13 +315,13 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { void ProcessTensorFlowReshapeOperator(Model* model, TensorFlowReshapeOperator* op) { - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -377,14 +377,14 @@ void ProcessTensorFlowReshapeOperator(Model* model, } void ProcessSimpleOperator(Model* model, Operator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; } const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); if (output_array.has_shape()) { return; } @@ -394,14 +394,14 @@ void ProcessSimpleOperator(Model* model, Operator* op) { void ProcessSimpleBinaryOperator(Model* model, Operator* op) { CHECK_EQ(op->inputs.size(), 2); - const auto& input0_array = *model->arrays[op->inputs[0]]; - const auto& input1_array = *model->arrays[op->inputs[1]]; + const auto& input0_array = model->GetArray(op->inputs[0]); + const auto& input1_array = model->GetArray(op->inputs[1]); // Yield until input dims have been resolved. if (!input0_array.has_shape() || !input1_array.has_shape()) { return; } const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), &output_array); } @@ -424,11 +424,11 @@ bool KeepDims(const Operator& op) { void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { CHECK_LE(op->inputs.size(), 2); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { return; } @@ -436,7 +436,7 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { const bool keep_dims = KeepDims(*op); if (op->inputs.size() == 2) { // There is a reduction_indices input. - const auto& reduction_array = *model->arrays[op->inputs[1]]; + const auto& reduction_array = model->GetArray(op->inputs[1]); if (!reduction_array.buffer) { return; } @@ -476,11 +476,11 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) { if (op->begin.empty()) return; // Yield until input dims have been resolved. - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) return; const Shape& input_shape = input_array.shape(); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; CHECK_EQ(input_shape.dims().size(), op->size.size()); @@ -500,7 +500,7 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) { void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -515,20 +515,20 @@ void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { // Yield until input dims have been resolved. for (const auto& input_name : op->inputs) { - auto& input_array = *model->arrays[input_name]; + auto& input_array = model->GetArray(input_name); if (!input_array.has_shape()) { return; } } auto& output_array = model->GetArray(op->outputs[0]); // Use 0 input as basis for output dimensions. - const auto& first_input_array = *model->arrays[op->inputs[0]]; + const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; for (const auto& input_name : op->inputs) { - auto& input_array = *model->arrays[input_name]; + auto& input_array = model->GetArray(input_name); CHECK(input_array.has_shape()); if (input_array.shape().dimensions_count() == 0) { continue; @@ -548,16 +548,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { void ProcessRangeOperator(Model* model, RangeOperator* op) { CHECK_EQ(op->inputs.size(), 3); - const auto& start_array = *model->arrays[op->inputs[0]]; + const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until input dims have been resolved. return; } - const auto& limit_array = *model->arrays[op->inputs[1]]; + const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { return; } - const auto& delta_array = *model->arrays[op->inputs[2]]; + const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { return; } @@ -599,7 +599,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) { void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { CHECK_EQ(op->inputs.size(), 2); const string& input_name = op->inputs[1]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -618,13 +618,13 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { CHECK_EQ(op->outputs.size(), op->num_split); for (const auto& output : op->outputs) { - model->arrays[output]->copy_shape(output_shape); + model->GetArray(output).copy_shape(output_shape); } } void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -641,7 +641,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -658,7 +658,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { const string& input_name = op->inputs[0]; - const auto& input_array = *model->arrays[input_name]; + const auto& input_array = model->GetArray(input_name); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -679,14 +679,14 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - if (!model->arrays[op->inputs[0]]->has_shape() || - !model->arrays[op->inputs[1]]->has_shape()) { + if (!model->GetArray(op->inputs[0]).has_shape() || + !model->GetArray(op->inputs[1]).has_shape()) { return; } - const auto& input_data_shape = model->arrays[op->inputs[0]]->shape(); + const auto& input_data_shape = model->GetArray(op->inputs[0]).shape(); const string& output_size_name = op->inputs[1]; - const auto& output_size_array = *model->arrays[output_size_name]; + const auto& output_size_array = model->GetArray(output_size_name); CHECK(output_size_array.data_type == ArrayDataType::kInt32); CHECK(output_size_array.has_shape()); const auto& output_size_shape = output_size_array.shape(); @@ -697,9 +697,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { } std::vector<int32> output_shape = output_size_array.GetBuffer<ArrayDataType::kInt32>().data; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_data_shape.dims(0), output_shape[0], output_shape[1], - input_data_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_data_shape.dims(0), output_shape[0], + output_shape[1], input_data_shape.dims(3)})); } void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { @@ -708,7 +708,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); const auto& input_array = - *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); // Yield until all input dims have been resolved. if (!input_array.has_shape()) { return; @@ -717,7 +717,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(input_shape.dimensions_count(), 2); const auto& prev_activ_array = - *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]); // Yield until all input dims have been resolved. if (!prev_activ_array.has_shape()) { return; @@ -726,7 +726,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(prev_activ_shape.dimensions_count(), 2); const auto& weights_array = - *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]); // Yield until weights dims have been resolved. if (!weights_array.has_shape()) { return; @@ -735,7 +735,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_EQ(weights_shape.dimensions_count(), 2); const auto& bias_array = - *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]); // Yield until bias dims have been resolved. if (!bias_array.has_shape()) { return; @@ -744,7 +744,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { CHECK_GE(bias_shape.dimensions_count(), 1); const auto& prev_state_array = - *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]]; + model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]); // Yield until all input dims have been resolved. if (!prev_state_array.has_shape()) { return; @@ -784,7 +784,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { } void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -797,8 +797,8 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto input_height = input_shape.dims(1); const auto input_width = input_shape.dims(2); - const auto& block_shape_array = *model->arrays[op->inputs[1]]; - const auto& paddings_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); + const auto& paddings_array = model->GetArray(op->inputs[2]); const auto& block_shape_array_shape = block_shape_array.shape(); const auto& paddings_array_shape = paddings_array.shape(); QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); @@ -830,13 +830,13 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { int output_height = height_with_paddings / block_height; int output_width = width_with_paddings / block_width; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_shape.dims(0) * block_height * block_width, output_height, - output_width, input_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_shape.dims(0) * block_height * block_width, + output_height, output_width, input_shape.dims(3)})); } void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -846,8 +846,8 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { const auto input_height = input_shape.dims(1); const auto input_width = input_shape.dims(2); - const auto& block_shape_array = *model->arrays[op->inputs[1]]; - const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); + const auto& crops_array = model->GetArray(op->inputs[2]); const auto& block_shape_array_shape = block_shape_array.shape(); const auto& crops_array_shape = crops_array.shape(); QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); @@ -882,15 +882,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { int output_height = input_height * block_height; int output_width = input_width * block_width; - model->arrays[op->outputs[0]]->copy_shape( - Shape({input_shape.dims(0) / (block_height * block_width), output_height, - output_width, input_shape.dims(3)})); + model->GetArray(op->outputs[0]) + .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width), + output_height, output_width, input_shape.dims(3)})); } void ProcessGatherOperator(Model* model, GatherOperator* op) { - const auto& input_array = *model->arrays[op->inputs[0]]; - const auto& indices_array = *model->arrays[op->inputs[1]]; - auto& output_array = *model->arrays[op->outputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); + const auto& indices_array = model->GetArray(op->inputs[1]); + auto& output_array = model->GetArray(op->outputs[0]); // Bail if we already know the output shape. if (output_array.has_shape()) { @@ -924,7 +924,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) return; @@ -932,7 +932,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) { if (op->left_padding.empty()) return; CHECK_EQ(op->left_padding.size(), op->right_padding.size()); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; Shape output_shape = input_array.shape(); @@ -949,13 +949,13 @@ void ProcessPadOperator(Model* model, PadOperator* op) { void ProcessRankOperator(Model* model, RankOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -970,13 +970,13 @@ void ProcessRankOperator(Model* model, RankOperator* op) { void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -991,7 +991,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { void ProcessStackOperator(Model* model, StackOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; @@ -1032,7 +1032,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) { void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // Shape already propagated return; @@ -1112,12 +1112,12 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { CHECK_EQ(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) return; - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) return; const std::vector<int>& input_dims = input_array.shape().dims(); @@ -1136,18 +1136,18 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { void ProcessSvdfOperator(Model* model, SvdfOperator* op) { CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) return; - auto& weights_feature_array = *model->arrays[op->inputs[1]]; + auto& weights_feature_array = model->GetArray(op->inputs[1]); if (!weights_feature_array.has_shape()) return; - const auto& weights_time_array = *model->arrays[op->inputs[2]]; + const auto& weights_time_array = model->GetArray(op->inputs[2]); if (!weights_time_array.has_shape()) return; const bool has_bias = (op->inputs.size() == 4); if (has_bias) { - const auto& bias_array = *model->arrays[op->inputs[3]]; + const auto& bias_array = model->GetArray(op->inputs[3]); if (!bias_array.has_shape()) return; } @@ -1164,13 +1164,13 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) { } void ProcessTransposeOperator(Model* model, TransposeOperator* op) { - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.has_shape()) { // We have already run return; } - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return; @@ -1204,7 +1204,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { CHECK_EQ(op->inputs.size(), 2); - const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -1222,7 +1222,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { } output_dims.push_back(1); const string& output_name = op->outputs[0]; - auto& output_array = *model->arrays[output_name]; + auto& output_array = model->GetArray(output_name); if (output_array.has_shape()) { return; } @@ -1236,8 +1236,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { auto* op = it->get(); std::unordered_map<string, std::vector<int>> old_output_dims; for (const auto& output : op->outputs) { - if (model->arrays[output]->has_shape()) { - old_output_dims[output] = model->arrays[output]->shape().dims(); + if (model->GetArray(output).has_shape()) { + old_output_dims[output] = model->GetArray(output).shape().dims(); } } @@ -1433,10 +1433,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { // Return true if any output dim changed, false if none changed. // Assumption: no transformation clears an output shape, they only add shapes. for (const auto& output : op->outputs) { - if (model->arrays[output]->has_shape() && - (old_output_dims[output] != model->arrays[output]->shape().dims())) { + if (model->GetArray(output).has_shape() && + (old_output_dims[output] != model->GetArray(output).shape().dims())) { AddMessageF("Set shape of %s to [%s]", output, - absl::StrJoin(model->arrays[output]->shape().dims(), ",")); + absl::StrJoin(model->GetArray(output).shape().dims(), ",")); return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 56082b965a..b973b2b813 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -412,7 +412,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { model->flags.set_output_arrays(i, dequantize_op->inputs[0]); } } - model->arrays.erase(dequantize_op->outputs[0]); + model->EraseArray(dequantize_op->outputs[0]); model->operators.erase(dequantize_it); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc index 371ced388a..11f8d4b6ee 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -80,7 +80,7 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { // else. for (int i = 1; i <= 2; i++) { if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->arrays.erase(fakequant_op->inputs[i]); + model->EraseArray(fakequant_op->inputs[i]); } } fakequant_op->inputs.resize(1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc index 3992e7d1ef..c3b2709a33 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -51,7 +51,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { // Remove the node and its output array. AddMessageF("Removed final %s", LogName(*dequantize_op)); - model->arrays.erase(output); + model->EraseArray(output); model->operators.erase(dequantize_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc index 6add443f2d..8512e6bb5a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -81,7 +81,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { // Now check if the constant operand makes this binary // operator trivial. const auto& constant_input_array = - *model->arrays[binary_op->inputs[index_of_constant_input]]; + model->GetArray(binary_op->inputs[index_of_constant_input]); // For now, we only handle floats here. if (constant_input_array.data_type != ArrayDataType::kFloat) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc index 23a5c857e8..936854a04f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -59,7 +59,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { for (const string& input : trivial_inputs) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } concat_op->inputs = nontrivial_inputs; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index 047389f69a..587f171bbf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -124,7 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, } } if (!is_referenced) { - model->arrays.erase(removal_candidate); + model->EraseArray(removal_candidate); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index e6cca8acf3..aa2c293382 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -33,7 +33,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { // the model. We allow specifying an arbitrary input_array, // treating the part of the graph leading up to it as unused. for (const auto& output : op->outputs) { - CHECK(model->arrays.count(output)); + CHECK(model->HasArray(output)); // If this output is provided as the model's input array, // then we don't need this operator to produce its contents. if (IsInputArray(*model, output)) { @@ -93,7 +93,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1 && !GetOpWithOutput(*model, input)) { - model->arrays.erase(input); + model->EraseArray(input); } } @@ -116,7 +116,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { continue; } // Generic case: do delete this output array. - model->arrays.erase(output); + model->EraseArray(output); } model->operators.erase(it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 3eb7fa3896..fb109eb91b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -121,9 +121,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->arrays.erase(bn_op->inputs[1]); - model->arrays.erase(bn_op->inputs[2]); - model->arrays.erase(bn_op->inputs[3]); + model->EraseArray(bn_op->inputs[1]); + model->EraseArray(bn_op->inputs[2]); + model->EraseArray(bn_op->inputs[3]); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index 7777d4f543..a06919e228 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -42,7 +42,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { return false; // Handle crops - const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& crops_array = model->GetArray(op->inputs[2]); if (!crops_array.has_shape()) return false; const std::vector<int>& crops_dims = crops_array.shape().dims(); if (crops_dims.size() != 2) { @@ -58,7 +58,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { } // Handle block_shape - const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& block_shape_array = model->GetArray(op->inputs[1]); if (!block_shape_array.has_shape()) return false; const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index fd51df4058..5e779f6765 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -166,8 +166,9 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, void EvaluateBinaryOperatorOnConstantInputs(Model* model, const Operator* binary_op) { - const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; - const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; + const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type; + const auto output_data_type = + model->GetArray(binary_op->outputs[0]).data_type; #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ if (inputs_data_type == InputsDataType && \ output_data_type == OutputDataType) { \ @@ -214,7 +215,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { return false; } - auto& output_array = *model->arrays[binary_op->outputs[0]]; + auto& output_array = model->GetArray(binary_op->outputs[0]); // Yield until the output array dims have been resolved. if (!output_array.has_shape()) { return false; @@ -239,10 +240,10 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { // Remove the binary operator and its inputs if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { - model->arrays.erase(binary_op->inputs[0]); + model->EraseArray(binary_op->inputs[0]); } if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { - model->arrays.erase(binary_op->inputs[1]); + model->EraseArray(binary_op->inputs[1]); } AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*binary_op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index 9835f86398..833c97c758 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -189,7 +189,7 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove all the resolved arrays. for (const string& input_name : concat_op->inputs) { - model->arrays.erase(input_name); + model->EraseArray(input_name); } // Remove concatenate operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 244adcc4c4..81fe37d7e0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -66,7 +66,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { output_buffer.data[i] = dst_val; } if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { - model->arrays.erase(fakequant_op->inputs[0]); + model->EraseArray(fakequant_op->inputs[0]); } model->operators.erase(fakequant_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc index 9da51d9147..f6f95481b5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -104,11 +104,11 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { // Erase input arrays if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } if (IsDiscardableArray(*model, op->inputs[1]) && CountOpsWithInput(*model, op->inputs[1]) == 1) { - model->arrays.erase(op->inputs[1]); + model->EraseArray(op->inputs[1]); } // Erase the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc index 383d54aa5a..1a0ba9e2bc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc @@ -28,17 +28,17 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { auto* op = static_cast<RangeOperator*>(base_op); CHECK_EQ(op->inputs.size(), 3); - const auto& start_array = *model->arrays[op->inputs[0]]; + const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until all input dims have been resolved. return false; } - const auto& limit_array = *model->arrays[op->inputs[1]]; + const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { // Yield until all input dims have been resolved. return false; } - const auto& delta_array = *model->arrays[op->inputs[2]]; + const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { // Yield until all input dims have been resolved. return false; @@ -52,7 +52,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { } CHECK_EQ(op->outputs.size(), 1); - auto& output_array = *model->arrays[op->outputs[0]]; + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes return false; @@ -87,15 +87,15 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { // Delete the input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } if (IsDiscardableArray(*model, op->inputs[1]) && CountOpsWithInput(*model, op->inputs[1]) == 1) { - model->arrays.erase(op->inputs[1]); + model->EraseArray(op->inputs[1]); } if (IsDiscardableArray(*model, op->inputs[2]) && CountOpsWithInput(*model, op->inputs[2]) == 1) { - model->arrays.erase(op->inputs[2]); + model->EraseArray(op->inputs[2]); } // Delete the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 35b81dd550..9ea01acd05 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -62,7 +62,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Delete the input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } model->operators.erase(it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc index 86c76141a4..ea0d6dc820 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -101,7 +101,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { if (IsDiscardableArray(*model, input) && CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 3976d9cbb4..a0cfc3d597 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -186,7 +186,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { // Erase input array if no longer used if (IsDiscardableArray(*model, op->inputs[0]) && CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->arrays.erase(op->inputs[0]); + model->EraseArray(op->inputs[0]); } // Erase the operator diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 26ff9d887b..1cd2aff28c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -199,7 +199,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } for (const auto& input : unary_op->inputs) { if (CountOpsWithInput(*model, input) == 1) { - model->arrays.erase(input); + model->EraseArray(input); } } AddMessageF("Resolved constant %s to the equivalent constant array", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc index b77be3f5c0..013b50ac9b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc @@ -36,7 +36,7 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { if (op->inputs.size() != 2) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& indices_array = *model->arrays[op->inputs[1]]; + const auto& indices_array = model->GetArray(op->inputs[1]); if (!indices_array.has_shape()) return false; op->axis = indices_array.GetBuffer<ArrayDataType::kInt32>().data; return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc index d5f5869c62..8a8e723cf7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -35,7 +35,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 2); if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& array = *model->arrays[op->inputs[1]]; + const auto& array = model->GetArray(op->inputs[1]); if (!array.has_shape()) return false; const std::vector<int>& dims = array.shape().dims(); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index b5093bc4c7..5c68f87f6c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -103,7 +103,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { AddMessageF("Reordered axes for array %s", input_array_name); // Remove the op and output array. - model->arrays.erase(output_array_name); + model->EraseArray(output_array_name); model->operators.erase(reorder_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index bed2a85bd2..2e063e3554 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -37,7 +37,7 @@ bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { if (!op->shape.empty()) return false; if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { - const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]]; + const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]); op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc index 1d0a2ec8f6..e760d08e5a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -36,10 +36,10 @@ bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { if (!IsConstantParameterArray(*model, op->inputs[1])) return false; if (!IsConstantParameterArray(*model, op->inputs[2])) return false; - const auto& begin_array = *model->arrays[op->inputs[1]]; + const auto& begin_array = model->GetArray(op->inputs[1]); if (!begin_array.has_shape()) return false; - const auto& size_array = *model->arrays[op->inputs[2]]; + const auto& size_array = model->GetArray(op->inputs[2]); if (!size_array.has_shape()) return false; op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index a73f16735c..dad6aceccf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -45,7 +45,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { return false; // Handle paddings. - const auto& paddings_array = *model->arrays[op->inputs[paddings_index]]; + const auto& paddings_array = model->GetArray(op->inputs[paddings_index]); if (!paddings_array.has_shape()) return false; const std::vector<int>& paddings_dims = paddings_array.shape().dims(); if (paddings_dims.size() != 2) { @@ -61,7 +61,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { } // Handle block_shape. - const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]]; + const auto& block_shape_array = + model->GetArray(op->inputs[block_shape_index]); if (!block_shape_array.has_shape()) return false; const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index dbe69adcbd..de4d06be2a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -31,13 +31,13 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { } CHECK_EQ(op->inputs.size(), 4); - const auto& start_array = *model->arrays[op->inputs[1]]; + const auto& start_array = model->GetArray(op->inputs[1]); if (!start_array.has_shape()) return false; - const auto& stop_array = *model->arrays[op->inputs[2]]; + const auto& stop_array = model->GetArray(op->inputs[2]); if (!stop_array.has_shape()) return false; - const auto& stride_array = *model->arrays[op->inputs[3]]; + const auto& stride_array = model->GetArray(op->inputs[3]); if (!stride_array.has_shape()) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index c6723a880e..5c0c1e3478 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -75,7 +75,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // Remove the axis array if it is not used by anything else. if (CountOpsWithInput(*model, axis_name) == 1) { - model->arrays.erase(axis_name); + model->EraseArray(axis_name); } // Remove the TensorFlowConcat op model->operators.erase(concat_it); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index bea7487051..ad1e56888e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -69,7 +69,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; if (CountOpsWithInput(*model, previous_op_output) == 1) { - model->arrays.erase(previous_op_output); + model->EraseArray(previous_op_output); } CHECK_EQ(previous_op->inputs.size(), 2); fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; @@ -78,7 +78,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { const auto& previous_op_shape = previous_op->inputs[1]; if (CountOpsWithInput(*model, previous_op_shape) == 1 && !GetOpWithOutput(*model, previous_op_shape)) { - model->arrays.erase(previous_op_shape); + model->EraseArray(previous_op_shape); } model->operators.erase(previous_op_it); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index cfa5ce0716..477e7f13da 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -55,7 +55,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { // Remove the node and its output array. AddMessageF("Removing already-resolved %s", LogName(*merge_op)); - model->arrays.erase(merge_op->outputs[0]); + model->EraseArray(merge_op->outputs[0]); model->operators.erase(merge_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index 150cf53da3..a418073441 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -103,7 +103,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { // Remove the output arrays if they are now unused. for (int i = 0; i < 2; i++) { if (!GetOpWithInput(*model, switch_op->outputs[i])) { - model->arrays.erase(switch_op->outputs[i]); + model->EraseArray(switch_op->outputs[i]); } } // Remove input arrays if they are only used by the switch itself and aren't @@ -111,7 +111,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { for (const auto& input : switch_op->inputs) { if (CountOpsWithInput(*model, input) == 1 && !GetOpWithOutput(*model, input)) { - model->arrays.erase(input); + model->EraseArray(input); } } // Remove the switch node itself. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc index 9f7e7c42a2..1ddf54c778 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc @@ -45,10 +45,10 @@ void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, model->operators.erase(tile_it); if (!CountOpsWithInput(*model, tile_multiplier_array) && !GetOpWithOutput(*model, tile_multiplier_array)) { - model->arrays.erase(tile_multiplier_array); + model->EraseArray(tile_multiplier_array); } if (!CountOpsWithInput(*model, tile_output_array)) { - model->arrays.erase(tile_output_array); + model->EraseArray(tile_output_array); } } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc index 12d966b261..a657ee00af 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc @@ -35,7 +35,7 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) { if (!IsConstantParameterArray(*model, op->inputs[1])) return false; // Handling perm. - const auto& perm_array = *model->arrays[op->inputs[1]]; + const auto& perm_array = model->GetArray(op->inputs[1]); if (!perm_array.has_shape()) return false; const std::vector<int>& perm_dims = perm_array.shape().dims(); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index a14016e8e2..3a1d175b98 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> -//#include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -168,11 +167,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, ElementsAreArray(ArrayFloatNear( {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., @@ -187,11 +186,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, ElementsAreArray(ArrayFloatNear( {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., @@ -206,11 +205,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); - EXPECT_THAT(model.arrays.size(), 5); + EXPECT_THAT(model.GetArrayMap().size(), 5); (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); - EXPECT_THAT(model.arrays.size(), 1); + EXPECT_THAT(model.GetArrayMap().size(), 1); - auto& concatenated_array = (*model.arrays.begin()).second; + auto& concatenated_array = (*model.GetArrayMap().begin()).second; EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, ElementsAreArray(ArrayFloatNear( {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 4e273343df..2c7046c8c7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -63,7 +63,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { ac_op->outputs = op->outputs; const string& tmp_array_name = AvailableArrayName(*model, op->outputs[0] + "_unfused"); - CHECK(!model->arrays.count(tmp_array_name)); + CHECK(!model->HasArray(tmp_array_name)); model->GetOrCreateArray(tmp_array_name); ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 995e9d67ca..1947271f55 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1652,7 +1652,7 @@ void StripCaretFromArrayNames(Model* model) { output = string(absl::StripPrefix(output, "^")); } } - for (auto& array : model->arrays) { + for (auto& array : model->GetArrayMap()) { if (absl::StartsWith(array.first, "^")) { LOG(FATAL) << "What?"; } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index b079750bed..54fbba7381 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -1521,15 +1521,19 @@ struct Array { // Our Model struct, represents an entire model (our "top-level" struct). // Owns everything. -struct Model { +class Model { + public: + using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>; + + bool HasArray(const string& name) const { return arrays.count(name) > 0; } Array& GetArray(const string& name) const { - DCHECK(arrays.count(name)); + DCHECK(HasArray(name)); return *arrays.at(name); } Array& GetOrCreateArray(const string& name) { // Make sure name is not used by an optional array DCHECK(!optional_arrays.count(name)); - if (!arrays.count(name)) { + if (!HasArray(name)) { Array* ptr = new Array; arrays[name] = std::unique_ptr<Array>(ptr); } @@ -1544,18 +1548,27 @@ struct Model { return optional_arrays.count(name); } + // Note that this invalidates all array iterators. + void EraseArray(const string& name) { arrays.erase(name); } + void EraseArrays(std::function<bool(const string&)> discardable) { + for (auto it = arrays.begin(); it != arrays.end();) { + if (discardable(it->first)) { + it = arrays.erase(it); + } else { + ++it; + } + } + } + const ArrayMap& GetArrayMap() const { return arrays; } + // Optional arrays are used for optional tensors, // these tensors do not have data, but with reserved names as op inputs. std::set<string> optional_arrays; + // The list of operators. Notice how it's a list of unique_ptr's, implying // that the Model is what owns Operator's and keeps them alive. std::vector<std::unique_ptr<Operator>> operators; - // The associative array mapping names to Array's. - // Notice how it's a container of unique_ptr's, implying - // that the Model is what owns Array's and keeps them alive. - // The Operator's refer to these Array's by their name strings, not by their - // addresses. See Operator::inputs, Operator::outputs. - std::unordered_map<string, std::unique_ptr<Array>> arrays; + // Generic flags, a place where we combine information passed to us via // command-line parameters (e.g. --input_width=N) with information that // we may or may not find in the input model file. @@ -1564,6 +1577,14 @@ struct Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; + + private: + // The associative array mapping names to Array's. + // Notice how it's a container of unique_ptr's, implying + // that the Model is what owns Array's and keeps them alive. + // The Operator's refer to these Array's by their name strings, not by their + // addresses. See Operator::inputs, Operator::outputs. + std::unordered_map<string, std::unique_ptr<Array>> arrays; }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 440353203e..391ef87029 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -62,7 +62,7 @@ namespace details { void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { // First find a list of unique array names. std::set<string> names; - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { names.insert(array_pair.first); } @@ -96,7 +96,7 @@ Offset<Vector<Offset<Tensor>>> ExportTensors( // tensors in the tensors_map. std::map<int, Offset<Tensor>> ordered_tensors; - for (const auto& array_pair : model.arrays) { + for (const auto& array_pair : model.GetArrayMap()) { const string& tensor_name = array_pair.first; const toco::Array& array = *array_pair.second; diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc index 309fa6d7f6..aad6e780d5 100644 --- a/tensorflow/contrib/lite/toco/tflite/import_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -114,7 +114,7 @@ TEST_F(ImportTest, Tensors) { auto model = Import(ModelFlags(), InputModelAsString()); - ASSERT_GT(model->arrays.count("tensor_one"), 0); + ASSERT_GT(model->HasArray("tensor_one"), 0); Array& a1 = model->GetArray("tensor_one"); EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data, diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 94b4d14696..afaa0fd0c7 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -133,7 +133,7 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { for (int i = 0; i < model->flags.input_arrays_size(); i++) { string const& array_name = model->flags.input_arrays(i).name(); - auto* array = model->arrays[array_name].get(); + auto* array = &model->GetArray(array_name); // Note that the notion of changing data types only applies to real-numbers // arrays (see the documentation for inference_input_type). // TODO(benoitjacob) this is assuming that uint8 arrays are quantized, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index f9093ab973..900c60bd90 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -93,7 +93,7 @@ int CountOpsWithInput(const Model& model, const string& array_name) { bool DeleteArrayIfUnused(const string& array_name, Model* model) { if (CountOpsWithInput(*model, array_name) == 0) { - model->arrays.erase(array_name); + model->EraseArray(array_name); return true; } return false; @@ -566,11 +566,11 @@ int RequiredBufferSizeForShape(const Shape& shape) { } bool IsConstantParameterArray(const Model& model, const string& name) { - if (!model.arrays.count(name)) { + if (!model.HasArray(name)) { return false; } - return !!model.arrays.at(name)->buffer; + return !!model.GetArray(name).buffer; } namespace { @@ -633,17 +633,17 @@ void CheckNonExistentIOArrays(const Model& model) { return; } for (const auto& input_array : model.flags.input_arrays()) { - CHECK(model.arrays.count(input_array.name())) + CHECK(model.HasArray(input_array.name())) << "Input array not found: " << input_array.name(); } for (const string& output_array : model.flags.output_arrays()) { - CHECK(model.arrays.count(output_array)) + CHECK(model.HasArray(output_array)) << "Output array not found: " << output_array; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { - CHECK(model.arrays.count(rnn_state.state_array())); - CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + CHECK(model.HasArray(rnn_state.state_array())); + CHECK(model.HasArray(rnn_state.back_edge_source_array())); } } } @@ -652,10 +652,10 @@ void CheckNonExistentIOArrays(const Model& model) { void CheckNoMissingArray(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { - CHECK(model.arrays.count(input) || model.optional_arrays.count(input)); + CHECK(model.HasArray(input) || model.optional_arrays.count(input)); } for (const auto& output : op->outputs) { - CHECK(model.arrays.count(output)); + CHECK(model.HasArray(output)); } } CheckNonExistentIOArrays(model); @@ -664,12 +664,12 @@ void CheckNoMissingArray(const Model& model) { void FixNoMissingArray(Model* model) { for (const auto& op : model->operators) { for (const auto& input : op->inputs) { - if (!model->arrays.count(input)) { + if (!model->HasArray(input)) { model->GetOrCreateArray(input); } } for (const auto& output : op->outputs) { - if (!model->arrays.count(output)) { + if (!model->HasArray(output)) { model->GetOrCreateArray(output); } } @@ -687,7 +687,7 @@ void FixNoMissingArray(Model* model) { void CheckNoOrphanedArray(const Model& model) { std::unordered_set<string> arrays_without_known_use; - for (const auto& array : model.arrays) { + for (const auto& array : model.GetArrayMap()) { if (IsDiscardableArray(model, array.first)) { arrays_without_known_use.insert(array.first); } @@ -714,7 +714,7 @@ void CheckNoOrphanedArray(const Model& model) { void FixNoOrphanedArray(Model* model) { std::unordered_set<string> arrays_without_known_use; - for (const auto& array : model->arrays) { + for (const auto& array : model->GetArrayMap()) { arrays_without_known_use.insert(array.first); } for (const auto& op : model->operators) { @@ -731,13 +731,13 @@ void FixNoOrphanedArray(Model* model) { } for (const auto& array : arrays_without_known_use) { if (IsDiscardableArray(*model, array)) { - model->arrays.erase(array); + model->EraseArray(array); } } } void CheckArrayFieldsConsistent(const Model& model) { - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; if (array->has_shape()) { for (int d : array->shape().dims()) { @@ -756,7 +756,7 @@ void CheckArrayFieldsConsistent(const Model& model) { void CheckOperatorOrdering(const Model& model) { std::unordered_set<string> arrays_behind_us; - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { if (!GetOpWithOutput(model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } @@ -781,7 +781,7 @@ void CheckOperatorOrdering(const Model& model) { void FixOperatorOrdering(Model* model) { std::unordered_set<string> arrays_behind_us; - for (const auto& array_entry : model->arrays) { + for (const auto& array_entry : model->GetArrayMap()) { if (!GetOpWithOutput(*model, array_entry.first)) { arrays_behind_us.insert(array_entry.first); } @@ -936,7 +936,8 @@ void CheckModelCounts(const Model& model) { if (count_type == "None") { continue; } else if (count_type == "Arrays") { - CheckCountInRange(model_check, model.arrays.size(), "count of arrays"); + CheckCountInRange(model_check, model.GetArrayMap().size(), + "count of arrays"); } else if (count_type == "Total") { CheckCountInRange(model_check, model.operators.size(), "count of all operator instances"); @@ -1297,7 +1298,7 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { return false; } } - const auto& array = model.arrays.at(array_name); + const auto& array = &model.GetArray(array_name); // An array with a constant buffer isn't a transient array. if (!!array->buffer) { return false; @@ -1310,13 +1311,13 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { } string AvailableArrayName(const Model& model, const string& name) { - if (!model.arrays.count(name) && !model.optional_arrays.count(name)) { + if (!model.HasArray(name) && !model.optional_arrays.count(name)) { return name; } const int kNumSuffixesToTry = 1000; for (int i = 0; i < kNumSuffixesToTry; i++) { const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); - if (!model.arrays.count(name_with_suffix)) { + if (!model.HasArray(name_with_suffix)) { return name_with_suffix; } } @@ -1334,12 +1335,12 @@ string ShapeToString(const Shape& shape) { } void PrintArrayShape(Model* model, const string& name) { - if (!model->arrays[name]->has_shape()) { + if (!model->GetArray(name).has_shape()) { LOG(INFO) << name << " has no shape"; return; } LOG(INFO) << name - << " has shape: " << ShapeToString(model->arrays[name]->shape()); + << " has shape: " << ShapeToString(model->GetArray(name).shape()); } bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { @@ -1673,7 +1674,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) { } void CheckFinalDataTypesSatisfied(const Model& model) { - for (const auto& array_entry : model.arrays) { + for (const auto& array_entry : model.GetArrayMap()) { const auto& array = *array_entry.second; if (array.final_data_type != ArrayDataType::kNone) { CHECK(array.final_data_type == array.data_type) |