aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-11-10 10:35:35 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch)
tree9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
parent7ac140a5845553275427162aabd9d54987144b4a (diff)
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc142
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
new file mode 100644
index 0000000000..1ff4e827aa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+ArrayDataType CommonDataTypeOfAllInputs(const Model& model,
+ const Operator& op) {
+ CHECK_GT(op.inputs.size(), 0);
+ const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type;
+ for (const auto& input : op.inputs) {
+ const auto& array = model.GetArray(input);
+ CHECK(array.data_type == data_type)
+ << " Unexpected: this operator has inputs with different data types.";
+ }
+ return data_type;
+}
+
+void SetDataTypeForAllOutputs(Model* model, Operator* op,
+ ArrayDataType data_type) {
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->data_type = data_type;
+ }
+}
+} // namespace
+
+bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If the data type of some input is unknown, we need to yield.
+ for (const auto& input : op->inputs) {
+ if (model->arrays[input]->data_type == ArrayDataType::kNone) {
+ return false;
+ }
+ }
+ // Record data types of output before processing, so we can see at the
+ // end if we changed anything, and return the correct boolean value.
+ std::unordered_map<string, ArrayDataType> old_output_data_types;
+ for (const auto& output : op->outputs) {
+ old_output_data_types[output] = model->arrays[output]->data_type;
+ }
+ // Do the actual output data types propagation.
+ if (op->type == OperatorType::kDequantize ||
+ op->type == OperatorType::kResizeBilinear) {
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ } else if (op->type == OperatorType::kTensorFlowLess ||
+ op->type == OperatorType::kTensorFlowLessEqual ||
+ op->type == OperatorType::kTensorFlowGreater ||
+ op->type == OperatorType::kTensorFlowGreaterEqual) {
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ } else if (op->type == OperatorType::kTensorFlowShape) {
+ // These operators are assumed to produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ } else if (op->type == OperatorType::kAveragePool ||
+ op->type == OperatorType::kMaxPool ||
+ op->type == OperatorType::kL2Pool ||
+ op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected ||
+ op->type == OperatorType::kTensorFlowMax ||
+ op->type == OperatorType::kTensorFlowMin ||
+ op->type == OperatorType::kPad ||
+ op->type == OperatorType::kStridedSlice ||
+ op->type == OperatorType::kTensorFlowReshape ||
+ op->type == OperatorType::kSlice ||
+ op->type == OperatorType::kSqueeze ||
+ op->type == OperatorType::kTensorFlowSum ||
+ op->type == OperatorType::kTensorFlowSwitch ||
+ op->type == OperatorType::kTensorFlowTile ||
+ op->type == OperatorType::kTensorFlowAll ||
+ op->type == OperatorType::kReorderAxes ||
+ op->type == OperatorType::kTensorFlowConcatV2 ||
+ op->type == OperatorType::kFloor ||
+ op->type == OperatorType::kGather ||
+ op->type == OperatorType::kSpaceToBatchND ||
+ op->type == OperatorType::kBatchToSpaceND ||
+ op->type == OperatorType::kMean) {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kTensorFlowSplit ||
+ op->type == OperatorType::kTensorFlowConcat) {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GT(op->inputs.size(), 1);
+ const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kCast) {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ } else if (op->type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ if (unsupported_op->output_data_types.size() != op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->arrays[output]->data_type = data_type;
+ }
+ } else {
+ // These operators produce an output with the same type as any of their
+ // inputs, which must always have the same type.
+ const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ }
+ // Return true if any output data type changed, false if none changed.
+ for (const auto& output : op->outputs) {
+ if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco