diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc index fedf4441e2..5ff39aa313 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -36,10 +36,12 @@ namespace toco { // slice_c = tf.matmul(slice_a, slice_b) // result_slices[bat] = slice_c // result = tf.stack(result_slices) -bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { +::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto batch_op_it = model->operators.begin() + op_index; if (batch_op_it->get()->type != OperatorType::kBatchMatMul) { - return false; + return ::tensorflow::Status::OK(); } const auto* batch_op = static_cast<const BatchMatMulOperator*>(batch_op_it->get()); @@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { // We must have the shape of at least one input to know our batch size. const auto& input_array_a = model->GetArray(batch_op->inputs[0]); const auto& input_array_b = model->GetArray(batch_op->inputs[1]); - if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false; + if (!input_array_a.has_shape() || !input_array_b.has_shape()) + return ::tensorflow::Status::OK(); // We only support the rank 3 case. If you are batching on rank > 3 you'll // have to figure that out. @@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { batch_op_it = matmul_op_it + 1; CHECK_EQ(batch_op_it->get(), batch_op); model->operators.erase(batch_op_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } CHECK_EQ(input_array_a.shape().dimensions_count(), 3) << "Input arrays must have rank 3"; @@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { CHECK(batch_op_it != model->operators.end()); CHECK(batch_op_it->get() == batch_op); model->operators.erase(batch_op_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |