diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-15 11:21:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 11:24:59 -0800 |
commit | 9393d604ebd63664a27d28617f95fa5f60495270 (patch) | |
tree | 3254ac5224a10ec4e0a5bafb59d6f78e0ed7a48e /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | 04b5890cbdf6161c6d02db95d3365fac9cbfea05 (diff) |
internal change
PiperOrigin-RevId: 179217499
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 1d92bcbccd..4fe127544b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -75,6 +75,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->outputs.size(), 1); auto* cast_op = static_cast<CastOperator*>(op); model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + } else if (op->type == OperatorType::kArgMax) { + // Data type of the ArgMax op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmax_op = static_cast<ArgMaxOperator*>(op); + model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type; } else if (op->type == OperatorType::kTensorFlowUnsupported) { auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); if (unsupported_op->output_data_types.size() != op->outputs.size()) { |