aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc38
1 files changed, 24 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index d496f5ae5e..fcf30bd347 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -32,21 +32,34 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
+ // Handling transposition of the first input here isn't very simple because
+ // we need to know the actual shape in order to produce a proper
+ // TransposeOperator. However, the second input is supposed to be 2D, so we
+ // can actually handle transposition of that matrix, which happens to be more
+ // common anyway.
+ CHECK(!matmul_op->transpose_a);
+
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
// operator. We'll transpose the second input to be in column-major order now
// and let constant propagation optimize things (if possible).
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(
- model,
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
+ string input_lhs = matmul_op->inputs[0];
+ string input_rhs = matmul_op->inputs[1];
+ if (!matmul_op->transpose_b) {
+ auto* transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ matmul_op->inputs[1],
+ CreateInt32Array(model,
+ AvailableArrayName(
+ *model, matmul_op->inputs[1] + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+
+ input_rhs = transpose_op->outputs[0];
+ }
// Refresh iterator.
matmul_it = model->operators.begin();
@@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
}
DCHECK_EQ(matmul_it->get(), matmul_op);
- string input_lhs = matmul_op->inputs[0];
- string input_rhs = transpose_op->outputs[0];
-
// Construct the new FullyConnectedOperator.
auto* fc_op = new FullyConnectedOperator;
fc_op->outputs = matmul_op->outputs;