aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-02 20:11:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 20:15:21 -0700
commiteecadedbaae7b938e8a80dfb60c52679bcbf7196 (patch)
tree7f41ed6a9a3be126f5026a809593736d9bd10340 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent2d3819668d8c3ab99cd09a769ffb7b76e453fd8f (diff)
Implementation of ctc beam search decoder op in custom op fashion.
PiperOrigin-RevId: 207210333
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index f033ee013e..c8310161cb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -215,6 +215,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = on_value_type;
break;
}
+ case OperatorType::kCTCBeamSearchDecoder: {
+ CHECK_EQ(op->inputs.size(), 2);
+ // All outputs (sparse tensors) are int32s (although tf uses int64s)
+ // except the last one (log probabilities) is float.
+ const int output_size = op->outputs.size();
+ for (int i = 0; i < output_size - 1; ++i) {
+ model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32;
+ }
+ model->GetArray(op->outputs[output_size - 1]).data_type =
+ ArrayDataType::kFloat;
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);