aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 12:52:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 12:55:32 -0700
commit5be44f95a612b59be2182769c6650e4bb82c7355 (patch)
tree8826c2141c4edf6c7f3e2e284417889183052cbe /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent4b988cac015c1b464b2f83024a2853f628a7c938 (diff)
Allow transposition of the weights in fully connected ops.
PiperOrigin-RevId: 202693036
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc13
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 55e39d963f..5c32a39035 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -984,18 +984,19 @@ tensorflow::Status ConvertMatMulOperator(
Model* model) {
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- // Transpose flags should be easy to support, but we don't have a
- // GraphDef with them to test on at the moment.
- CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"),
- false);
- CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"),
- false);
CHECK(!HasAttr(node, "adjoint_a") ||
(GetBoolAttr(node, "adjoint_a") == false));
CHECK(!HasAttr(node, "adjoint_b") ||
(GetBoolAttr(node, "adjoint_b") == false));
auto* matmul = new TensorFlowMatMulOperator;
+ if (HasAttr(node, "transpose_a")) {
+ matmul->transpose_a = GetBoolAttr(node, "transpose_a");
+ }
+ if (HasAttr(node, "transpose_b")) {
+ matmul->transpose_b = GetBoolAttr(node, "transpose_b");
+ }
+
matmul->inputs = {node.input(0), node.input(1)};
matmul->outputs = {node.name()};
model->operators.emplace_back(matmul);