diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-23 12:24:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-23 12:28:02 -0800 |
commit | ba4aec48268d02f111cd7e2c2666f4e7b077e68a (patch) | |
tree | b5b2e456b9499cc460381bdb471e864b8b889dfa /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | b47119257b76bdec542121aaed056c8b010fd19f (diff) |
Internal Change
PiperOrigin-RevId: 182974191
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 29b55d9bfc..f0d107232b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -27,7 +27,7 @@ namespace { void SetDataTypeForAllOutputs(Model* model, Operator* op, ArrayDataType data_type) { for (const auto& output : op->outputs) { - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } } // namespace @@ -39,7 +39,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // If the data type of some input is unknown, we need to yield. for (const auto& input : op->inputs) { if (!model->IsOptionalArray(input) && - model->arrays[input]->data_type == ArrayDataType::kNone) { + model->GetArray(input).data_type == ArrayDataType::kNone) { return false; } } @@ -47,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // end if we changed anything, and return the correct boolean value. std::unordered_map<string, ArrayDataType> old_output_data_types; for (const auto& output : op->outputs) { - old_output_data_types[output] = model->arrays[output]->data_type; + old_output_data_types[output] = model->GetArray(output).data_type; } // Do the actual output data types propagation. if (op->type == OperatorType::kDequantize || @@ -69,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { op->type == OperatorType::kFill) { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); - const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kCast) { // Data type of the Cast op is specified. CHECK_EQ(op->outputs.size(), 1); auto* cast_op = static_cast<CastOperator*>(op); - model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type; } else if (op->type == OperatorType::kArgMax) { // Data type of the ArgMax op is specified. CHECK_EQ(op->outputs.size(), 1); auto* argmax_op = static_cast<ArgMaxOperator*>(op); - model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type; + model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; } else if (op->type == OperatorType::kRange) { auto* range_op = static_cast<RangeOperator*>(op); // Output type of the Range op can be set via an attribute @@ -91,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // Otherwise use the first input CHECK_GE(op->inputs.size(), 1); - data_type = model->arrays[op->inputs[0]]->data_type; + data_type = model->GetArray(op->inputs[0]).data_type; } CHECK_EQ(op->outputs.size(), 1); SetDataTypeForAllOutputs(model, op, data_type); @@ -103,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { auto output = op->outputs[i]; auto data_type = unsupported_op->output_data_types[i]; - model->arrays[output]->data_type = data_type; + model->GetArray(output).data_type = data_type; } } else if (op->type == OperatorType::kExpandDims) { // Yield on ExpandDim until it is converted to Reshape @@ -111,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } else { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); - const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; SetDataTypeForAllOutputs(model, op, data_type); } // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { - if (old_output_data_types[output] != model->arrays[output]->data_type) { + if (old_output_data_types[output] != model->GetArray(output).data_type) { return true; } } |