diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-29 12:52:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 12:55:32 -0700 |
commit | 5be44f95a612b59be2182769c6650e4bb82c7355 (patch) | |
tree | 8826c2141c4edf6c7f3e2e284417889183052cbe /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | 4b988cac015c1b464b2f83024a2853f628a7c938 (diff) |
Allow transposition of the weights in fully connected ops.
PiperOrigin-RevId: 202693036
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 55e39d963f..5c32a39035 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -984,18 +984,19 @@ tensorflow::Status ConvertMatMulOperator( Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"), - false); - CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"), - false); CHECK(!HasAttr(node, "adjoint_a") || (GetBoolAttr(node, "adjoint_a") == false)); CHECK(!HasAttr(node, "adjoint_b") || (GetBoolAttr(node, "adjoint_b") == false)); auto* matmul = new TensorFlowMatMulOperator; + if (HasAttr(node, "transpose_a")) { + matmul->transpose_a = GetBoolAttr(node, "transpose_a"); + } + if (HasAttr(node, "transpose_b")) { + matmul->transpose_b = GetBoolAttr(node, "transpose_b"); + } + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); |