aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 12:55:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 13:03:19 -0700
commitf72126c164ea67b226368bf51811c8528d81093b (patch)
tree116987724ac043ba9e8b00692f4e5eed821b85f7
parentc95353498e180e50f701dcb8331b994d9e5fad0b (diff)
Two improvements in resolve_tensorflow_matmul:
1. Before inserting a new Transpose node, check if there already is one that may be reused. In practice, there are two cases: either the array being transposed is a constant (by far the most common case) or it's not. * If it is constant, then this doesn't really make a difference: ResolveConstantTranspose runs anyway, eliminating these Transpose nodes and also mootifying this change as it leaves no Transpose node to be reused. So in that case, constant-array-deduping is really the only thing that prevents duplication of data. * If it is not constant, that's where this new logic really helps, as the resulting Transpose nodes are here to stay in the final graph, and this avoids inserting more than are needed. 2. transpose_a is not supported. However, rather than CHECK-fail, it's more useful to have this graph transformation bail with a log message. The resulting 'unresolved' MatMul node could still be handled in some way at the TFLite level, or we could end up having support for MatMul per se. PiperOrigin-RevId: 213678294
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc80
1 files changed, 67 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd347..65346c4fe4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@ limitations under the License.
namespace toco {
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+ const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ Operator* op = it->get();
+ if (op->type != OperatorType::kTranspose) {
+ continue;
+ }
+ if (op->inputs[0] != array_name) {
+ continue;
+ }
+ const auto& permutation_array = model.GetArray(op->inputs[1]);
+ if (permutation_array.data_type != ArrayDataType::kInt32) {
+ continue;
+ }
+ const auto& permutation_data =
+ permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+ if (permutation_data.size() != 2) {
+ continue;
+ }
+ if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+ continue;
+ }
+ return static_cast<TransposeOperator*>(op);
+ }
+ return nullptr;
+}
+
+} // namespace
+
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
- CHECK(!matmul_op->transpose_a);
+ if (matmul_op->transpose_a) {
+ AddMessageF(
+ "Not replacing %s by a FullyConnected operator, because it has "
+ "the transpose_a attribute",
+ LogName(*matmul_op));
+ return false;
+ }
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(model,
- AvailableArrayName(
- *model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
-
+ // Need to transpose input_rhs, by inserting a TransposeOperator.
+ // First, check if there already is a TransposeOperator transposing that
+ // array, so we can just reuse it.
+ auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+ if (!transpose_op) {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, created new "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ // No such TransposeOperator found. Create one now.
+ transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ input_rhs,
+ CreateInt32Array(
+ model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, input_rhs + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+ // Sanity check
+ DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+ } else {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, reused existing "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ }
+ // Re-wire: have the matmul consume the transposed array.
input_rhs = transpose_op->outputs[0];
}