diff options
author | Jared Duke <jdduke@google.com> | 2018-07-26 10:53:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 10:56:43 -0700 |
commit | 6e658c0a5ca77677a954a34fb98f241c592c970d (patch) | |
tree | b645103887539af5232b3f70d80a2eb9b77ed63a /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | 0a3155f7fbf56df5e81c7cbf35afd45173359635 (diff) |
Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 9c22497d5e..0f94006f34 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -201,6 +201,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kOneHot: { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType on_value_type = + model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type; + const ArrayDataType off_value_type = + model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT]) + .data_type; + CHECK(on_value_type == off_value_type); + model->GetArray(op->outputs[0]).data_type = on_value_type; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); |