aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
index fedf4441e2..5ff39aa313 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
@@ -36,10 +36,12 @@ namespace toco {
// slice_c = tf.matmul(slice_a, slice_b)
// result_slices[bat] = slice_c
// result = tf.stack(result_slices)
-bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto batch_op_it = model->operators.begin() + op_index;
if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* batch_op =
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
@@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
// We must have the shape of at least one input to know our batch size.
const auto& input_array_a = model->GetArray(batch_op->inputs[0]);
const auto& input_array_b = model->GetArray(batch_op->inputs[1]);
- if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false;
+ if (!input_array_a.has_shape() || !input_array_b.has_shape())
+ return ::tensorflow::Status::OK();
// We only support the rank 3 case. If you are batching on rank > 3 you'll
// have to figure that out.
@@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
batch_op_it = matmul_op_it + 1;
CHECK_EQ(batch_op_it->get(), batch_op);
model->operators.erase(batch_op_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(input_array_a.shape().dimensions_count(), 3)
<< "Input arrays must have rank 3";
@@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
CHECK(batch_op_it != model->operators.end());
CHECK(batch_op_it->get() == batch_op);
model->operators.erase(batch_op_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco