aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-15 11:21:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 11:24:59 -0800
commit9393d604ebd63664a27d28617f95fa5f60495270 (patch)
tree3254ac5224a10ec4e0a5bafb59d6f78e0ed7a48e /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent04b5890cbdf6161c6d02db95d3365fac9cbfea05 (diff)
internal change
PiperOrigin-RevId: 179217499
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 1d92bcbccd..4fe127544b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -75,6 +75,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->outputs.size(), 1);
auto* cast_op = static_cast<CastOperator*>(op);
model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ } else if (op->type == OperatorType::kArgMax) {
+ // Data type of the ArgMax op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmax_op = static_cast<ArgMaxOperator*>(op);
+ model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
} else if (op->type == OperatorType::kTensorFlowUnsupported) {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
if (unsupported_op->output_data_types.size() != op->outputs.size()) {