aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 15:41:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 17:27:20 -0700
commitfc7f0b296dd53d1b72af21d36d36b6bcc5291ea7 (patch)
tree46e76ead2391a3fb1232459189ad0b8d0d8066ac /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent3a2f1cfb73fa6a21eba077485bdc08aa05646ad1 (diff)
Add support for select (via tf.where) support to tflite.
PiperOrigin-RevId: 195734246
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c1cf79f626..6342cf3e8a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -152,6 +152,17 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// Yield on ExpandDim until it is converted to Reshape
return false;
}
+ case OperatorType::kSelect: {
+ // Select produces outputs with the same type as their 2nd input
+ CHECK_EQ(op->inputs.size(), 3);
+ const ArrayDataType data_type_x =
+ model->GetArray(op->inputs[1]).data_type;
+ const ArrayDataType data_type_y =
+ model->GetArray(op->inputs[2]).data_type;
+ CHECK(data_type_x == data_type_y);
+ SetDataTypeForAllOutputs(model, op, data_type_x);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);