diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-02 20:11:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 20:15:21 -0700 |
commit | eecadedbaae7b938e8a80dfb60c52679bcbf7196 (patch) | |
tree | 7f41ed6a9a3be126f5026a809593736d9bd10340 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | 2d3819668d8c3ab99cd09a769ffb7b76e453fd8f (diff) |
Implementation of ctc beam search decoder op in custom op fashion.
PiperOrigin-RevId: 207210333
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index f033ee013e..c8310161cb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -215,6 +215,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[0]).data_type = on_value_type; break; } + case OperatorType::kCTCBeamSearchDecoder: { + CHECK_EQ(op->inputs.size(), 2); + // All outputs (sparse tensors) are int32s (although tf uses int64s) + // except the last one (log probabilities) is float. + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size - 1; ++i) { + model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32; + } + model->GetArray(op->outputs[output_size - 1]).data_type = + ArrayDataType::kFloat; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); |