aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-23 12:24:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 12:28:02 -0800
commitba4aec48268d02f111cd7e2c2666f4e7b077e68a (patch)
treeb5b2e456b9499cc460381bdb471e864b8b889dfa /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parentb47119257b76bdec542121aaed056c8b010fd19f (diff)
Internal Change
PiperOrigin-RevId: 182974191
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc20
1 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 29b55d9bfc..f0d107232b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -27,7 +27,7 @@ namespace {
void SetDataTypeForAllOutputs(Model* model, Operator* op,
ArrayDataType data_type) {
for (const auto& output : op->outputs) {
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
}
} // namespace
@@ -39,7 +39,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// If the data type of some input is unknown, we need to yield.
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
- model->arrays[input]->data_type == ArrayDataType::kNone) {
+ model->GetArray(input).data_type == ArrayDataType::kNone) {
return false;
}
}
@@ -47,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// end if we changed anything, and return the correct boolean value.
std::unordered_map<string, ArrayDataType> old_output_data_types;
for (const auto& output : op->outputs) {
- old_output_data_types[output] = model->arrays[output]->data_type;
+ old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
if (op->type == OperatorType::kDequantize ||
@@ -69,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kFill) {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
// Data type of the Cast op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* cast_op = static_cast<CastOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
} else if (op->type == OperatorType::kArgMax) {
// Data type of the ArgMax op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
} else if (op->type == OperatorType::kRange) {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
@@ -91,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// Otherwise use the first input
CHECK_GE(op->inputs.size(), 1);
- data_type = model->arrays[op->inputs[0]]->data_type;
+ data_type = model->GetArray(op->inputs[0]).data_type;
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, data_type);
@@ -103,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
auto output = op->outputs[i];
auto data_type = unsupported_op->output_data_types[i];
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
} else if (op->type == OperatorType::kExpandDims) {
// Yield on ExpandDim until it is converted to Reshape
@@ -111,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
}
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
- if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ if (old_output_data_types[output] != model->GetArray(output).data_type) {
return true;
}
}