aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 01:53:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 01:56:38 -0700
commit7f85c95f71b01f711c366942a7cd911b0743b72c (patch)
treead92abbdf8a4c5736e323a75c3c153dadf3ce5c6 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentc59bb780ebd1674ab34dd96d193c71698682ed4d (diff)
Implementation of arg_min.
PiperOrigin-RevId: 203908601
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 8eb0423283..4f95c57451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1404,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1696,7 +1697,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
+ break;
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
break;
case OperatorType::kUnsupported:
break;