aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-24 16:23:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 16:29:50 -0800
commitdc0511e51a8f2af8b0f053d7f04352f0b0be58fa (patch)
tree36b0f3e60929072c3a37aca562b99e8f135b60b9
parente76a7a7c8bccb1fb67559160c9a06ba3a722fd54 (diff)
Basic AddN support in toco
PiperOrigin-RevId: 183160197
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc51
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc25
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc15
-rw-r--r--tensorflow/contrib/lite/toco/model.h11
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc11
10 files changed, 133 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index ad8f0e4a47..041e248790 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -172,6 +172,7 @@ cc_library(
"graph_transformations/convert_expanddims_to_reshape.cc",
"graph_transformations/convert_pure_conv_to_depthwise.cc",
"graph_transformations/convert_reorder_axes.cc",
+ "graph_transformations/convert_trivial_addn_to_add.cc",
"graph_transformations/convert_trivial_transpose_to_reshape.cc",
"graph_transformations/create_im2col_arrays.cc",
"graph_transformations/dequantize.cc",
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 4fc01dbc20..529df3cd2e 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -519,6 +519,18 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op,
(*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
+void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("AddN");
+ add_op->set_name(src_op.outputs[0]);
+ for (const auto& input : src_op.inputs) {
+ *add_op->add_input() = input;
+ }
+ (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
void ConvertMulOperator(const Model& model, const MulOperator& src_op,
GraphDef* tensorflow_graph) {
auto* add_op = tensorflow_graph->add_node();
@@ -1406,6 +1418,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kAdd) {
ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAddN) {
+ ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kMul) {
ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
new file mode 100644
index 0000000000..dcaaddbf3b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// This pass will convert an AddN operator with only 2 inputs into a regular Add
+// operator, to which more optimizations may apply.
+bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
+ auto addn_it = model->operators.begin() + op_index;
+ if (addn_it->get()->type != OperatorType::kAddN) {
+ return false;
+ }
+ AddNOperator* addn_op = static_cast<AddNOperator*>(addn_it->get());
+ CHECK_GE(addn_op->inputs.size(), 2);
+ CHECK_EQ(addn_op->outputs.size(), 1);
+
+ // We only reduce AddN with N=2 to a regular Add.
+ if (addn_op->inputs.size() != 2) {
+ return false;
+ }
+
+ // Copy inputs & outputs to regular Add.
+ auto* add_op = new AddOperator;
+ add_op->inputs.push_back(addn_op->inputs[0]);
+ add_op->inputs.push_back(addn_op->inputs[1]);
+ add_op->outputs = addn_op->outputs;
+
+ // Replace the AddN operator in the graph.
+ const auto add_it = model->operators.emplace(addn_it, add_op);
+ addn_it = add_it + 1;
+ CHECK_EQ(addn_it->get(), addn_op);
+ model->operators.erase(addn_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 4ac2265be9..e11bebcd4e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -114,6 +114,7 @@ void RunGraphTransformations(Model* model, const string& message,
// List of all graph transformations
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ff0a3bd881..4fb3b6ae7a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -406,6 +406,28 @@ void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
&output_array);
}
+void ProcessAddNOperator(Model* model, Operator* op) {
+ // Yield until all input dims have been resolved.
+ //
+ // TODO(myenik): Since AddN does not support broadcasting, maybe we could
+ // actually use this to improve shape propagation by propagating the shape of
+ // one input to all other inputs once it is resolved instead of just the
+ // output, since all inputs must be the same size and shape for a well-formed
+ // graph.
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+
+ // AddN does not support broadcasting, all inputs must be the same shape, so
+ // we just take the first input shape and apply it to the output.
+ const auto& input0_array = model->GetArray(op->inputs[0]);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(input0_array.shape());
+}
+
bool KeepDims(const Operator& op) {
switch (op.type) {
case OperatorType::kTensorFlowMin:
@@ -1282,6 +1304,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowGreaterEqual:
ProcessSimpleBinaryOperator(model, op);
break;
+ case OperatorType::kAddN:
+ ProcessAddNOperator(model, op);
+ break;
case OperatorType::kConv:
ProcessConvOperator(model, static_cast<ConvOperator*>(op));
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index e8f318cd43..ca378af4c5 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -696,6 +696,19 @@ void ConvertAddOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
+void ConvertAddNOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "AddN");
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ auto* op = new AddNOperator;
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
void ConvertMulOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1862,6 +1875,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
ConvertSquareOperator(node, tf_import_flags, model);
} else if (node.op() == "Add") {
ConvertAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AddN") {
+ ConvertAddNOperator(node, tf_import_flags, model);
} else if (node.op() == "Mul") {
ConvertMulOperator(node, tf_import_flags, model);
} else if (node.op() == "Sub") {
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index b1b9b718bb..d1af371fd4 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -32,6 +32,7 @@ enum class OperatorType {
kNone,
// General-purpose neural network operators.
kAdd,
+ kAddN,
kAveragePool,
kBatchNormalization,
kConv,
@@ -559,6 +560,16 @@ struct AddOperator : Operator {
AddOperator() : Operator(OperatorType::kAdd) {}
};
+// Element-wise addition operator for N inputs.
+//
+// Inputs:
+// inputs[i]: The i-th array to add together to form the output.
+//
+// TensorFlow equivalent: AddN
+struct AddNOperator : Operator {
+ AddNOperator() : Operator(OperatorType::kAddN) {}
+};
+
// Concatenation operator: concatenates its inputs
// along the axis.
//
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index d75d1fcc5b..298f49025f 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -807,6 +807,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
+ ops.emplace_back(
+ new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
"RSQRT", OperatorType::kTensorFlowRsqrt));
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index f2753c84e9..720c33777d 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -52,6 +52,7 @@ void MakeGeneralGraphTransformationsSet(
GraphTransformationsSet* transformations) {
CHECK(transformations->empty());
transformations->Add(new ConvertExpandDimsToReshape);
+ transformations->Add(new ConvertTrivialAddNToAdd);
transformations->Add(new ConvertTrivialTransposeToReshape);
transformations->Add(new ConvertReorderAxes);
transformations->Add(new ResolveReshapeAttributes);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 8543ba4742..99a54a300b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -197,6 +197,7 @@ const char* OperatorTypeName(OperatorType type) {
case OperatorType::k##c: \
return #c;
HANDLE_OPERATORTYPENAME_CASE(Add)
+ HANDLE_OPERATORTYPENAME_CASE(AddN)
HANDLE_OPERATORTYPENAME_CASE(AveragePool)
HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
HANDLE_OPERATORTYPENAME_CASE(Conv)
@@ -1396,6 +1397,16 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
total += RequiredBufferSizeForShape(output_array.shape());
break;
}
+ case OperatorType::kAddN: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // AddN cost is roughly the same cost as N-1 Adds.
+ const int num_adds = op->inputs.size() - 1;
+ total += num_adds * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
case OperatorType::kLogistic:
case OperatorType::kSoftmax:
case OperatorType::kTanh: {