aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc22
1 files changed, 20 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 00ab7cbaa9..9c22497d5e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -62,6 +62,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kGreaterEqual:
case OperatorType::kEqual:
case OperatorType::kNotEqual:
+ case OperatorType::kAny:
+ case OperatorType::kLogicalAnd:
+ case OperatorType::kLogicalNot:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
@@ -100,6 +103,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
break;
}
+ case OperatorType::kArgMin: {
+ // Data type of the ArgMin op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmin_op = static_cast<ArgMinOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type;
+ break;
+ }
case OperatorType::kRange: {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
@@ -144,8 +154,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
return false;
}
for (int i = 0; i < op->outputs.size(); ++i) {
- auto output = op->outputs[i];
- auto data_type = unsupported_op->output_data_types[i];
+ const string& output = op->outputs[i];
+ const ArrayDataType data_type = unsupported_op->output_data_types[i];
model->GetArray(output).data_type = data_type;
}
break;
@@ -183,6 +193,14 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
+ case OperatorType::kPack: {
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ for (const auto& input : op->inputs) {
+ CHECK(data_type == model->GetArray(input).data_type);
+ }
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);