diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations')
53 files changed, 7478 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc new file mode 100644 index 0000000000..bf454c40c7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + const auto* conv_op = static_cast<ConvOperator*>(conv_it->get()); + if (conv_op->stride_width != conv_op->stride_height) { + return false; + } + auto& weights_array = model->GetArray(conv_op->inputs[1]); + if (!weights_array.buffer) { + // Yield until the weights are resolved as a constant array. + return false; + } + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + if (weights_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } + // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. + AddMessageF( + "%s is purely convolutional (input/weights depth is 1), replacing it by " + "a DepthwiseConv.", + LogName(*conv_op)); + auto* depthwiseconv_op = new DepthwiseConvOperator; + // Conv and DepthwiseConv take the same inputs + depthwiseconv_op->inputs = conv_op->inputs; + // Conv may have a 2nd output for im2col + depthwiseconv_op->outputs = {conv_op->outputs[0]}; + if (conv_op->outputs.size() > 1) { + // delete the im2col array. + model->arrays.erase(conv_op->outputs[1]); + } + depthwiseconv_op->fused_activation_function = + conv_op->fused_activation_function; + // Let PropagateFixedSizes recompute fixed padding, just in case some day it + // may be different for Conv vs DepthwiseConv. + depthwiseconv_op->padding.type = conv_op->padding.type; + depthwiseconv_op->stride_height = conv_op->stride_height; + depthwiseconv_op->stride_width = conv_op->stride_width; + depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0); + // Replace the operator in the graph. + const auto depthwiseconv_it = + model->operators.emplace(conv_it, depthwiseconv_op); + conv_it = depthwiseconv_it + 1; + CHECK_EQ(conv_it->get(), conv_op); + model->operators.erase(conv_it); + // Shuffle the weights. + const auto& weights_shape = weights_array.shape(); + auto& weights_buffer = + weights_array.GetMutableBuffer<ArrayDataType::kFloat>(); + const std::vector<float>& conv_weights_data = weights_buffer.data; + std::vector<float> depthwise_conv_weights_data(conv_weights_data.size()); + const int depth = weights_shape.dims(0); + const int width = weights_shape.dims(1); + const int height = weights_shape.dims(2); + const int width_height = width * height; + for (int c = 0; c < depth; c++) { + for (int xy = 0; xy < width_height; xy++) { + depthwise_conv_weights_data[c + depth * xy] = + conv_weights_data[xy + width_height * c]; + } + } + *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth}; + weights_buffer.data = depthwise_conv_weights_data; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc new file mode 100644 index 0000000000..1735b51e5b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast<ConvOperator*>(conv_it->get()); + if (conv_op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + const auto& weights_array = *model->arrays[conv_op->inputs[1]]; + if (!weights_array.has_shape()) { + // We need to yield until weights dims have been resolved, because + // from the weights dims we determine whether an im2col array is + // needed. + return false; + } + const auto& weights_shape = weights_array.shape(); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && + conv_op->stride_height == 1) { + // 1x1 unstrided conv does not need an im2col array. + return false; + } + + // Create the im2col array. + CHECK_EQ(conv_op->outputs.size(), 1); + const string& im2col_array_name = + AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + conv_op->outputs.push_back(im2col_array_name); + AddMessageF( + "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " + "stride_height=%d", + LogName(*conv_op), kwidth, kheight, conv_op->stride_width, + conv_op->stride_height); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc new file mode 100644 index 0000000000..b89e3f5310 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template <ArrayDataType A> +void DequantizeBuffer(Array* array) { + const auto old_data = array->GetBuffer<A>().data; + array->buffer = nullptr; + array->data_type = ArrayDataType::kFloat; + auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data; + new_data.resize(old_data.size()); + const auto& qparams = array->GetQuantizationParams(); + for (int i = 0; i < old_data.size(); i++) { + new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point); + } +} + +std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput( + Model* model, const string& array_name) { + for (auto it = model->operators.begin(); it != model->operators.end(); ++it) { + for (const auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model->operators.end(); +} + +void ClearArrayQuantizationParams(const string& array_name, Model* model) { + auto* array = model->arrays.at(array_name).get(); + CHECK(array->quantization_params); + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == array_name) { + auto& qparams = *array->quantization_params; + const double new_std_value = 1. / qparams.scale; + const double new_mean_value = qparams.zero_point; + if (input_array.has_std_value()) { + CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001); + } else { + input_array.set_std_value(new_std_value); + } + if (input_array.has_mean_value()) { + CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001); + } else { + input_array.set_mean_value(new_mean_value); + } + } + } + array->quantization_params = nullptr; +} + +bool DequantizeArray(const string& array_name, + GraphTransformation* transformation, Model* model) { + auto* array = model->arrays.at(array_name).get(); + if (!array->quantization_params) { + return false; + } + transformation->AddMessageF("Dequantizing array: %s", array_name); + + // Dequantize any buffer + if (array->buffer) { + if (array->data_type == ArrayDataType::kUint8) { + DequantizeBuffer<ArrayDataType::kUint8>(array); + } else if (array->data_type == ArrayDataType::kInt32) { + DequantizeBuffer<ArrayDataType::kInt32>(array); + } else { + LOG(FATAL) << "Unhandled data type"; + } + CHECK(array->data_type == ArrayDataType::kFloat); + CHECK(array->buffer->type == ArrayDataType::kFloat); + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + return true; + } else { + array->data_type = ArrayDataType::kFloat; + } + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + + if (array->buffer) { + return true; + } + + auto* op_outputting_array = GetOpWithOutput(*model, array_name); + if (op_outputting_array) { + if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + return true; + } + } + + // If there was no minmax info, we can return now. Indeed, + // the below only serves to create a FakeQuant node, but some arrays are + // quantized without MinMax (see the CHECK above) and that corresponds to + // places where a FakeQuant node is actually not wanted, because the + // quantization params are meant to be inferred in another way (e.g. bias + // vector for a Conv op, see their special-casing in quantize.cc). + if (!array->minmax) { + return true; + } + + // Determine whether to insert a FakeQuant before or after + // this array. + bool must_insert_fakequant_before = false; + bool must_insert_fakequant_after = false; + if (IsInputArray(*model, array_name)) { + must_insert_fakequant_after = true; + } + for (const string& output_array : model->flags.output_arrays()) { + if (array_name == output_array) { + must_insert_fakequant_before = true; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + must_insert_fakequant_after = true; + } + if (array_name == rnn_state.back_edge_source_array()) { + must_insert_fakequant_before = true; + } + } + CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after)); + + // Create and insert the FakeQuant node + auto* fakequant_op = new FakeQuantOperator; + model->operators.emplace(FindFirstOpWithInput(model, array_name), + fakequant_op); + const string& new_array_name = AvailableArrayName(*model, array_name); + auto& new_array = model->GetOrCreateArray(new_array_name); + new_array.data_type = ArrayDataType::kFloat; + new_array.copy_shape(array->shape()); + new_array.GetOrCreateMinMax() = array->GetMinMax(); + fakequant_op->minmax.reset(new MinMax); + *fakequant_op->minmax = array->GetMinMax(); + if (must_insert_fakequant_before) { + for (const auto& op : model->operators) { + for (string& output : op->outputs) { + if (output == array_name) { + output = new_array_name; + } + } + } + fakequant_op->inputs = {new_array_name}; + fakequant_op->outputs = {array_name}; + } else { + for (const auto& op : model->operators) { + for (string& input : op->inputs) { + if (input == array_name) { + input = new_array_name; + } + } + } + fakequant_op->inputs = {array_name}; + fakequant_op->outputs = {new_array_name}; + } + return true; +} + +} // namespace + +bool Dequantize::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + auto* op = op_it->get(); + + if (op->type == OperatorType::kDequantize) { + auto& input_array = model->GetArray(op->inputs[0]); + if (input_array.data_type == ArrayDataType::kFloat) { + return false; + } + if (input_array.final_data_type != ArrayDataType::kFloat) { + return false; + } + input_array.data_type = ArrayDataType::kFloat; + input_array.quantization_params = nullptr; + auto& output_array = model->GetArray(op->outputs[0]); + output_array.data_type = ArrayDataType::kFloat; + output_array.quantization_params = nullptr; + return RemoveTrivialPassthroughOp(this, model, op_index); + } + + std::vector<string> arrays; + for (const string& input : op->inputs) { + arrays.push_back(input); + } + for (const string& output : op->outputs) { + arrays.push_back(output); + } + bool changed = false; + for (const string& array : arrays) { + changed |= DequantizeArray(array, this, model); + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc new file mode 100644 index 0000000000..fea360740f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op); + + if (!fakequant_op->minmax) { + return false; + } + + const auto& output_array = model->GetArray(fakequant_op->outputs[0]); + if (!output_array.minmax) { + return false; + } + + // Drop min/max inputs + for (int i = 1; i < fakequant_op->inputs.size(); i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc new file mode 100644 index 0000000000..a3ed6663bc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast<ConvOperator*>(conv_it->get()); + if (conv_op->outputs.size() < 2) { + // Conv op does not have im2col. + return false; + } + + // Drop the im2col array. + CHECK_EQ(conv_op->outputs.size(), 2); + model->arrays.erase(conv_op->outputs[1]); + conv_op->outputs.resize(1); + AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc new file mode 100644 index 0000000000..badefeca88 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ProcessLinearOperator(Model* model, Operator* op) { + if (op->inputs.size() >= 3) { + return false; + } + const string& output_name = op->outputs[0]; + const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); + op->inputs.push_back(bias_name); + DCHECK_EQ(op->inputs.size(), 3); + auto& bias_array = model->GetOrCreateArray(bias_name); + bias_array.data_type = ArrayDataType::kFloat; + + return true; +} +} // namespace + +bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) { + auto* op = model->operators[op_index].get(); + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + if (ProcessLinearOperator(model, op)) { + AddMessageF("Added bias vector to %s", LogName(*op)); + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc new file mode 100644 index 0000000000..7a86510025 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + const auto* ac_op = ac_it->get(); + + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + // Find the op producing the array passed to this activation function + Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]); + + if (!op) return false; + + if (CountTrueOutputs(*model, *op) > 1) { + AddMessageF( + "Not fusing activation function into %s because it has more than one " + " consumed output", + LogName(*op)); + return false; + } + + CHECK_EQ(op->outputs[0], ac_op->inputs[0]); + + int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing activation function into %s because it is consumed by more " + "than 1 other operator", + LogName(*op)); + return false; + } + + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing activation function into %s because it already has a fused " + "activation function", + LogName(*op)); + return false; + } + + // TODO(dkalenichenko): Great many ops don't support activation function + // fusing. Switch to the whilelist approach instead. + if (op->type == OperatorType::kConcatenation || + op->type == OperatorType::kSlice) { + AddMessageF( + "Not fusing activation function because the %s op doesn't support it", + LogName(*op)); + return false; + } + + AddMessageF("Fusing activation function %s into the preceding %s", + LogName(*ac_op), LogName(*op)); + if (ac_op->type == OperatorType::kRelu6) { + op->fused_activation_function = FusedActivationFunctionType::kRelu6; + } else if (ac_op->type == OperatorType::kRelu1) { + op->fused_activation_function = FusedActivationFunctionType::kRelu1; + } else if (ac_op->type == OperatorType::kRelu) { + op->fused_activation_function = FusedActivationFunctionType::kRelu; + } else { + LOG(FATAL) << "Unhandled activation function type"; + } + model->arrays.erase(ac_op->inputs[0]); + op->outputs[0] = ac_op->outputs[0]; + model->operators.erase(ac_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc new file mode 100644 index 0000000000..4619d8bbee --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a subtraction, the constant input should be the right hand + // side. + // This should have been checked before this point. + CHECK(add_or_sub_op->type != OperatorType::kSub || + index_of_constant_input == 1); + if (following_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights = model->GetArray(following_op->inputs[1]); + auto& bias = model->GetArray(following_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer<ArrayDataType::kFloat>().data[0]; + // At this point we reduce the case of subtraction to that of addition + // by negating the operand. + float add_scalar_operand = 0.f; + if (add_or_sub_op->type == OperatorType::kAdd) { + add_scalar_operand = scalar_operand; + } else if (add_or_sub_op->type == OperatorType::kSub && + index_of_constant_input == 1) { + add_scalar_operand = -scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + // From here on we are fusing an addition. add_or_sub_op->type does not + // matter anymore. + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>(); + const float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); + float* const bias_data = bias_buffer.data.data(); + + if (following_op->type == OperatorType::kConv || + following_op->type == OperatorType::kFullyConnected) { + const int output_depth = weights_shape.dims(0); + // TODO(b/62904716): Bias array should become 1-D when padding removed. + CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1)); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int d = 0; d < output_depth; d++) { + float accumulation = 0; + for (int i = 0; i < weights_per_depth; i++) { + accumulation += + add_scalar_operand * weights_data[d * weights_per_depth + i]; + } + bias_data[d] += accumulation; + } + } else if (following_op->type == OperatorType::kDepthwiseConv) { + const int output_depth = + weights_shape.dims(weights_shape.dimensions_count() - 1); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int c = 0; c < output_depth; c++) { + float accumulation = 0; + for (int k = 0; k < weights_per_depth; k++) { + accumulation += add_scalar_operand * weights_data[k * output_depth + c]; + } + bias_data[c] += accumulation; + } + } else { + LOG(FATAL) << "Should not get here."; + } +} + +void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + const auto& weights_name = following_op->inputs[1]; + const auto& bias_name = following_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer<ArrayDataType::kFloat>().data[0]; + + float* weights_data = + weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data(); + const int weights_size = RequiredBufferSizeForShape(weights.shape()); + for (int i = 0; i < weights_size; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[i] *= scalar_operand; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[i] /= scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + } +} + +} // namespace + +bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + const auto& operand_shape = + model->GetArray(binary_op->inputs[index_of_constant_input]).shape(); + for (const auto& dim : operand_shape.dims()) { + if (dim > 1) { + AddMessageF( + "Not fusing %s into the following affine op, because we only know " + "how to do so when the constant operand is a scalar", + LogName(*binary_op)); + return false; + } + } + + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF("Not fusing %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); + + if (!following_op) { + AddMessageF( + "Not fusing %s because it is not consumed by exactly one other op", + LogName(*binary_op)); + return false; + } + + if (following_op->type != OperatorType::kConv && + following_op->type != OperatorType::kFullyConnected && + following_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the following %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + if (following_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the following %s does not have a bias vector", + LogName(*following_op), LogName(*binary_op)); + return false; + } + + const auto& weights = model->GetArray(following_op->inputs[1]); + const auto& bias = model->GetArray(following_op->inputs[2]); + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the following %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + // Try to fuse the binary params into the following op's params + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (following_op->type == OperatorType::kConv) { + if (static_cast<ConvOperator*>(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + if (following_op->type == OperatorType::kDepthwiseConv) { + if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + AddMessageF("Fusing %s into the following %s", LogName(*binary_op), + LogName(*following_op)); + + model->arrays.erase(binary_op->outputs[0]); + following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc new file mode 100644 index 0000000000..8948653ec3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -0,0 +1,326 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + auto& bias = model->GetArray(preceding_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); + const float* const operand_data = operand_buffer.data.data(); + + // TODO(b/62904716): Bias array should become 1-D when padding removed. + const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); + CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); + + enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; + + const OpType optype = (add_or_sub_op->type == OperatorType::kAdd) + ? OpType::BiasPlusOperand + : (index_of_constant_input == 1) + ? OpType::BiasMinusOperand + : OpType::OperandMinusBias; + + for (int i = 0; i < depth; i++) { + float& bias_val = bias_data[i]; + const float operand_val = operand_data[i]; + if (optype == OpType::BiasPlusOperand) { + bias_val += operand_val; + } else if (optype == OpType::BiasMinusOperand) { + bias_val -= operand_val; + } else if (optype == OpType::OperandMinusBias) { + bias_val = operand_val - bias_val; + } else { + LOG(FATAL) << "Should not get here."; + } + } +} + +void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + auto& bias = model->GetArray(bias_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>(); + float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>(); + const float* const operand_data = operand_buffer.data.data(); + + // We support broadcasting the operand along the depth dimension, + // when the operand's depth is 1. + int operand_channel_increment = 0; + if (operand_shape.dimensions_count() >= 1 && + operand_shape.dims(operand_shape.dimensions_count() - 1) == + bias_shape.dims(bias_shape.dimensions_count() - 1)) { + operand_channel_increment = 1; + } else if (operand_shape.dimensions_count() == 0 || + operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { + operand_channel_increment = 0; + } else { + LOG(FATAL) << "Operand shape mismatch."; + } + + int output_depth; + + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + output_depth = weights_shape.dims(0); + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); + } else { + LOG(FATAL) << "Should not get here"; + } + + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + int operand_channel = 0; + for (int c = 0; c < output_depth; c++) { + if (mul_or_div_op->type == OperatorType::kMul) { + bias_data[c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + bias_data[c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + for (int i = 0; i < weights_per_depth; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[c * weights_per_depth + i] *= + operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[c * weights_per_depth + i] /= + operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + for (int k = 0; k < weights_per_depth; k++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[k * output_depth + c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[k * output_depth + c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else { + LOG(FATAL) << "Should not get here"; + } + operand_channel += operand_channel_increment; + } +} +} // namespace + +bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + Operator* preceding_op = + GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]); + if (!preceding_op) { + AddMessageF("Not fusing %s because it is not the output of another op", + LogName(*binary_op)); + return false; + } + + for (const string& output_array : model->flags.output_arrays()) { + if (preceding_op->outputs[0] == output_array) { + return false; + } + } + + if (preceding_op->type != OperatorType::kConv && + preceding_op->type != OperatorType::kFullyConnected && + preceding_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the preceding %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing %s because the preceding %s has a fused activation " + "function", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the preceding %s does not have a bias vector", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + const auto& weights = model->GetArray(preceding_op->inputs[1]); + const auto& bias = model->GetArray(preceding_op->inputs[2]); + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (!bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has a non-constant bias " + "array", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } else { + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } + + int count_ops_consuming_output = + CountOpsWithInput(*model, preceding_op->outputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing %s because the output of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), + LogName(*preceding_op)); + + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + model->arrays.erase(preceding_op->outputs[0]); + preceding_op->outputs[0] = binary_op->outputs[0]; + preceding_op->fused_activation_function = + binary_op->fused_activation_function; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc new file mode 100644 index 0000000000..323fec6cf8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void PrintModelStats(const string& label, const Model& model) { + int quantized_arrays = 0; + for (const auto& array : model.arrays) { + if (array.second->quantization_params) { + quantized_arrays++; + } + } + LOG(INFO) << label << ": " << model.operators.size() << " operators, " + << model.arrays.size() << " arrays (" << quantized_arrays + << " quantized)"; +} + +bool GraphTransformationsPass(int increment, Model* model, + const GraphTransformationsSet& transformations) { + CHECK(increment == 1 || increment == -1); + bool changed = false; + CHECK(!model->operators.empty()); + int op_index = increment == 1 ? 0 : model->operators.size() - 1; + while (true) { + bool changed_now = false; + // Loop over all transformations at the current position in the graph. + for (const auto& transformation : transformations) { + CHECK(!changed_now); + CHECK(transformation->Messages().empty()); + changed_now = transformation->Run(model, op_index); + if (changed_now) { + DumpGraphvizVideoFrame(*model); + CHECK(!model->operators.empty()); + op_index = std::min<int>(op_index, model->operators.size() - 1); + // Uncomment for debugging + // CheckInvariants(*model); + } + const char* made_a_change_msg = + changed_now ? "made a change" : "did NOT make a change"; + const int log_level = + changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged; + for (const string& message : transformation->Messages()) { + VLOG(log_level) << transformation->Name() << " " << made_a_change_msg + << " at op_index=" << op_index << "/" + << model->operators.size() - 1 << ": " << message; + } + transformation->ClearMessages(); + if (changed_now) { + break; + } + } + if (changed_now) { + changed = true; + } else { + const int op_index_last = + increment == 1 ? model->operators.size() - 1 : 0; + if (op_index == op_index_last) { + break; + } + op_index += increment; + } + } + return changed; +} + +} // namespace + +void RunGraphTransformations(Model* model, const string& msg, + const GraphTransformationsSet& transformations) { + PrintModelStats(toco::port::StringF("Before %s", msg), *model); + int pass_index = 0; + while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, + transformations)) { + pass_index++; + const auto& label = + toco::port::StringF("After %s pass %d", msg, pass_index); + PrintModelStats(label, *model); + CheckInvariants(*model); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h new file mode 100644 index 0000000000..2cc24ff361 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ + +#include <cstddef> +#include <initializer_list> +#include <unordered_set> +#include <vector> + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" + +namespace toco { + +class GraphTransformation { + public: + virtual bool Run(Model* model, std::size_t op_index) = 0; + virtual const char* Name() const = 0; + virtual ~GraphTransformation() {} + // Returns the list of messages that this graph transformation + // generated since ClearMessages() was called. + const std::vector<string>& Messages() const { return messages_; } + // Clears the list of messages; should be called after every + // run of this graph transformation. + void ClearMessages() { return messages_.clear(); } + // Adds a message; normally only called by the graph transformation + // itself during its run (this function could be protected). + template <typename... Args> + void AddMessageF(const char* format, const Args&... args) { + return messages_.push_back(toco::port::StringF(format, args...)); + } + + protected: + GraphTransformation() {} + + // List of messages generated by this graph transformation. + std::vector<string> messages_; + + private: + GraphTransformation(const GraphTransformation& other) = delete; + GraphTransformation(const GraphTransformation&& other) = delete; +}; + +class GraphTransformationsSet { + public: + // The choice of a container with fully-specified iteration order + // ensures that graph transformations are always run in the same order, + // which avoids having toco randomly fail or produce different results + // depending on the toolchain. Ideally success/results should be independent + // of the order in which graph transformations are run, but that's + // unfortunately not currently guaranteed to be the case. + using TransformationsContainer = + std::vector<std::unique_ptr<GraphTransformation>>; + + GraphTransformationsSet() {} + GraphTransformationsSet( + const std::initializer_list<GraphTransformation*> transformations) { + for (GraphTransformation* t : transformations) { + Add(t); + } + } + void Add(GraphTransformation* transformation) { + const string& name = transformation->Name(); + CHECK(!names_.count(name)); + names_.insert(name); + transformations_.emplace_back(transformation); + } + TransformationsContainer::const_iterator begin() const { + return transformations_.begin(); + } + TransformationsContainer::const_iterator end() const { + return transformations_.end(); + } + bool empty() const { return transformations_.empty(); } + + private: + GraphTransformationsSet(const GraphTransformationsSet& other) = delete; + GraphTransformationsSet(const GraphTransformationsSet&& other) = delete; + std::vector<std::unique_ptr<GraphTransformation>> transformations_; + // Names of transformations in the set. Only used to guard against dupes. + std::unordered_set<string> names_; +}; + +// Run the given list of graph transformations on the model. +// The message is only for logging purposes. +// The transformations is a rvalue reference, indicating that +// nothing else will use these pointers. The user is supposed to +// construct GraphTransformation objects by using 'new', pass us +// the resulting raw pointers, and this RunGraphTransformations +// takes care of delete'ing these pointers. +void RunGraphTransformations(Model* model, const string& message, + const GraphTransformationsSet& transformations); + +#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ + class GTName : public GraphTransformation { \ + public: \ + bool Run(Model* model, std::size_t op_index) override; \ + const char* Name() const { return #GTName; } \ + }; + +// List of all graph transformations +DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) +DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) +DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) +DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) +DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) +DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) +DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) +DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) +DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) +DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) +DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape) +DECLARE_GRAPH_TRANSFORMATION(Dequantize) + +class ResolveReshapeAttributes : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "ResolveReshapeAttributes"; } +}; + +class RemoveTrivialReshape : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "RemoveTrivialReshape"; } + bool treat_expand_dims_as_trivial() const { + return treat_expand_dims_as_trivial_; + } + void set_treat_expand_dims_as_trivial(bool val) { + treat_expand_dims_as_trivial_ = val; + } + + private: + bool treat_expand_dims_as_trivial_ = false; +}; + +#undef DECLARE_GRAPH_TRANSFORMATION + +} // end namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc new file mode 100644 index 0000000000..d44b5dc7b0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -0,0 +1,229 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) { + if (op->outputs.size() != 2) { + return false; + } + auto& im2col_array = model->GetArray(op->outputs[1]); + if (im2col_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!im2col_array.minmax); + auto& im2col_minmax = im2col_array.GetOrCreateMinMax(); + im2col_minmax.min = input_minmax.min; + im2col_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min >= 0. ? 0. : -1.; + output_minmax.max = input_minmax.max <= 0. ? 0. : 1.; + return true; +} + +bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { + // Do not early return if the output already has min/max: + // we may still need to adjust the inputs min/max. + bool has_minmax = false; + double overall_min = std::numeric_limits<double>::infinity(); + double overall_max = -std::numeric_limits<double>::infinity(); + for (const auto& input : op->inputs) { + if (model->GetArray(input).minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(input).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + } + auto& output = model->GetArray(op->outputs[0]); + if (output.minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(op->outputs[0]).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + if (!has_minmax) { + return false; + } + MinMax overall_minmax; + overall_minmax.min = overall_min; + overall_minmax.max = overall_max; + bool changed = false; + for (const auto& input : op->inputs) { + auto& array = model->GetArray(input); + if (!array.minmax) { + changed = true; + } else if (!(overall_minmax == array.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of array " << input << ", which is " + << "an input to " << LogName(*op) << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no " + "arithmetic."; + } + array.GetOrCreateMinMax() = overall_minmax; + } + if (!output.minmax) { + changed = true; + } else if (!(overall_minmax == output.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of the output array of " << LogName(*op) + << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no arithmetic."; + } + output.GetOrCreateMinMax() = overall_minmax; + + return changed; +} + +// The output of average or max pooling is within the same range as its input. +bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = std::min(input_minmax.min, 0.); + output_minmax.max = std::max(input_minmax.max, 0.); + return true; +} + +bool HardcodeMinMaxForReshape(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, + double max) { + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = min; + output_minmax.max = max; + return true; +} +} // namespace + +bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + bool changed = false; + switch (op->type) { + case OperatorType::kConv: + changed = HardcodeMinMaxForIm2colArray(model, op); + break; + + case OperatorType::kL2Normalization: + changed = HardcodeMinMaxForL2Normalization(model, op); + break; + + case OperatorType::kConcatenation: + changed = HardcodeMinMaxForConcatenation(model, op); + break; + + case OperatorType::kAveragePool: + case OperatorType::kMaxPool: + changed = HardcodeMinMaxForAverageOrMaxPool(model, op); + break; + + case OperatorType::kTensorFlowReshape: + changed = HardcodeMinMaxForReshape(model, op); + break; + + case OperatorType::kLogistic: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + case OperatorType::kSoftmax: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + default: + break; + } + if (changed) { + AddMessageF("Hardcoded min-max through %s", LogName(*op)); + } + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc new file mode 100644 index 0000000000..01b75e37c6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <cmath> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector<std::unique_ptr<Operator>>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { + const auto div_it = model->operators.begin() + op_index; + const auto* div_or_mul_op = div_it->get(); + OperatorType expected_op_type_producing_div_or_mul_input; + if (div_or_mul_op->type == OperatorType::kDiv) { + expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + } else if (div_or_mul_op->type == OperatorType::kMul) { + expected_op_type_producing_div_or_mul_input = + OperatorType::kTensorFlowRsqrt; + } else { + return false; + } + CHECK_EQ(div_or_mul_op->inputs.size(), 2); + Operator* op_producing_div_or_mul_input[2] = { + GetOpWithOutput(*model, div_or_mul_op->inputs[0]), + GetOpWithOutput(*model, div_or_mul_op->inputs[1]), + }; + if (!op_producing_div_or_mul_input[1] || + op_producing_div_or_mul_input[1]->type != + expected_op_type_producing_div_or_mul_input) { + return false; + } + Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; + CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); + Operator* op_producing_sqrt_or_rsqrt_input = + GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); + if (!op_producing_sqrt_or_rsqrt_input) { + return false; + } + + // There may be an Add or a Maximum here, adding or clamping to a "small" + // constant scalar. + // Reported bug: b/29395854 + Operator* add_op = nullptr; + Operator* op_producing_add_input = nullptr; + if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || + op_producing_sqrt_or_rsqrt_input->type == + OperatorType::kTensorFlowMaximum) { + add_op = op_producing_sqrt_or_rsqrt_input; + bool add_can_be_removed = false; + CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); + for (int i = 0; i < 2; i++) { + const auto& input_array = + model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]); + if (!input_array.buffer) { + continue; + } + if (input_array.buffer->type != ArrayDataType::kFloat) { + continue; + } + if (RequiredBufferSizeForShape(input_array.shape()) != 1) { + continue; + } + const auto& input_float_data = + input_array.GetBuffer<ArrayDataType::kFloat>().data; + if (std::abs(input_float_data[0]) > 1e-3f) { + continue; + } + add_can_be_removed = true; + op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]); + break; + } + if (!add_can_be_removed) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph " + " because the operator producing the input to the square root, %s," + ", does not match the expected pattern", + LogName(*op_producing_sqrt_or_rsqrt_input)); + return false; + } + } + + Operator* sum_op = + add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; + if (sum_op->type != OperatorType::kTensorFlowSum) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Sum op, got %s", + LogName(*sum_op)); + return false; + } + + Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + CHECK_EQ(square_op->inputs.size(), 1); + + if (square_op->inputs[0] != div_or_mul_op->inputs[0]) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: %s does not " + "take the same input as the Mul/Div node", + LogName(*square_op)); + return false; + } + + // Create and emplace the new L2Normalization + auto* l2norm_op = new L2NormalizationOperator; + l2norm_op->inputs = {div_or_mul_op->inputs[0]}; + l2norm_op->outputs = div_or_mul_op->outputs; + model->operators.emplace(div_it, l2norm_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); + + // Erase the subgraph that is now replaced by L2Normalization + model->operators.erase(FindOperator(model, square_op)); + model->arrays.erase(sum_op->inputs[0]); + if (sum_op->inputs.size() > 1) { + model->arrays.erase(sum_op->inputs[1]); + } + model->operators.erase(FindOperator(model, sum_op)); + if (add_op) { + model->arrays.erase(add_op->inputs[0]); + model->arrays.erase(add_op->inputs[1]); + model->operators.erase(FindOperator(model, add_op)); + } + model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]); + model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); + model->arrays.erase(div_or_mul_op->inputs[1]); + model->operators.erase(FindOperator(model, div_or_mul_op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc new file mode 100644 index 0000000000..1865416fc2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector<std::unique_ptr<Operator>>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { + const auto sqrt_it = model->operators.begin() + op_index; + const auto* sqrt_op = sqrt_it->get(); + if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + return false; + } + + CHECK_EQ(sqrt_op->inputs.size(), 1); + CHECK_EQ(sqrt_op->outputs.size(), 1); + + const AveragePoolOperator* avpool_op; + const Operator* square_op; + + Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, got %s", + LogName(*prev_to_sqrt_op)); + return false; + } + + avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op); + CHECK_EQ(avpool_op->inputs.size(), 1); + + square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); + CHECK_EQ(square_op->inputs.size(), 1); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + // Create and emplace L2Pool node. + auto* l2pool_op = new L2PoolOperator; + + l2pool_op->inputs = {square_op->inputs[0]}; + l2pool_op->outputs = sqrt_op->outputs; + + l2pool_op->padding.type = avpool_op->padding.type; + // Note that we do not setup avpool_op->padding.fixed here. This is done by + // the PropagateFixedSizes graph transformation. + + l2pool_op->stride_height = avpool_op->stride_height; + l2pool_op->stride_width = avpool_op->stride_width; + l2pool_op->kheight = avpool_op->kheight; + l2pool_op->kwidth = avpool_op->kwidth; + model->operators.emplace(sqrt_it, l2pool_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); + + // Erase intermediate arrays, keeping input to square op. + model->arrays.erase(avpool_op->inputs[0]); + model->arrays.erase(sqrt_op->inputs[0]); + + // Erase three operators being replaced. + model->operators.erase(FindOperator(model, square_op)); + model->operators.erase(FindOperator(model, avpool_op)); + model->operators.erase(FindOperator(model, sqrt_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc new file mode 100644 index 0000000000..082820fddc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -0,0 +1,396 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +std::vector<std::unique_ptr<Operator>>::iterator FindOperator( + Model* model, const Operator& op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == &op) { + break; + } + } + return it; +} + +bool GetStateArrayForBackEdge(const Model& model, + const string& back_edge_source_array, + string* state_array = nullptr) { + for (const auto& rnn_state : model.flags.rnn_states()) { + if (back_edge_source_array == rnn_state.back_edge_source_array()) { + // Found LSTM cell output + if (state_array) { + *state_array = rnn_state.state_array(); + } + return true; + } + } + return false; +} + +// Returns true if the given operator has exactly 1 input, and is connected to +// the given op_type. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType op_type, Operator** connected_op) { + // Check for required number of inputs + if (op.inputs.size() != 1) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (connected_op) { + *connected_op = x; + } + + return true; +} + +// Returns true if the given operator has exactly 2 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op) { + // Check for required number of inputs + if (op.inputs.size() != 2) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + return true; +} + +// Returns true if the given operator has exactly 3 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op, + OperatorType c_op_type, Operator** c_op) { + // Check for required number of inputs + if (op.inputs.size() != 3) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Check if third input is disconnected/connected to an operator + Operator* z = GetOpWithOutput(model, op.inputs[2]); + if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { + return false; + } + if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { + return false; + } + + // Check that third operator, if connected, is of correct type + if ((z != nullptr) && (z->type != c_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + if (c_op != nullptr) { + *c_op = z; + } + return true; +} + +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + size_t count = 0; + const ssize_t limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} + +} // namespace + +bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { + // This LSTM cell identification method is not invariant to commutation of + // commutative operator inputs. For example, if input[0] and input[1] of the + // final output multiplication were swapped, this method would not identify it + // as an LSTM cell. This is OK in most cases, because + // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way. + + // Final output multiply + auto op_it = model->operators.begin() + op_index; + Operator* final_output_mul = op_it->get(); + if (final_output_mul->type != OperatorType::kMul) { + return false; + } + Operator *state_output_tanh, *fc_output_sig; + if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, + &state_output_tanh, OperatorType::kLogistic, + &fc_output_sig)) { + return false; + } + + // State output TanH + // (We don't count an operator as ID'd until we verify it has the correct + // operator types feeding into it.) + Operator* state_combine_add; + if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, + &state_combine_add)) { + return false; + } + string prev_state; + if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], + &prev_state)) { + return false; + } + + // State forget & remember addition + Operator *state_forget_mul, *state_remember_mul; + if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, + &state_forget_mul, OperatorType::kMul, + &state_remember_mul)) { + return false; + } + if (state_forget_mul->inputs[0] != prev_state) { + return false; + } + + // State forget gate + Operator* state_forget_sig; + if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, + nullptr, OperatorType::kLogistic, + &state_forget_sig)) { + return false; + } + + // State remember gate + Operator *state_remember_sig, *state_info_tanh; + if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, + &state_remember_sig, OperatorType::kTanh, + &state_info_tanh)) { + return false; + } + + // State remember "information" activation function + Operator* fc_output_split; + if (!MatchOperatorInputs(*state_info_tanh, *model, + OperatorType::kTensorFlowSplit, &fc_output_split)) { + return false; + } + // State remember gate activation function + Operator* tmp; + if (!MatchOperatorInputs(*state_remember_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // State forget gate activation function + if (!MatchOperatorInputs(*state_forget_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output activation function + if (!MatchOperatorInputs(*fc_output_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output split + Operator* fully_connected; + if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, + nullptr, OperatorType::kFullyConnected, + &fully_connected)) { + return false; + } + + // Fully connected op + Operator* concat_inputs; + if (!MatchOperatorInputs(*fully_connected, *model, + OperatorType::kConcatenation, &concat_inputs, + OperatorType::kNone, nullptr, OperatorType::kNone, + nullptr)) { + return false; + } + + // Emplace a new LSTM cell operator + auto* lstm_cell_op = new LstmCellOperator; + lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); + lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0]; + lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = + concat_inputs->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = + fully_connected->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = + fully_connected->inputs[2]; + lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state; + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = + state_output_tanh->inputs[0]; + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = + final_output_mul->outputs[0]; + model->operators.emplace(op_it, lstm_cell_op); + AddMessageF("Creating %s replacing equivalent subgraph", + LogName(*lstm_cell_op)); + + // Create temp arrays used internally during runtime. + const string base_name(FindLongestCommonPrefix( + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT], + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); + const string& concat_temp_array_name = + AvailableArrayName(*model, base_name + "concat_temp"); + model->GetOrCreateArray(concat_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; + const string& activ_temp_array_name = + AvailableArrayName(*model, base_name + "activ_temp"); + model->GetOrCreateArray(activ_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; + AddMessageF("Created temp outputs %s and %s on operator %s", + concat_temp_array_name, activ_temp_array_name, + LogName(*lstm_cell_op)); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + model->operators.erase(FindOperator(model, *final_output_mul)); + DeleteArrayIfUnused(state_output_tanh->outputs[0], model); + DeleteArrayIfUnused(fc_output_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_output_tanh)); + model->operators.erase(FindOperator(model, *fc_output_sig)); + model->operators.erase(FindOperator(model, *state_combine_add)); + DeleteArrayIfUnused(state_forget_mul->outputs[0], model); + DeleteArrayIfUnused(state_remember_mul->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_mul)); + model->operators.erase(FindOperator(model, *state_remember_mul)); + DeleteArrayIfUnused(state_forget_sig->outputs[0], model); + DeleteArrayIfUnused(state_info_tanh->outputs[0], model); + DeleteArrayIfUnused(state_remember_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_sig)); + model->operators.erase(FindOperator(model, *state_info_tanh)); + model->operators.erase(FindOperator(model, *state_remember_sig)); + DeleteArrayIfUnused(fc_output_split->outputs[0], model); + DeleteArrayIfUnused(fc_output_split->outputs[1], model); + DeleteArrayIfUnused(fc_output_split->outputs[2], model); + DeleteArrayIfUnused(fc_output_split->outputs[3], model); + string dims_array = fc_output_split->inputs[0]; + model->operators.erase(FindOperator(model, *fc_output_split)); + DeleteArrayIfUnused(dims_array, model); + DeleteArrayIfUnused(fully_connected->outputs[0], model); + model->operators.erase(FindOperator(model, *fully_connected)); + DeleteArrayIfUnused(concat_inputs->outputs[0], model); + model->operators.erase(FindOperator(model, *concat_inputs)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc new file mode 100644 index 0000000000..cfc77024e7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector<std::unique_ptr<Operator>>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} + +bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) { + const auto& op_array = model->GetArray(name); + if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat || + RequiredBufferSizeForShape(op_array.shape()) != 1) { + return false; + } + const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data; + return op_data[0] == val; +} + +// Returns index of scalar input when there is exactly one scalar, -1 otherwise +int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, + float val) { + bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val); + bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val); + return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1; +} +} // namespace + +bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { + const auto maximum_it = model->operators.begin() + op_index; + const auto* maximum_op = maximum_it->get(); + if (maximum_op->type != OperatorType::kTensorFlowMaximum) { + return false; + } + CHECK_EQ(maximum_op->inputs.size(), 2); + if (maximum_op->outputs.size() != 1) { + return false; + } + int scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f); + if (scalar_input_index == -1) { + return false; + } + const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]); + if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) { + return false; + } + if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) { + return false; + } + CHECK_EQ(minimum_op->inputs.size(), 2); + + // Create and emplace Relu1 node + auto* relu1_op = new Relu1Operator; + relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]}; + relu1_op->outputs = minimum_op->outputs; + model->operators.emplace(maximum_it, relu1_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); + + // Erase Maximum scalar input & operator + model->arrays.erase(maximum_op->inputs[scalar_input_index]); + model->operators.erase(FindOperator(model, maximum_op)); + + // Erase Minimum inputs & operator + model->arrays.erase(minimum_op->inputs[0]); + model->arrays.erase(minimum_op->inputs[1]); + model->operators.erase(FindOperator(model, minimum_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc new file mode 100644 index 0000000000..d83603e9a2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// This inserts an operator whose output is a float array (name: +// flags.input_array()). It has to wait for any existing operators that +// generate this output to be removed by graph transformations. Note that there +// may be more than one operator that takes the input_array as their input, and +// that some of these may be removed by graph transformations. +bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, + GraphTransformation* transformation, + Model* model) { + // An operator with the required output may be a dequantize operator already + // created. Alternatively it may be an operator that needs to be removed + // because it is unused, in which case we wait for RemoveUnusedOp to do its + // work. + if (GetOpWithOutput(*model, input_name)) { + return false; + } + + // We only apply for the first operator if there is more than one. This is + // not strictly necessary for ordering correctness, since we insert the + // dequant operator at the beginning of the op sequence, but it makes the + // insertion more predictable (eg forward vs backwards operator sweep). + if (CountOpsWithInput(*model, input_name) > 1) { + if (op != GetFirstOpWithInput(*model, input_name)) { + return false; + } + } + + auto& input_array = model->GetArray(input_name); + if (input_array.data_type != ArrayDataType::kFloat) { + return false; + } + + if (input_array.final_data_type == input_array.data_type || + input_array.final_data_type == ArrayDataType::kNone) { + return false; + } + + const auto& dequantized_input_name = + AvailableArrayName(*model, input_name + "_dequantized"); + for (auto& other_op : model->operators) { + for (string& other_op_input : other_op->inputs) { + if (other_op_input == input_name) { + other_op_input = dequantized_input_name; + } + } + } + + auto& dequantized_input_array = + model->GetOrCreateArray(dequantized_input_name); + auto* image_input_op = new DequantizeOperator; + image_input_op->inputs = {input_name}; + image_input_op->outputs = {dequantized_input_name}; + model->operators.emplace(model->operators.begin(), image_input_op); + + CHECK(input_array.final_data_type == ArrayDataType::kUint8); + input_array.data_type = ArrayDataType::kUint8; + dequantized_input_array.data_type = ArrayDataType::kFloat; + const auto& input_minmax = input_array.GetMinMax(); + auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax(); + dequantized_input_minmax = input_minmax; + auto& input_qparams = input_array.GetOrCreateQuantizationParams(); + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>( + model->flags, input_minmax, &input_qparams); + + transformation->AddMessageF( + "Created %s" + " to handle quantized input image data, taking over existing" + " mean_value and std_value flags. Cleared those flags.", + LogName(*image_input_op)); + + return true; +} + +bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) { + // This is effectively a transformation applied to edges. We iterate over the + // specified node (op) and proceed for input edges. + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + bool change_made = false; + for (auto& input : op->inputs) { + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == input) { + if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) { + change_made = true; + input_array.clear_mean_value(); + input_array.clear_std_value(); + } + } + } + } + return change_made; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc new file mode 100644 index 0000000000..1ff4e827aa --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +ArrayDataType CommonDataTypeOfAllInputs(const Model& model, + const Operator& op) { + CHECK_GT(op.inputs.size(), 0); + const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type; + for (const auto& input : op.inputs) { + const auto& array = model.GetArray(input); + CHECK(array.data_type == data_type) + << " Unexpected: this operator has inputs with different data types."; + } + return data_type; +} + +void SetDataTypeForAllOutputs(Model* model, Operator* op, + ArrayDataType data_type) { + for (const auto& output : op->outputs) { + model->arrays[output]->data_type = data_type; + } +} +} // namespace + +bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If the data type of some input is unknown, we need to yield. + for (const auto& input : op->inputs) { + if (model->arrays[input]->data_type == ArrayDataType::kNone) { + return false; + } + } + // Record data types of output before processing, so we can see at the + // end if we changed anything, and return the correct boolean value. + std::unordered_map<string, ArrayDataType> old_output_data_types; + for (const auto& output : op->outputs) { + old_output_data_types[output] = model->arrays[output]->data_type; + } + // Do the actual output data types propagation. + if (op->type == OperatorType::kDequantize || + op->type == OperatorType::kResizeBilinear) { + // These operators unconditionally produce float outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); + } else if (op->type == OperatorType::kTensorFlowLess || + op->type == OperatorType::kTensorFlowLessEqual || + op->type == OperatorType::kTensorFlowGreater || + op->type == OperatorType::kTensorFlowGreaterEqual) { + // These operators unconditionally produce bool outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); + } else if (op->type == OperatorType::kTensorFlowShape) { + // These operators are assumed to produce int32 outputs. + SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); + } else if (op->type == OperatorType::kAveragePool || + op->type == OperatorType::kMaxPool || + op->type == OperatorType::kL2Pool || + op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected || + op->type == OperatorType::kTensorFlowMax || + op->type == OperatorType::kTensorFlowMin || + op->type == OperatorType::kPad || + op->type == OperatorType::kStridedSlice || + op->type == OperatorType::kTensorFlowReshape || + op->type == OperatorType::kSlice || + op->type == OperatorType::kSqueeze || + op->type == OperatorType::kTensorFlowSum || + op->type == OperatorType::kTensorFlowSwitch || + op->type == OperatorType::kTensorFlowTile || + op->type == OperatorType::kTensorFlowAll || + op->type == OperatorType::kReorderAxes || + op->type == OperatorType::kTensorFlowConcatV2 || + op->type == OperatorType::kFloor || + op->type == OperatorType::kGather || + op->type == OperatorType::kSpaceToBatchND || + op->type == OperatorType::kBatchToSpaceND || + op->type == OperatorType::kMean) { + // These operators produce outputs with the same type as their 1st input + CHECK_GT(op->inputs.size(), 0); + const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kTensorFlowSplit || + op->type == OperatorType::kTensorFlowConcat) { + // These operators produce an output with the same type as their 2nd input + CHECK_GT(op->inputs.size(), 1); + const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kCast) { + // Data type of the Cast op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* cast_op = static_cast<CastOperator*>(op); + model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + } else if (op->type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); + if (unsupported_op->output_data_types.size() != op->outputs.size()) { + return false; + } + for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { + auto output = op->outputs[i]; + auto data_type = unsupported_op->output_data_types[i]; + model->arrays[output]->data_type = data_type; + } + } else { + // These operators produce an output with the same type as any of their + // inputs, which must always have the same type. + const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op); + SetDataTypeForAllOutputs(model, op, data_type); + } + // Return true if any output data type changed, false if none changed. + for (const auto& output : op->outputs) { + if (old_output_data_types[output] != model->arrays[output]->data_type) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc new file mode 100644 index 0000000000..82a43bc2ce --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -0,0 +1,1129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, + int kheight, int stride_width, int stride_height, + PaddingType padding_type, Shape* output_shape, + FixedPadding* fixed_padding) { + const int input_width = input_shape.dims(2); + const int input_height = input_shape.dims(1); + const int batch = input_shape.dims(0); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride_height - kheight) / stride_height; + output_width = (input_width + stride_width - kwidth) / stride_width; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } else { + LOG(FATAL) << "Only supporting SAME or VALID padding"; + } + + fixed_padding->height = + ((output_height - 1) * stride_height + kheight - input_height) / 2; + fixed_padding->width = + ((output_width - 1) * stride_width + kwidth - input_width) / 2; + + // Actually had to debug a situation where those were negative due to bad + // propagation of placeholder -1 sizes in TensorFlowReshape. + CHECK_GT(output_width, 0); + CHECK_GT(output_height, 0); + output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); +} + +void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, + const Shape& input_shape2, + Array* output_array) { + const int size1 = RequiredBufferSizeForShape(input_shape1); + const int size2 = RequiredBufferSizeForShape(input_shape2); + if (size1 > size2) { + output_array->copy_shape(input_shape1); + } else if (size2 > size1) { + output_array->copy_shape(input_shape2); + } else { + CHECK_EQ(size1, size2); + const int dims1 = input_shape1.dimensions_count(); + const int dims2 = input_shape2.dimensions_count(); + if (dims1 >= dims2) { + output_array->copy_shape(input_shape1); + } else { + output_array->copy_shape(input_shape2); + } + } + CHECK(output_array->has_shape()); +} + +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.arrays.at(weights_name)->shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } else if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } else { + LOG(FATAL) << "Unhandled operator type"; + } +} + +bool EnsureBiasVectorShape(Model* model, Operator* op) { + const string& weights_name = op->inputs[1]; + const auto& weights_array = *model->arrays[weights_name]; + // Yield until weights shape has been resolved. + if (!weights_array.has_shape()) { + return false; + } + + if (op->inputs.size() < 3) { + return false; + } + auto& bias_array = *model->arrays[op->inputs[2]]; + if (bias_array.has_shape()) { + return true; + } + + const int output_depth = GetOutputDepthFromWeights(*model, *op); + bias_array.copy_shape(Shape({output_depth})); + + auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>(); + float_buffer.data.resize(output_depth, 0); + + return true; +} + +void ProcessConvOperator(Model* model, ConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + auto& output_array = model->GetArray(op->outputs[0]); + const int output_depth = weights_shape.dims(0); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + output_array.mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); + CHECK_EQ(output_array.shape().dimensions_count(), 4); + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const auto& output_shape = output_array.shape(); + const int input_depth = weights_shape.dims(3); + auto& im2col_array = *model->arrays[op->outputs[1]]; + im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), + output_shape.dims(2), + input_depth * kheight * kwidth}); + } +} + +void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int input_depth = input_shape.dims(3); + const int output_depth = weights_shape.dims(3); + // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops, + // instead it has to be inferred from the weights dims. However, once we are + // here, weights dims have already been converted to our own internal format, + // where the multiplier is no longer readily apparent. So instead we get it + // as the quotient of output and input depths. We only want to do that when + // depth_multiplier had the zero value: any other value should be checked + // as done by the next if() below. + if (!op->depth_multiplier) { + op->depth_multiplier = output_depth / input_depth; + } + QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + << "input/output depths and depth_multiplier don't match"; + + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(depth % (block_size * block_size), 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height * block_size, width * block_size, + depth / block_size / block_size})); +} + +void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(width % block_size, 0); + QCHECK_EQ(height % block_size, 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height / block_size, width / block_size, + depth * block_size * block_size})); +} + +void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 1); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + + const int weights_output_depth = weights_shape.dims(0); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const int input_overall_size = RequiredBufferSizeForShape(input_shape); + const int matmul_repeats = input_overall_size / weights_shape.dims(1); + CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size); + + auto& output_array = model->GetArray(op->outputs[0]); + output_array.copy_shape(Shape({matmul_repeats, weights_output_depth})); +} + +void ProcessTensorFlowReshapeOperator(Model* model, + TensorFlowReshapeOperator* op) { + auto& output_array = *model->arrays[op->outputs[0]]; + // Bail if we already have output dims + if (output_array.has_shape()) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + + const string& shape_name = op->inputs[1]; + auto& shape_array = model->GetArray(shape_name); + // Yield until the shape is resolved as a constant array + if (!shape_array.buffer) { + return; + } + CHECK(shape_array.data_type == ArrayDataType::kInt32); + // shape_data is the raw array of ints describing the shape + // in the TensorFlow node. We intentionally make a copy here, rather than + // modify wildcards in-place below, because in some graphs, the same shape + // array with a wildcard may be referenced from multiple Reshape nodes, where + // the wildcard needs to resolved to distinct values. + std::vector<int32> shape_data = + shape_array.GetBuffer<ArrayDataType::kInt32>().data; + // The Reshape shape may have a wildcard dim, encoded as -1. + bool has_wildcard = false; + int wildcard_index = 0; + int product_non_wildcard_dims = 1; + for (int i = 0; i < shape_data.size(); i++) { + if (shape_data[i] == -1) { + CHECK(!has_wildcard); + has_wildcard = true; + wildcard_index = i; + } else { + product_non_wildcard_dims *= shape_data[i]; + } + } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + if (has_wildcard) { + shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; + } + auto& output_shape = *output_array.mutable_shape(); + *output_shape.mutable_dims() = shape_data; + const int output_flat_size = RequiredBufferSizeForShape(output_shape); + CHECK_EQ(output_flat_size, input_flat_size); +} + +void ProcessSimpleOperator(Model* model, Operator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + if (output_array.has_shape()) { + return; + } + + output_array.copy_shape(input_array.shape()); +} + +void ProcessSimpleBinaryOperator(Model* model, Operator* op) { + CHECK_EQ(op->inputs.size(), 2); + const auto& input0_array = *model->arrays[op->inputs[0]]; + const auto& input1_array = *model->arrays[op->inputs[1]]; + // Yield until input dims have been resolved. + if (!input0_array.has_shape() || !input1_array.has_shape()) { + return; + } + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), + &output_array); +} + +void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { + CHECK_LE(op->inputs.size(), 2); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + return; + } + if (op->inputs.size() == 2) { + // There is a reduction_indices input. + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& reduction_array = *model->arrays[op->inputs[1]]; + if (!reduction_array.buffer) { + return; + } + if (!input_array.has_shape()) { + return; + } + auto& input_shape = input_array.shape(); + CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); + const auto& reduction_array_vals = + reduction_array.GetBuffer<ArrayDataType::kInt32>().data; + auto& output_dims = *output_array.mutable_shape()->mutable_dims(); + output_dims.clear(); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + bool is_reduction_dim = false; + for (int r : reduction_array_vals) { + if (i == r) { + is_reduction_dim = true; + } + } + if (!is_reduction_dim) { + output_dims.push_back(input_shape.dims(i)); + } + } + } else { + // No reduction_indices means complete reduction to a single scalar. + output_array.copy_shape(Shape({})); + } +} + +void ProcessSliceOperator(Model* model, SliceOperator* op) { + CHECK_EQ(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + + // Yield until the Slice params have been resolved. + if (op->begin.empty()) return; + + // Yield until input dims have been resolved. + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + const Shape& input_shape = input_array.shape(); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + CHECK_EQ(input_shape.dims().size(), op->size.size()); + CHECK_EQ(op->begin.size(), op->size.size()); + + std::vector<int> output_dims; + for (int i = 0; i < op->begin.size(); ++i) { + int size = op->size[i]; + if (size == -1) { + size = input_array.shape().dims(i) - op->begin[i]; + } + output_dims.push_back(size); + } + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const string& output_name = op->outputs[0]; + Shape* output_shape = model->GetArray(output_name).mutable_shape(); + ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order, + output_shape); +} + +void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { + // Yield until input dims have been resolved. + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + if (!input_array.has_shape()) { + return; + } + } + auto& output_array = model->GetArray(op->outputs[0]); + // Use 0 input as basis for output dimensions. + const auto& first_input_array = *model->arrays[op->inputs[0]]; + output_array.copy_shape(first_input_array.shape()); + // Determine the concat size, and enfore that all inputs have + // the same dimensions count. + int concat_size = 0; + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + CHECK(input_array.has_shape()); + if (input_array.shape().dimensions_count() == 0) { + continue; + } + CHECK_EQ(input_array.shape().dimensions_count(), + output_array.shape().dimensions_count()); + const std::vector<int>& input_dims = input_array.shape().dims(); + CHECK_LT(op->concat_dim, input_dims.size()); + concat_size += input_dims[op->concat_dim]; + } + // Write out the concat_size on the output array shape. + auto& output_shape = *output_array.mutable_shape(); + auto& output_dims = *output_shape.mutable_dims(); + CHECK_LT(op->concat_dim, output_shape.dimensions_count()); + output_dims[op->concat_dim] = concat_size; +} + +void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + const string& input_name = op->inputs[1]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const Shape& input_shape = input_array.shape(); + + // This code is slightly suspect. The TensorFlow docs say that the axis + // selection defaults to 0, but we are splitting across the final axis. + const int input_dims_count = input_shape.dimensions_count(); + const int input_depth = input_shape.dims(input_dims_count - 1); + CHECK_EQ(input_depth % op->num_split, 0); + const int split_depth = input_depth / op->num_split; + + Shape output_shape = input_shape; + (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth; + + CHECK_EQ(op->outputs.size(), op->num_split); + for (const auto& output : op->outputs) { + model->arrays[output]->copy_shape(output_shape); + } +} + +void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + if (input_shape.dimensions_count() < 4) { + LOG(FATAL) << "missing dimensions for " << input_name; + } + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + if (!model->arrays[op->inputs[0]]->has_shape() || + !model->arrays[op->inputs[1]]->has_shape()) { + return; + } + const auto& input_data_shape = model->arrays[op->inputs[0]]->shape(); + + const string& output_size_name = op->inputs[1]; + const auto& output_size_array = *model->arrays[output_size_name]; + CHECK(output_size_array.data_type == ArrayDataType::kInt32); + CHECK(output_size_array.has_shape()); + const auto& output_size_shape = output_size_array.shape(); + CHECK_EQ(output_size_shape.dimensions_count(), 1); + CHECK_EQ(output_size_shape.dims(0), 2); + std::vector<int32> output_shape = + output_size_array.GetBuffer<ArrayDataType::kInt32>().data; + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_data_shape.dims(0), output_shape[0], output_shape[1], + input_data_shape.dims(3)})); +} + +void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { + // I/O arrays should be allocated on creation of op. + QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); + QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + + const auto& input_array = + *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]]; + // Yield until all input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 2); + + const auto& prev_activ_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_activ_array.has_shape()) { + return; + } + const auto& prev_activ_shape = prev_activ_array.shape(); + CHECK_GE(prev_activ_shape.dimensions_count(), 2); + + const auto& weights_array = + *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const auto& bias_array = + *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]]; + // Yield until bias dims have been resolved. + if (!bias_array.has_shape()) { + return; + } + const auto& bias_shape = bias_array.shape(); + CHECK_GE(bias_shape.dimensions_count(), 1); + + const auto& prev_state_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_state_array.has_shape()) { + return; + } + const auto& prev_state_shape = prev_state_array.shape(); + CHECK_GE(prev_state_shape.dimensions_count(), 2); + + const int fc_output_depth = weights_shape.dims(0); + CHECK_EQ(fc_output_depth, bias_shape.dims(0)); + CHECK_EQ(fc_output_depth % 4, 0); + const int depth = fc_output_depth / 4; + + const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1); + const int fc_input_depth = weights_shape.dims(1); + CHECK_EQ(input_depth + depth, fc_input_depth); + Shape output_shape(input_shape); + (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth; + + // Set output dimensions + model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT]) + .copy_shape(output_shape); + model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .copy_shape(output_shape); + + Shape concat_temp_shape(input_shape); + (*concat_temp_shape + .mutable_dims())[concat_temp_shape.dimensions_count() - 1] = + fc_input_depth; + model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]) + .copy_shape(concat_temp_shape); + + Shape activ_temp_shape(input_shape); + (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] = + fc_output_depth; + model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]) + .copy_shape(activ_temp_shape); +} + +void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& paddings_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& paddings_array_shape = paddings_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(paddings_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension. + if (!paddings_array.buffer) { + return; + } + QCHECK(paddings_array.data_type == ArrayDataType::kInt32); + const auto& paddings_data = + paddings_array.GetBuffer<ArrayDataType::kInt32>().data; + int height_with_paddings = input_height + paddings_data[0] + paddings_data[1]; + int width_with_paddings = input_width + paddings_data[2] + paddings_data[3]; + QCHECK_EQ(height_with_paddings % block_height, 0); + QCHECK_EQ(width_with_paddings % block_width, 0); + int output_height = height_with_paddings / block_height; + int output_width = width_with_paddings / block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) * block_height * block_width, output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& crops_array_shape = crops_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(crops_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension. + if (!crops_array.buffer) { + return; + } + QCHECK(crops_array.data_type == ArrayDataType::kInt32); + const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data; + // We don't support crops now. + QCHECK_EQ(crops_data[0], 0); + QCHECK_EQ(crops_data[1], 0); + QCHECK_EQ(crops_data[2], 0); + QCHECK_EQ(crops_data[3], 0); + + QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0); + + int output_height = input_height * block_height; + int output_width = input_width * block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) / (block_height * block_width), output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessGatherOperator(Model* model, GatherOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& indices_array = *model->arrays[op->inputs[1]]; + auto& output_array = *model->arrays[op->outputs[0]]; + + // Bail if we already know the output shape. + if (output_array.has_shape()) { + return; + } + + // Yield until input dims have been resolved. + if (!input_array.has_shape() || !indices_array.has_shape()) { + return; + } + + const auto& input_shape = input_array.shape(); + const auto& indices_shape = indices_array.shape(); + QCHECK_GE(input_shape.dimensions_count(), 1); + op->input_rank = input_shape.dimensions_count(); + + // We only support 1-D indices. + QCHECK_EQ(indices_shape.dimensions_count(), 1); + + // Copy the input dimensions to the output except for dimension 0, + // where the dimension of indices_shape is used. + auto output_dims = output_array.mutable_shape()->mutable_dims(); + output_dims->push_back(indices_shape.dims(0)); + for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + output_dims->push_back(input_shape.dims(dim)); + } +} + +void ProcessPadOperator(Model* model, PadOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->left_padding.empty()) return; + CHECK_EQ(op->left_padding.size(), op->right_padding.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector<int>& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->left_padding.size(), dims.size()); + + for (int i = 0; i < op->left_padding.size(); ++i) { + dims[i] += op->left_padding[i] + op->right_padding[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessMeanOperator(Model* model, MeanOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + const std::vector<int>& indices = op->reduction_indices; + if (indices.empty()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector<int>& input_dims = input_array.shape().dims(); + std::vector<int> output_dims; + for (int i = 0; i < input_dims.size(); ++i) { + if (std::find(indices.begin(), indices.end(), i) == indices.end()) { + output_dims.push_back(input_dims[i]); + } + } + CHECK(!output_dims.empty()); + CHECK_EQ(output_dims.size(), 2); + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->start_indices.empty()) return; + CHECK_EQ(op->start_indices.size(), op->stop_indices.size()); + CHECK_EQ(op->start_indices.size(), op->strides.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector<int>& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->start_indices.size(), dims.size()); + + for (int i = 0; i < op->start_indices.size(); ++i) { + const int mask = 1 << i; + const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; + const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i] + : op->stop_indices[i]; + dims[i] = (stop - start) / op->strides[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector<int>& input_dims = input_array.shape().dims(); + std::vector<int> output_dims; + + for (int i = 0; i < input_dims.size(); ++i) { + if (input_dims[i] != 1 || + (!op->squeeze_dims.empty() && + std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) == + op->squeeze_dims.end())) { + output_dims.push_back(input_dims[i]); + } + } + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessSvdfOperator(Model* model, SvdfOperator* op) { + CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + + auto& weights_feature_array = *model->arrays[op->inputs[1]]; + if (!weights_feature_array.has_shape()) return; + + const auto& weights_time_array = *model->arrays[op->inputs[2]]; + if (!weights_time_array.has_shape()) return; + + const bool has_bias = (op->inputs.size() == 4); + if (has_bias) { + const auto& bias_array = *model->arrays[op->inputs[3]]; + if (!bias_array.has_shape()) return; + } + + const int batch_size = input_array.shape().dims()[0]; + const int num_units = weights_feature_array.shape().dims()[0]; + const int memory_size = weights_time_array.shape().dims()[1]; + + auto& state_array = model->GetArray(op->outputs[0]); + state_array.mutable_shape()->ReplaceDims( + {batch_size, memory_size * num_units}); + + auto& output_array = model->GetArray(op->outputs[1]); + output_array.mutable_shape()->ReplaceDims({batch_size, num_units}); +} +} // namespace + +bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + std::unordered_map<string, std::vector<int>> old_output_dims; + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape()) { + old_output_dims[output] = model->arrays[output]->shape().dims(); + } + } + + switch (op->type) { + case OperatorType::kBatchNormalization: + case OperatorType::kL2Normalization: + case OperatorType::kDequantize: + case OperatorType::kRelu: + case OperatorType::kRelu1: + case OperatorType::kRelu6: + case OperatorType::kSoftmax: + case OperatorType::kLogistic: + case OperatorType::kTanh: + case OperatorType::kLocalResponseNormalization: + case OperatorType::kTensorFlowIdentity: + case OperatorType::kFakeQuant: + case OperatorType::kTensorFlowRsqrt: + case OperatorType::kTensorFlowSqrt: + case OperatorType::kTensorFlowSquare: + case OperatorType::kTensorFlowAll: + case OperatorType::kTensorFlowAssert: + case OperatorType::kCast: + case OperatorType::kFloor: + ProcessSimpleOperator(model, op); + break; + case OperatorType::kGather: + ProcessGatherOperator(model, static_cast<GatherOperator*>(op)); + break; + + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: + case OperatorType::kDiv: + case OperatorType::kTensorFlowLess: + case OperatorType::kTensorFlowLessEqual: + case OperatorType::kTensorFlowGreater: + case OperatorType::kTensorFlowMaximum: + case OperatorType::kTensorFlowMinimum: + case OperatorType::kTensorFlowGreaterEqual: + ProcessSimpleBinaryOperator(model, op); + break; + case OperatorType::kConv: + ProcessConvOperator(model, static_cast<ConvOperator*>(op)); + break; + case OperatorType::kDepthwiseConv: + ProcessDepthwiseConvOperator(model, + static_cast<DepthwiseConvOperator*>(op)); + break; + case OperatorType::kDepthToSpace: + ProcessDepthToSpaceOperator(model, + static_cast<DepthToSpaceOperator*>(op)); + break; + case OperatorType::kSpaceToDepth: + ProcessSpaceToDepthOperator(model, + static_cast<SpaceToDepthOperator*>(op)); + break; + case OperatorType::kFullyConnected: + ProcessFullyConnectedOperator(model, + static_cast<FullyConnectedOperator*>(op)); + break; + case OperatorType::kTensorFlowReshape: + ProcessTensorFlowReshapeOperator( + model, static_cast<TensorFlowReshapeOperator*>(op)); + break; + case OperatorType::kAveragePool: + ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op)); + break; + case OperatorType::kMaxPool: + ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op)); + break; + case OperatorType::kL2Pool: + ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op)); + break; + case OperatorType::kTensorFlowMin: + case OperatorType::kTensorFlowMax: + case OperatorType::kTensorFlowSum: + ProcessTensorFlowReductionOperator(model, op); + break; + + case OperatorType::kSlice: + ProcessSliceOperator(model, static_cast<SliceOperator*>(op)); + break; + + case OperatorType::kTensorFlowTile: + // We don't currently implement the propagation of fixed sizes through + // a TensorFlow Tile. + // + // Fortunately, we don't need to: so far, we have only dealt with Tile + // or Slice ops in subgraphs that are identified as L2Normalization. + // See IdentifyL2Normalization. + break; + case OperatorType::kTensorFlowSwitch: + // We can't know the sizes of the outputs until we have resolved the + // predicate, and once we have resolved the predicate, the whole + // Switch node will get resolved away. + // See ResolveTensorFlowSwitch. + break; + case OperatorType::kTensorFlowMerge: + // No need to bother resolving TensorFlow Merge ops: other graph + // transformations will remove them anyway. + // See ResolveTensorFlowMerge. + break; + case OperatorType::kTensorFlowSplit: + ProcessTensorFlowSplitOperator(model, + static_cast<TensorFlowSplitOperator*>(op)); + break; + case OperatorType::kSqueeze: + ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op)); + break; + case OperatorType::kTensorFlowConcat: + case OperatorType::kTensorFlowConcatV2: + // Unimplemented, hopefully another graph transformation will + // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat + // will resolve this node to a DepthConcatenation, or else we have + // a more general non-depth concatenation that will hopefully be dropped, + // or else at the moment we will abort. + break; + case OperatorType::kTensorFlowShape: + // Unimplemented, hopefully another graph transformation will drop it or + // rewrite it. + break; + case OperatorType::kReorderAxes: + ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op)); + break; + case OperatorType::kConcatenation: + ProcessConcatenationOperator(model, + static_cast<ConcatenationOperator*>(op)); + break; + case OperatorType::kResizeBilinear: + ProcessResizeBilinearOperator(model, + static_cast<ResizeBilinearOperator*>(op)); + break; + case OperatorType::kLstmCell: + ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); + break; + case OperatorType::kTensorFlowMatMul: + // MatMul operators are converted to FullyConnected, after which their + // shapes are propagated. + break; + case OperatorType::kSpaceToBatchND: + ProcessSpaceToBatchNDOperator(model, + static_cast<SpaceToBatchNDOperator*>(op)); + break; + case OperatorType::kBatchToSpaceND: + ProcessBatchToSpaceNDOperator(model, + static_cast<BatchToSpaceNDOperator*>(op)); + break; + case OperatorType::kPad: + ProcessPadOperator(model, static_cast<PadOperator*>(op)); + break; + case OperatorType::kMean: + ProcessMeanOperator(model, static_cast<MeanOperator*>(op)); + break; + case OperatorType::kStridedSlice: + ProcessStridedSliceOperator(model, + static_cast<StridedSliceOperator*>(op)); + break; + case OperatorType::kTensorFlowUnsupported: + break; + case OperatorType::kSvdf: + ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op)); + break; + default: + // Unimplemented, another graph transformation should drop it. + LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); + } + + // Return true if any output dim changed, false if none changed. + // Assumption: no transformation clears an output shape, they only add shapes. + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape() && + (old_output_dims[output] != model->arrays[output]->shape().dims())) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc new file mode 100644 index 0000000000..5551755ea7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -0,0 +1,467 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <cmath> +#include <limits> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool SupportsQuantization(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op); + return unsupported->quantized; + } + return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv || + type == OperatorType::kFullyConnected || + type == OperatorType::kConcatenation || + type == OperatorType::kL2Normalization || type == OperatorType::kAdd || + type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || + type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kTensorFlowReshape || + type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || + type == OperatorType::kDepthToSpace; +} + +template <ArrayDataType A> +std::unique_ptr<GenericBuffer> QuantizeBuffer( + const GenericBuffer& buffer, + const QuantizationParams& quantization_params) { + const auto inverse_scale = 1. / quantization_params.scale; + CHECK(buffer.type == ArrayDataType::kFloat); + const auto& float_buffer = + static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer); + auto* quantized_buffer = new Buffer<A>; + quantized_buffer->data.resize(float_buffer.data.size()); + const auto qmin = static_cast<int32>(std::numeric_limits<DataType<A>>::min()); + const auto qmax = static_cast<int32>(std::numeric_limits<DataType<A>>::max()); + for (std::size_t i = 0; i < float_buffer.data.size(); i++) { + const float src_val = float_buffer.data[i]; + double scaled_val; // Astonishingly, using 'float' degrades accuracy just + // enough to make a few tests fail! + if (quantization_params.scale == 0) { + CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, " + << "so all its values should be 0."; + scaled_val = quantization_params.zero_point; + } else { + scaled_val = quantization_params.zero_point + inverse_scale * src_val; + } + const auto rounded_val = static_cast<int32>(std::round(scaled_val)); + const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val)); + quantized_buffer->data[i] = static_cast<DataType<A>>(clamped_val); + } + return std::unique_ptr<GenericBuffer>(quantized_buffer); +} + +template <ArrayDataType A> +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, + const QuantizationParams& quantization_params) { + auto& array = model->GetArray(name); + CHECK(array.data_type == ArrayDataType::kFloat); + CHECK(!array.quantization_params); + array.GetOrCreateQuantizationParams() = quantization_params; + if (array.buffer) { + array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params); + } + array.data_type = A; + transformation->AddMessageF("Quantized array %s", name); +} + +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params) { + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name, + quantization_params); + case ArrayDataType::kInt32: + return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name, + quantization_params); + default: + LOG(FATAL) << "Unhandled case."; + } +} + +const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + // Normally we should have a MinMax recorded on this Array, + // so we just use it. + if (array.minmax != nullptr) { + return *array.minmax; + } + + // We don't have a MinMax. That's bad news: we need + // the graph to provide MinMax info for all arrays in order + // for inference to reproduce faithfully the same quantization + // error as the training process had. + // + // But we still want to support a fallback for constant arrays, + // just using the plain min and max computed from array elements. + // We should hopefully never rely on that in production, as that + // will not give very good accuracy as that typically won't be + // exactly what the training process used. But it will be useful + // to allow easily trying out quantization even if the graph + // lacks some minmax information. + if (array.buffer != nullptr) { + LOG(WARNING) + << "Constant array " << array_name + << " lacks MinMax information. To make up for that, we will now compute" + << " the MinMax from actual array elements. That will result in" + << " quantization parameters that probably do not match whichever " + "arithmetic" + << " was used during training, and thus will probably be a cause of " + "poor" + << " inference accuracy."; + CHECK(array.buffer->type == ArrayDataType::kFloat); + const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data; + // We always want [min, max] to contain 0. + float min = 0.f; + float max = 0.f; + for (auto val : data) { + min = std::min(min, val); + max = std::max(max, val); + } + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = min; + minmax.max = max; + return minmax; + } + + LOG(FATAL) << "Array " << array_name + << " does not have MinMax information, " + "and is not a constant array. Cannot " + "proceed with quantization."; +} + +bool ChooseQuantizationForOperatorInput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t input_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& input = op.inputs[input_index]; + auto& array = model->GetArray(input); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (op.type == OperatorType::kConv || + op.type == OperatorType::kDepthwiseConv || + op.type == OperatorType::kFullyConnected) { + if (input_index == 2) { + // Quantization of bias vector. + // We need both of the mandatory inputs (input activations and weights) to + // have + // been already quantized. + const auto& input_activations = model->GetArray(op.inputs[0]); + const auto& input_weights = model->GetArray(op.inputs[1]); + if (!input_activations.quantization_params || + !input_weights.quantization_params) { + return false; + } + const auto input_activations_scale = + input_activations.quantization_params->scale; + const auto input_weights_scale = input_weights.quantization_params->scale; + quantization_params->scale = + input_activations_scale * input_weights_scale; + quantization_params->zero_point = 0; + *quantized_data_type = ArrayDataType::kInt32; + transformation->AddMessageF( + "Input array %s is a bias vector. Choosing quantization params " + "accordingly.", + input); + return true; + } + } + + const MinMax& minmax = GetOrComputeMinMax(model, input); + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax, + quantization_params); + transformation->AddMessageF( + "For input array %s with min=%g" + ", max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + input, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + *quantized_data_type = ArrayDataType::kUint8; + return true; +} + +bool IsExactlyRepresentable(double real_value, ArrayDataType data_type, + const QuantizationParams& quantization_params) { + const double scaled_value = + quantization_params.zero_point + real_value / quantization_params.scale; + const double fractional_scaled_value = + scaled_value - std::round(scaled_value); + if (std::abs(fractional_scaled_value) > 1e-12) { + return false; + } + const double rounded_scaled_value = std::round(scaled_value); + if (data_type == ArrayDataType::kUint8) { + if (rounded_scaled_value < 0 || rounded_scaled_value > 255) { + return false; + } + } + return true; +} + +bool ChooseHardcodedQuantizationForOperatorOutput( + const Operator& op, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + if (op.type == OperatorType::kL2Normalization) { + // L2Normalization has range: [-1, 1]. + // 0 should be exactly representable, as values will typically be centered + // around 0, with many values near 0. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 128; + quantization_params->scale = 1. / 128.; + CHECK( + IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); + return true; + } + if ((op.type == OperatorType::kLogistic) || + (op.type == OperatorType::kSoftmax)) { + // Logistic and Softmax have range: [0, 1]. + // + // For Logistic, 0.5 should be exactly representable, as implementations + // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and + // the glueing of the two halves of the graph will only be seamless if we + // are accurately representing logistic(0) == 0.5. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 0; + quantization_params->scale = 1. / 256.; + CHECK(IsExactlyRepresentable(0.5, *quantized_data_type, + *quantization_params)); + return true; + } + return false; +} + +bool ChooseQuantizationForOperatorOutput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t output_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& output = op.outputs[output_index]; + auto& array = model->GetArray(output); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type, + quantization_params)) { + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Choosing fixed " + "quantization params accordingly.", + output, OperatorTypeName(op.type)); + return true; + } + if ((op.type == OperatorType::kDepthToSpace) || + (op.type == OperatorType::kSpaceToDepth)) { + // DepthToSpace and SpaceToDepth should preserve the quantization parameters + // of the input array, as these are simple reshape operations. + const auto& input_quantization_params = + model->GetArray(op.inputs[0]).GetQuantizationParams(); + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = input_quantization_params.zero_point; + quantization_params->scale = input_quantization_params.scale; + + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Copying quantization " + "params from input array.", + output, OperatorTypeName(op.type)); + return true; + } + const MinMax& minmax = GetOrComputeMinMax(model, output); + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax, + quantization_params); + *quantized_data_type = ArrayDataType::kUint8; + transformation->AddMessageF( + "For output array %s with min=%g, max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + output, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + + return true; +} +} // namespace + +bool Quantize::Run(Model* model, std::size_t op_index) { + // Our general "quantization" graph transformation consists in replacing + // QuantizedInputArrays[] -> + // DequantizeOperators[] -> + // FloatInputArrays[] -> + // Operator -> + // FloatOutputArray + // by + // QuantizedInputArrays[] -> + // Operator -> + // QuantizedOutputArray -> + // DequantizeOperator -> + // FloatOutputArray + // + // In other words, this is pushing Dequantize operators to the right of + // other operators. + // + + auto& op = *model->operators[op_index]; + if (op.type == OperatorType::kDequantize || + op.type == OperatorType::kFakeQuant) { + return false; + } + + // Our assumption here is that the input arrays are already quantized - + // that is typically the case in models operating on an input bitmap + // image, and MakeInitialDequantizeOp should have already resolved + // the handling of the input image as an initial Dequantize op. + // + // Thus we are building around the assumption that the graph always starts + // with a quantized input array, and only after some Dequantize op do we have + // float arrays. The problem of quantizing the graph thus becomes a problem of + // pushing Dequantize ops to the right of other ops. + // + // Let us just guard this assumption by the following assertion: + for (const auto& input : op.inputs) { + if (IsInputArray(*model, input)) { + const auto& input_array = model->GetArray(input); + CHECK(input_array.quantization_params); + } + } + if (!SupportsQuantization(op)) { + LOG(FATAL) << "Unimplemented: this graph contains an operator of type " + << HelpfulOperatorTypeName(op) + << " for which the quantized form is not yet implemented. " + "Sorry, and patches welcome (that's a relatively fun patch " + "to write, mostly providing the actual quantized arithmetic " + "code for this op)."; + } + + for (const auto& input : op.inputs) { + const auto& array = model->GetArray(input); + if (array.data_type == ArrayDataType::kFloat) { + if (!array.minmax && !array.buffer) { + LOG(ERROR) << "Can't quantize input array " << input + << " because it lacks min/max info"; + return false; + } + const auto* other_op = GetOpWithOutput(*model, input); + if (other_op && other_op->type != OperatorType::kDequantize) { + AddMessageF( + "Not quantizing %s for now, because its input array %s is not " + "produced by a Dequantize op, " + "which means that we should yield and let other ops " + "get quantized first", + LogName(op), input); + return false; + } + } + } + + bool changed = false; + + // Quantize inputs, remove any Dequantize op on the inputs side + for (std::size_t input_index = 0; input_index < op.inputs.size(); + input_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorInput(this, model, op, input_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& input = op.inputs[input_index]; + if (IsConstantParameterArray(*model, input)) { + QuantizeArray(this, model, input, quantized_data_type, + quantization_params); + } else { + auto dequantize_it = FindOpWithOutput(*model, input); + CHECK(dequantize_it != model->operators.end()); + auto* dequantize_op = dequantize_it->get(); + CHECK(dequantize_op->type == OperatorType::kDequantize); + op.inputs[input_index] = dequantize_op->inputs[0]; + // Check if the output of that Dequantize op was not used by any + // other operator. We will then erase that Dequantize op. + if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + model->arrays.erase(dequantize_op->outputs[0]); + model->operators.erase(dequantize_it); + } + } + } + } + + // Quantize outputs, add Dequantize ops as needed on the outputs side + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + const auto& output_array = model->GetArray(output); + const auto& output_minmax = output_array.GetMinMax(); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } + } + } + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + model->flags.set_output_arrays(i, dequantized_output); + } + } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); + } + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc new file mode 100644 index 0000000000..371ced388a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, + const MinMax& minmax, const string& array_name) { + auto& annotated_array = model->GetArray(array_name); + if (annotated_array.minmax) { + return false; + } + annotated_array.GetOrCreateMinMax() = minmax; + transformation->AddMessageF( + "Read min/max annotation for array %s: min=%g, max=%g", array_name, + minmax.min, minmax.max); + return true; +} + +} // end namespace + +bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op); + + bool changed = false; + + if (!fakequant_op->minmax) { + CHECK_EQ(fakequant_op->inputs.size(), 3); + // We need to yield until the min and max parameters have been + // resolved to constant arrays. + for (int i = 1; i <= 2; i++) { + if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { + return false; + } + } + + // Obtain the final min/max values + const auto& min_array = model->GetArray(fakequant_op->inputs[1]); + const auto& max_array = model->GetArray(fakequant_op->inputs[2]); + CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); + CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); + fakequant_op->minmax.reset(new MinMax); + MinMax& minmax = *fakequant_op->minmax; + minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0]; + minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0]; + // We always want [min, max] to contain 0. + minmax.min = std::min(minmax.min, 0.); + minmax.max = std::max(minmax.max, 0.); + + // We won't use the input arrays that provided these min and max + // values, anymore. Delete them unless they are used by something + // else. + for (int i = 1; i <= 2; i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + changed = true; + } + + // At this point, this FakeQuantOperator should have a MinMax + // attached to it, and should only have 1 input (it should not have + // 2nd and 3rd input arrays giving min and max anymore). + CHECK(fakequant_op->minmax); + CHECK_EQ(1, fakequant_op->inputs.size()); + + const MinMax& minmax = *fakequant_op->minmax; + + // Record the MinMax info on the input and output arrays + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc new file mode 100644 index 0000000000..3992e7d1ef --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { + const auto dequantize_it = model->operators.begin() + op_index; + const auto* dequantize_op = dequantize_it->get(); + if (dequantize_op->type != OperatorType::kDequantize) { + return false; + } + const auto& output = dequantize_op->outputs[0]; + // We can remove any dequantize op whose output is not consumed by + // any op. This is not necessarily equivalent to the output being + // one of the model's output arrays, as some intermediate array + // in the middle of the graph might be designated as an output + // array. + if (CountOpsWithInput(*model, output)) { + return false; + } + + // If one of the model's output arrays was actually the Dequantize op's + // output, then we need to update it to point to the Dequantize op's input. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (output == model->flags.output_arrays(i)) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + + // Remove the node and its output array. + AddMessageF("Removed final %s", LogName(*dequantize_op)); + model->arrays.erase(output); + model->operators.erase(dequantize_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc new file mode 100644 index 0000000000..35a0c46532 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { + const auto assert_it = model->operators.begin() + op_index; + const auto* assert_op = assert_it->get(); + if (assert_op->type != OperatorType::kTensorFlowAssert) { + return false; + } + + bool changed = false; + // Remove any other node's dependency on this assert node + for (const auto& op : model->operators) { + auto it = op->inputs.begin(); + while (it != op->inputs.end()) { + if (*it == assert_op->outputs[0]) { + op->inputs.erase(it); + changed = true; + } else { + ++it; + } + } + } + CHECK(!CountOpsWithInput(*model, assert_op->outputs[0])); + + if (changed) { + AddMessageF( + "Prepared for the removal of %s by removing any other op's dependency " + "on it", + LogName(*assert_op)); + } + + // That's it. We can stop here, no need to duplicate the work that + // RemoveUnusedOp will do removing this now-unused node. + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc new file mode 100644 index 0000000000..404269bbfd --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + const auto* passthru_op = passthru_it->get(); + if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + return false; + } + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc new file mode 100644 index 0000000000..6add443f2d --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template <typename Scalar> +bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data, + Scalar value) { + for (auto x : buffer_data) { + if (x != value) { + return false; + } + } + return true; +} +} // namespace + +// A binary operator is called trivial when exactly one of its operands is +// a constant and is such that the binary operation is equivalent to +// the identity operation on its other input. +// For example, an Add operator is trivial if +// one of its operands is constant 0, a Mul operator is trivial +// if one of its operands is constant 1, etc. +bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // This graph transformation is only concerned with the case + // when one input is constant and the other is not constant. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can resolve here. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // Now check if the constant operand makes this binary + // operator trivial. + const auto& constant_input_array = + *model->arrays[binary_op->inputs[index_of_constant_input]]; + // For now, we only handle floats here. + if (constant_input_array.data_type != ArrayDataType::kFloat) { + return false; + } + const auto& constant_input_float_data = + constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; + bool is_trivial = false; + if (binary_op->type != OperatorType::kAdd) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kSub) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kMul) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } else if (binary_op->type != OperatorType::kDiv) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } + + if (!is_trivial) { + return false; + } + + // Now we know that this node is trivial, so we can remove it. + AddMessageF("Removing trivial %s", LogName(*binary_op)); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc new file mode 100644 index 0000000000..3ceb93d8ee --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + if (concat_op->inputs.size() != 1) { + return false; + } + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc new file mode 100644 index 0000000000..b603735704 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { + // TensorFlow allows Concatenation nodes to have 0-D inputs, + // and they are then treated as empty i.e. omitted from concatenation, + // in violation of the notion that 0-D is equivalent to 1x1x1x1. + // Thus we have to drop these 0-D inputs from Concatenation nodes. + // Sometimes, there will remain only one non-trivial input, and + // the other graph transformation RemoveTrivialConcatenation will then drop + // it. + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + std::vector<string> trivial_inputs; + std::vector<string> nontrivial_inputs; + for (const string& input : concat_op->inputs) { + const auto& input_array = model->GetArray(input); + const bool is_trivial = + input_array.has_shape() && input_array.shape().dimensions_count() == 0; + if (is_trivial) { + trivial_inputs.push_back(input); + } else { + nontrivial_inputs.push_back(input); + } + } + + if (trivial_inputs.empty()) { + return false; + } + + // Drop trivial inputs. + for (const string& input : trivial_inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + concat_op->inputs = nontrivial_inputs; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc new file mode 100644 index 0000000000..a0d1338298 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { +// Reroute all edges involving a given discardable array to another +// array instead. from_array is assumed to be discardable, and consequently +// this only updates operator edges (since discardable arrays only +// appear there, and not e.g. in model flags). +void RerouteEdges(const string& from_array, const string& to_array, + Model* model) { + for (const auto& op : model->operators) { + for (auto& output : op->outputs) { + if (output == from_array) { + output = to_array; + } + } + for (auto& input : op->inputs) { + if (input == from_array) { + input = to_array; + } + } + } +} + +} // end anonymous namespace + +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + auto* passthru_op = passthru_it->get(); + CHECK_EQ(passthru_op->outputs.size(), 1); + CHECK_GE(passthru_op->inputs.size(), 1); + int count_nonconstant_input_arrays = 0; + // We call 'main input' the unique nonconstant input array if there is one, + // or else the 0-th input. + int main_input_array_index = 0; + for (int i = 0; i < passthru_op->inputs.size(); i++) { + if (!model->GetArray(passthru_op->inputs[i]).buffer) { + count_nonconstant_input_arrays++; + main_input_array_index = i; + } + } + CHECK_LE(count_nonconstant_input_arrays, 1); + + const string main_input_name = passthru_op->inputs[main_input_array_index]; + const string output_name = passthru_op->outputs[0]; + if (IsDiscardableArray(*model, output_name)) { + transformation->AddMessageF( + "Removing %s, keeping its non-constant input array", + LogName(*passthru_op)); + model->arrays.erase(output_name); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && input != main_input_name && + CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + RerouteEdges(output_name, main_input_name, model); + } else if (IsDiscardableArray(*model, main_input_name)) { + transformation->AddMessageF("Removing %s, keeping its output array", + LogName(*passthru_op)); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && + (input == main_input_name || CountOpsWithInput(*model, input) == 1)) { + model->arrays.erase(input); + } + } + RerouteEdges(main_input_name, output_name, model); + } else { + transformation->AddMessageF( + "Cannot remove %s, neither its nonconstant input nor its output may be " + "discarded", + LogName(*passthru_op)); + return false; + } + + // Remove the pass-through node. + model->operators.erase(passthru_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h new file mode 100644 index 0000000000..b72c85c0e5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +// A "passthrough op" is an op that satisfies the following conditions: +// 1. It has at most one non-constant input (it may have other constant +// inputs). +// 2. It has exactly one output. +// 3. It forwards exactly its single non-constant input to its single output. +// +// Examples include: +// 1. TensorFlow Identity ops. (Have one input). +// 2. TensorFlow Reshape ops when the input and output shapes agree. +// 3. Any binary operator, one of whose two inputs is a constant and is the +// neutral value for that operation. For example, a binary Add operator +// where one of its inputs is a constant array filled with zeros. +// +// A passthrough op is "trivial" and can be removed when it is possible to +// discard either its single non-constant input or output array, rerouting any +// edge involving it to the other of these two arrays. +// +// It is only possible to discard such an array if it is not explicitly +// designated as a global input/output array of the graph, e.g. the model's +// input arrays, output arrays, and any array involved in a RNN back-edge +// specified by the model. +// +// This function does not check that the given operator is a passthrough op: +// that's the responsibility of the caller. +// Given that it is a passthrough op, this function checks whether it is trivial +// and then discards it and returns true, or, if it's not trivial (if neither +// the input nor the output may be discarded), returns false. +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc new file mode 100644 index 0000000000..28f76c9d36 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, + std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->fused_activation_function != FusedActivationFunctionType::kRelu && + op->fused_activation_function != FusedActivationFunctionType::kRelu6) { + return false; + } + const auto& output_array = model->GetArray(op->outputs[0]); + if (!output_array.quantization_params) { + return false; + } + if (output_array.data_type != ArrayDataType::kUint8) { + return false; + } + const auto& quantization_params = output_array.GetQuantizationParams(); + + bool has_nontrivial_min_bound = false; + bool has_nontrivial_max_bound = false; + + if (op->fused_activation_function == FusedActivationFunctionType::kRelu || + op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double lowest_representable_output = + (0. - quantization_params.zero_point) * quantization_params.scale; + if (lowest_representable_output < 0.) { + has_nontrivial_min_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the lowest representable output value %g" + " less than the clamp min bound.", + lowest_representable_output); + } + } + if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double highest_representable_output = + (255. - quantization_params.zero_point) * quantization_params.scale; + if (highest_representable_output > 6.) { + has_nontrivial_max_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the highest representable output value %g" + " is greater than the clamp max bound.", + highest_representable_output); + } + } + + if (has_nontrivial_min_bound || has_nontrivial_max_bound) { + return false; + } + + op->fused_activation_function = FusedActivationFunctionType::kNone; + AddMessageF( + "Removing trivial quantized activation function on %s" + " because the output quantization parameters imply at least as tight" + " a clamp anyway.", + LogName(*op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc new file mode 100644 index 0000000000..90f9381ec1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsReshapeTrivial(const Model& model, const Operator& op, + RemoveTrivialReshape* transformation) { + CHECK(op.type == OperatorType::kTensorFlowReshape); + + // One way in which a reshape can be trivial is if its + // output shape is == its input shape + const auto& input_array = model.GetArray(op.inputs[0]); + const auto& output_array = model.GetArray(op.outputs[0]); + if (input_array.has_shape() && output_array.has_shape()) { + if (transformation->treat_expand_dims_as_trivial() && + ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal up to " + "extending " + "by 1's, and we are told to aggressively discard such Reshape ops.", + LogName(op)); + return true; + } + if (input_array.shape().dims() == output_array.shape().dims()) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal", + LogName(op)); + return true; + } + } + + // Another way in which a reshape can be trivial is if its output + // is only consumed by another reshape. + if (CountOpsWithInput(model, op.outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(model, op.outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + transformation->AddMessageF( + "%s is trivial because its output is only consumed by another " + "Reshape op", + LogName(op)); + return true; + } + } + + return false; +} + +} // namespace + +bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + if (!IsReshapeTrivial(*model, *reshape_op, this)) { + return false; + } + + AddMessageF("Removing trivial %s", LogName(*reshape_op)); + + CHECK_EQ(reshape_op->inputs.size(), 2); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc new file mode 100644 index 0000000000..1f1f1f6948 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + + // Bail if any output is used, and is not an input_array of + // the model. We allow specifying an arbitrary input_array, + // treating the part of the graph leading up to it as unused. + for (const auto& output : op->outputs) { + CHECK(model->arrays.count(output)); + // If this output is provided as the model's input array, + // then we don't need this operator to produce its contents. + if (IsInputArray(*model, output)) { + continue; + } + // If this output is provided as a RNN's state array, + // then we don't need this operator to produce its contents. + // So far this case has only been encountered with TensorFlow + // Fill ops used to zero-initialize RNN states, which is + // redundant for us as we zero-initialize RNN states anyway. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + CHECK(op->type == OperatorType::kTensorFlowUnsupported); + CHECK_EQ(static_cast<const TensorFlowUnsupportedOperator*>(op) + ->tensorflow_op, + "Fill"); + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + for (const string& output_array : model->flags.output_arrays()) { + if (output == output_array) { + return false; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.back_edge_source_array()) { + return false; + } + } + if (CountOpsWithInput(*model, output)) { + return false; + } + } + + if (op->unresolved_outputs) { + AddMessageF("Not discarding %s because it has unresolved outputs.", + LogName(*op)); + return false; + } + + AddMessageF("Discarding %s because none of its outputs is used.", + LogName(*op)); + + // At that point we know that none of the outputs is used, so we will + // definitely remove the node and all its outputs. + + // Remove any input array that is not used by anything else, + // and that is not the output of some other operator. + for (const auto& input : op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + + // Remove the node and its now-unused output arrays. + for (const auto& output : op->outputs) { + // If the output array is the model's input array, don't remove that. + // That's the case when cropping a model at a given --input_array. + if (IsInputArray(*model, output)) { + continue; + } + // Likewise, if the output array is a RNN state array, don't remove that. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + // Generic case: do delete this output array. + model->arrays.erase(output); + } + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc new file mode 100644 index 0000000000..3eb7fa3896 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { + auto bn_it = model->operators.begin() + op_index; + if (bn_it->get()->type != OperatorType::kBatchNormalization) { + return false; + } + const auto* bn_op = + static_cast<const BatchNormalizationOperator*>(bn_it->get()); + + const auto& mean_array = model->GetArray(bn_op->inputs[1]); + const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); + const auto& offset_array = model->GetArray(bn_op->inputs[3]); + + CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && + IsConstantParameterArray(*model, bn_op->inputs[2]) && + IsConstantParameterArray(*model, bn_op->inputs[3])) + << "Batch normalization resolution requires that mean, multiplier and " + "offset arrays be constant."; + + // We should only have *float* BatchNormalizations... let's guard this + // assumption by CHECK's. + CHECK(mean_array.data_type == ArrayDataType::kFloat); + CHECK(multiplier_array.data_type == ArrayDataType::kFloat); + CHECK(offset_array.data_type == ArrayDataType::kFloat); + + // Create the new Mul, Add operators + auto* mul_op = new MulOperator; + auto* add_op = new AddOperator; + const string mul_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); + const string add_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_add"); + const string mul_param_name = AvailableArrayName(*model, mul_name + "_param"); + const string add_param_name = AvailableArrayName(*model, add_name + "_param"); + mul_op->inputs = {bn_op->inputs[0], mul_param_name}; + mul_op->outputs = {mul_name}; + add_op->inputs = {mul_name, add_param_name}; + add_op->outputs = {bn_op->outputs[0]}; + AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op), + LogName(*add_op)); + + // Create the intermediate activation array (output of mul, input of add) + auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]); + intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type; + + // Insert the new operators in the graph + auto add_it = model->operators.emplace(bn_it, add_op); + auto mul_it = model->operators.emplace(add_it, mul_op); + // update invalidated iterators. + DCHECK_EQ(mul_it->get(), mul_op); + add_it = mul_it + 1; + DCHECK_EQ(add_it->get(), add_op); + bn_it = add_it + 1; + DCHECK_EQ(bn_it->get(), bn_op); + + // Create the new param arrays + const auto& mean_shape = mean_array.shape(); + const auto& multiplier_shape = multiplier_array.shape(); + const auto& offset_shape = offset_array.shape(); + CHECK(mean_shape.dims() == multiplier_shape.dims()); + CHECK(mean_shape.dims() == offset_shape.dims()); + const auto& param_shape = mean_shape; + const int buffer_size = RequiredBufferSizeForShape(param_shape); + auto& mul_param_array = model->GetOrCreateArray(mul_param_name); + auto& add_param_array = model->GetOrCreateArray(add_param_name); + DropMinMax(model, mul_param_name); + DropMinMax(model, add_param_name); + mul_param_array.copy_shape(param_shape); + add_param_array.copy_shape(param_shape); + mul_param_array.data_type = ArrayDataType::kFloat; + add_param_array.data_type = ArrayDataType::kFloat; + auto& mul_float_data = + mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; + auto& add_float_data = + add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; + mul_float_data.resize(buffer_size); + add_float_data.resize(buffer_size); + const auto& mean_float_data = + mean_array.GetBuffer<ArrayDataType::kFloat>().data; + const auto& multiplier_float_data = + multiplier_array.GetBuffer<ArrayDataType::kFloat>().data; + const auto& offset_float_data = + offset_array.GetBuffer<ArrayDataType::kFloat>().data; + + CHECK(mul_float_data.size() == buffer_size); + CHECK(add_float_data.size() == buffer_size); + CHECK(mean_float_data.size() == buffer_size); + CHECK(multiplier_float_data.size() == buffer_size); + CHECK(offset_float_data.size() == buffer_size); + + for (int i = 0; i < buffer_size; i++) { + mul_float_data[i] = multiplier_float_data[i]; + add_float_data[i] = + offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; + } + + // Remove the old param arrays + model->arrays.erase(bn_op->inputs[1]); + model->arrays.erase(bn_op->inputs[2]); + model->arrays.erase(bn_op->inputs[3]); + + // Remove the old operator + DCHECK_EQ(bn_it->get(), bn_op); + model->operators.erase(bn_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc new file mode 100644 index 0000000000..53e1be7a05 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector<bool> VectorGreaterThan(const std::vector<int>& a, + const std::vector<int>& b) { + DCHECK_EQ(a.size(), b.size()); + const int size = a.size(); + std::vector<bool> result(size); + for (int i = 0; i < size; i++) { + result[i] = a[i] > b[i]; + } + return result; +} + +void PairwiseVectorSelect(const std::vector<bool>& selector, + const std::vector<int>& input_a, + const std::vector<int>& input_b, + std::vector<int>* output_a, + std::vector<int>* output_b) { + DCHECK_EQ(input_a.size(), input_b.size()); + DCHECK_EQ(output_a->size(), output_b->size()); + DCHECK_EQ(input_a.size(), output_a->size()); + DCHECK_EQ(selector.size(), input_a.size()); + const int size = input_a.size(); + for (int i = 0; i < size; i++) { + if (selector[i]) { + (*output_a)[i] = input_a[i]; + (*output_b)[i] = input_b[i]; + } else { + (*output_a)[i] = input_b[i]; + (*output_b)[i] = input_a[i]; + } + } +} + +template <ArrayDataType InputsDataType, ArrayDataType OutputDataType> +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + CHECK(IsConstantParameterArray(*model, binary_op->inputs[0])); + CHECK(IsConstantParameterArray(*model, binary_op->inputs[1])); + CHECK(binary_op->fused_activation_function == + FusedActivationFunctionType::kNone); + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + const auto& output_name = binary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + CHECK(input0_array.data_type == InputsDataType); + CHECK(input1_array.data_type == InputsDataType); + CHECK(output_array.data_type == OutputDataType); + + // We have already tested above for existence of input buffers + // (synonymous to being a constant param). + CHECK(input0_array.buffer); + CHECK(input1_array.buffer); + // On the other hand, the output should not already have a buffer. + CHECK(!output_array.buffer); + + const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data; + const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data; + // Create the buffer on the output array, effectively turning it into + // a constant parameter + + const Shape& output_shape = output_array.shape(); + auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data; + const int output_buffer_size = RequiredBufferSizeForShape(output_shape); + output_data.resize(output_buffer_size); + const int dims_count = output_shape.dimensions_count(); + + // It will be convenient here to have copies of the operands shapes + // extended to match the number of dimensions of the output shape. + Shape input0_shape = input0_array.shape(); + Shape input1_shape = input1_array.shape(); + ExtendShape(&input0_shape, dims_count); + ExtendShape(&input1_shape, dims_count); + // Now we may still have operands of different sizes, which would indicate + // that we have to "broadcast" the smaller dimension. We do this using a + // a vector of Booleans indicating which input is the larger in each + // dimension. + CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count()); + CHECK_EQ(input0_shape.dimensions_count(), dims_count); + const std::vector<bool> input0_larger = + VectorGreaterThan(input0_shape.dims(), input1_shape.dims()); + + std::vector<int> big_sizes(dims_count); + std::vector<int> small_sizes(dims_count); + PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(), + &big_sizes, &small_sizes); + + // The output should already be correctly sized to match the big dimensions. + for (int i = 0; i < dims_count; i++) { + CHECK_EQ(output_shape.dims(i), big_sizes[i]); + } + + std::vector<int> input0_indices(dims_count); + std::vector<int> input1_indices(dims_count); + std::vector<int> modulo_indices(dims_count); + + for (int k = 0; k < output_buffer_size; k++) { + const std::vector<int> output_indices = ReverseOffset(output_shape, k); + for (int i = 0; i < dims_count; i++) { + modulo_indices[i] = output_indices[i] % small_sizes[i]; + } + PairwiseVectorSelect(input0_larger, output_indices, modulo_indices, + &input0_indices, &input1_indices); + const auto val0 = input0_data[Offset(input0_shape, input0_indices)]; + const auto val1 = input1_data[Offset(input1_shape, input1_indices)]; + + DataType<OutputDataType> outval; + if (binary_op->type == OperatorType::kAdd) { + outval = val0 + val1; + } else if (binary_op->type == OperatorType::kMul) { + outval = val0 * val1; + } else if (binary_op->type == OperatorType::kSub) { + outval = val0 - val1; + } else if (binary_op->type == OperatorType::kDiv) { + outval = val0 / val1; + } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + outval = std::min(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + outval = std::max(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowLess) { + outval = val0 < val1; + } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + outval = val0 <= val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + outval = val0 > val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + outval = val0 >= val1; + } else { + LOG(FATAL) << "should not get here"; + } + output_data[Offset(output_shape, output_indices)] = outval; + } +} + +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; + const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; +#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ + if (inputs_data_type == InputsDataType && \ + output_data_type == OutputDataType) { \ + EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \ + model, binary_op); \ + return; \ + } + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat) + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool) + LOG(FATAL) << "Unimplemented: don't know how to resolve a constant " + << "binary operator for these data types."; +#undef TOCO_HANDLE_CASE +} +} // namespace + +bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kTensorFlowMinimum && + binary_op->type != OperatorType::kTensorFlowMaximum && + binary_op->type != OperatorType::kTensorFlowLess && + binary_op->type != OperatorType::kTensorFlowLessEqual && + binary_op->type != OperatorType::kTensorFlowGreater && + binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + return false; + } + CHECK_EQ(binary_op->inputs.size(), 2); + + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + // Check if both inputs are constant parameters. + if (!input0_array.buffer || !input1_array.buffer) { + return false; + } + + auto& output_array = *model->arrays[binary_op->outputs[0]]; + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + // Check that input data types agree. + CHECK(input0_array.data_type == input1_array.data_type); + + // Do the actual constants propagation + EvaluateBinaryOperatorOnConstantInputs(model, binary_op); + + // Remove the binary operator and its inputs + if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { + model->arrays.erase(binary_op->inputs[0]); + } + if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { + model->arrays.erase(binary_op->inputs[1]); + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*binary_op)); + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc new file mode 100644 index 0000000000..0983c43849 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -0,0 +1,196 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Copies data from multiple source arrays to a destination array based on a +// concatenation dimension. From each array in input_arrays, it copies chunk +// sizes provided in array_copy_size vector (per array). It uses the buffer +// in concatenated_array as destination buffer. +template <ArrayDataType A, typename T> +void CopyTensorSegments(const std::vector<Array*>& input_arrays, + const std::vector<int>& array_copy_size, + const int num_elements_concatenated_array, + Array* concatenated_array) { + for (Array* input_array : input_arrays) { + if (!input_array->buffer) { + return; + } + } + + auto& concatenated_array_buffer = + concatenated_array->GetMutableBuffer<A>().data; + concatenated_array_buffer.resize(num_elements_concatenated_array); + + // It does not matter which array to use to find the value for the total + // number of copy steps. + CHECK(!input_arrays.empty()); + CHECK_NE(array_copy_size[0], 0); + const int total_copy_steps = + input_arrays[0]->GetBuffer<A>().data.size() / array_copy_size[0]; + + // Initialize the source pointers to point to beginning of the array buffers. + std::vector<const T*> src_ptr; + src_ptr.reserve(input_arrays.size()); + for (Array* input_array : input_arrays) { + src_ptr.push_back(input_array->GetBuffer<A>().data.data()); + } + + // Copy the data from input_arrays to concatenated_array_buffer. + T* dest_ptr = concatenated_array_buffer.data(); + for (int s = 0; s < total_copy_steps; s++) { + for (int i = 0; i < input_arrays.size(); i++) { + std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr); + src_ptr[i] += array_copy_size[i]; + dest_ptr += array_copy_size[i]; + } + } +} + +// Receives a series of input arrays of type Array and an integer showing the +// axis on which those arrays will be concatenated. It returns the concatenated +// arrray. +template <ArrayDataType A> +void ConcatenateTensorBuffers(const std::vector<Array*>& input_arrays, + int concatenation_axis, + Array* concatenated_array) { + int num_elements_concatenated_array = 1; + for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) { + num_elements_concatenated_array *= concatenated_array->shape().dims()[i]; + } + // Prepare the data needed for segmented copy from multiple source arrays to + // a destination array based on a oncatenation dimension. + std::vector<int> array_copy_size(input_arrays.size()); + int count = 0; + for (Array* input_array : input_arrays) { + const Shape array_shape = input_array->shape(); + array_copy_size[count] = 1; + for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) { + array_copy_size[count] *= array_shape.dims()[i]; + } + count++; + } + + // Do the actual data copy. + CopyTensorSegments<A, DataType<A>>(input_arrays, array_copy_size, + num_elements_concatenated_array, + concatenated_array); +} + +// Sets the minimum and maximum values for the concatenated array. If it's +// already set (e.g. because of previous pass in TOCO), it doesn't change it and +// returns. Otherwise it uses the input arrays min and max values to compute the +// concatenated array min and max. +void SetMinMaxForConcatenedArray(const std::vector<Array*>& input_arrays, + Array* concatenated_array) { + CHECK(concatenated_array->data_type == ArrayDataType::kFloat); + // If the minmax is already set, use it + if (concatenated_array->minmax) return; + + double concat_min = std::numeric_limits<double>::infinity(); + double concat_max = -std::numeric_limits<double>::infinity(); + + for (Array* input_array : input_arrays) { + // If any of the input arrays minmax is not set, return. + // TODO(ghodrat): shall we add the logic to compute the minmax? + if (!input_array->minmax) return; + const MinMax& input_minmax = input_array->GetMinMax(); + concat_min = std::min(concat_min, input_minmax.min); + concat_max = std::max(concat_max, input_minmax.max); + } + MinMax& minmax = concatenated_array->GetOrCreateMinMax(); + minmax.min = concat_min; + minmax.max = concat_max; +} + +} // namespace + +// Resolves the concatenation operator if all its inputs are constant arrays. +bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + const auto* concat_base_op = concat_it->get(); + if (concat_base_op->type != OperatorType::kConcatenation) { + return false; + } + const auto* concat_op = + static_cast<const ConcatenationOperator*>(concat_base_op); + + for (const string& input_name : concat_op->inputs) { + // We only expect constant unquantized arrays as input, otherwise we return. + // We also make sure the shapes of the input arrays are known and they are + // all discardable. + const Operator* input_op = GetOpWithOutput(*model, input_name); + if (input_op) return false; + if (!IsConstantParameterArray(*model, input_name)) return false; + if (!model->GetArray(input_name).has_shape()) return false; + if (model->GetArray(input_name).quantization_params) return false; + if (!IsDiscardableArray(*model, input_name)) return false; + } + + const int concatenation_axis = concat_op->concat_dim; + + CHECK_EQ(concat_op->outputs.size(), 1); + string concatenated_array_name = concat_op->outputs[0]; + Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name); + std::vector<Array*> input_arrays; + for (const string& input_name : concat_op->inputs) { + input_arrays.push_back(&model->GetArray(input_name)); + } + + switch (concatenated_array.data_type) { + case ArrayDataType::kFloat: + ConcatenateTensorBuffers<ArrayDataType::kFloat>( + input_arrays, concatenation_axis, &concatenated_array); + SetMinMaxForConcatenedArray(input_arrays, &concatenated_array); + break; + case ArrayDataType::kUint8: + ConcatenateTensorBuffers<ArrayDataType::kUint8>( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt32: + ConcatenateTensorBuffers<ArrayDataType::kInt32>( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt64: + ConcatenateTensorBuffers<ArrayDataType::kInt64>( + input_arrays, concatenation_axis, &concatenated_array); + break; + default: + LOG(FATAL) << "ArrayDataType not supported"; + } + + // Remove all the resolved arrays. + for (const string& input_name : concat_op->inputs) { + model->arrays.erase(input_name); + } + + // Remove concatenate operator + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc new file mode 100644 index 0000000000..244adcc4c4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + const auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + + const auto* fakequant_op = + static_cast<const FakeQuantOperator*>(fakequant_base_op); + + // Yield until the fakequant MinMax has been resolved. + if (!fakequant_op->minmax) { + return false; + } + + // This transformation only applies when the input array is constant. + if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) { + return false; + } + + const auto& input_array = model->GetArray(fakequant_op->inputs[0]); + auto& output_array = model->GetArray(fakequant_op->outputs[0]); + CHECK(input_array.data_type == ArrayDataType::kFloat); + output_array.data_type = ArrayDataType::kFloat; + CHECK(!output_array.buffer); + const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>(); + auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>(); + const int size = input_buffer.data.size(); + output_buffer.data.resize(size); + QuantizationParams qparams; + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>( + model->flags, *fakequant_op->minmax, &qparams); + for (int i = 0; i < size; i++) { + const double src_val = input_buffer.data[i]; + const double unclamped_quantized_val = + std::round(qparams.zero_point + src_val / qparams.scale); + const double quantized_val = + std::min(255., std::max(0., unclamped_quantized_val)); + const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); + output_buffer.data[i] = dst_val; + } + if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { + model->arrays.erase(fakequant_op->inputs[0]); + } + model->operators.erase(fakequant_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc new file mode 100644 index 0000000000..8cc6db1619 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <cstddef> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) { + const auto tfshape_it = model->operators.begin() + op_index; + const auto* tfshape_base_op = tfshape_it->get(); + if (tfshape_base_op->type != OperatorType::kTensorFlowShape) { + return false; + } + + const auto* tfshape_op = + static_cast<const TensorFlowShapeOperator*>(tfshape_base_op); + + const auto& input_array = model->GetArray(tfshape_op->inputs[0]); + auto& output_array = model->GetArray(tfshape_op->outputs[0]); + + // Yield until the input array's shape has been resolved. + if (!input_array.has_shape()) { + return false; + } + + // Create a buffer for the output array, making it a constant array, and + // copy the input shape into the output buffer. + CHECK(!output_array.buffer); + auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>(); + output_buffer.data = input_array.shape().dims(); + + // Erase the input array if no longer used + if (IsDiscardableArray(*model, tfshape_op->inputs[0]) && + CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) { + model->arrays.erase(tfshape_op->inputs[0]); + } + model->operators.erase(tfshape_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc new file mode 100644 index 0000000000..bb9bda3c82 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <algorithm> +#include <cmath> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { + const auto unary_it = model->operators.begin() + op_index; + const auto* unary_op = unary_it->get(); + // Test for unary ops of types that we know how to resolve + if (unary_op->type != OperatorType::kTensorFlowRsqrt && + unary_op->type != OperatorType::kTensorFlowSqrt && + unary_op->type != OperatorType::kTensorFlowSquare && + unary_op->type != OperatorType::kTensorFlowSum && + unary_op->type != OperatorType::kTensorFlowMin && + unary_op->type != OperatorType::kTensorFlowMax && + unary_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + // Check if the input is a constant parameter. + if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { + return false; + } + + // if the unary op involves a tensor required by a rnn state, ignore it + for (const auto& rnn_state : model->flags.rnn_states()) { + if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { + return false; + } + if (unary_op->inputs[0] == rnn_state.state_array()) { + return false; + } + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (unary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s " + " because it has a fused activation function", + LogName(*unary_op)); + return false; + } + const auto& input_array = model->GetArray(unary_op->inputs[0]); + // We have already tested above for existence of buffers (synonymous to being + // a constant param). + CHECK(input_array.buffer); + // At the moment we only support float buffers. + if (input_array.buffer->type != ArrayDataType::kFloat) { + return false; + } + const auto& input_float_data = + input_array.GetBuffer<ArrayDataType::kFloat>().data; + // Create the float buffer on the output array, effectively turning it into + // a constant parameter + const auto& output_name = unary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + int input_buffer_size = RequiredBufferSizeForShape(input_array.shape()); + int output_buffer_size = RequiredBufferSizeForShape(output_array.shape()); + const Shape& input_shape = input_array.shape(); + const Shape& output_shape = output_array.shape(); + + auto& output_float_data = + output_array.GetMutableBuffer<ArrayDataType::kFloat>().data; + output_float_data.resize(output_buffer_size); + + const int output_dims_count = output_shape.dimensions_count(); + if (unary_op->type == OperatorType::kTensorFlowReshape) { + CHECK(input_buffer_size == output_buffer_size); + memcpy(output_float_data.data(), input_float_data.data(), + input_buffer_size * sizeof(input_float_data[0])); + } else if (unary_op->type == OperatorType::kTensorFlowSum) { + // At the moment only full reduction across all dimensions is supported. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float sum = 0.f; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + sum += input_float_data[i]; + } + output_float_data[0] = sum; + } else if (unary_op->type == OperatorType::kTensorFlowMin) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float min = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + min = std::min(min, input_float_data[i]); + } + output_float_data[0] = min; + } else if (unary_op->type == OperatorType::kTensorFlowMax) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float max = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + max = std::max(max, input_float_data[i]); + } + output_float_data[0] = max; + } else if (unary_op->type == OperatorType::kTensorFlowRsqrt || + unary_op->type == OperatorType::kTensorFlowSqrt || + unary_op->type == OperatorType::kTensorFlowSquare) { + // Element-wise ops. Should have perfectly matching sizes here. + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); + } + + for (int i = 0; i < input_size; i++) { + const float val = input_float_data[i]; + float outval = 0.f; + if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + outval = 1.0f / std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + outval = std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + outval = val * val; + } else { + LOG(FATAL) << "should not get here."; + } + output_float_data[i] = outval; + } + } else { + LOG(FATAL) << "should not get here."; + } + for (const auto& input : unary_op->inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*unary_op)); + model->operators.erase(unary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc new file mode 100644 index 0000000000..d25c773f19 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { + auto* mean_op = model->operators[op_index].get(); + if (mean_op->type != OperatorType::kMean) return false; + auto* op = static_cast<MeanOperator*>(mean_op); + + if (!op->reduction_indices.empty()) return false; + if (op->inputs.size() != 2) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& indices_array = *model->arrays[op->inputs[1]]; + if (!indices_array.has_shape()) return false; + + op->reduction_indices = indices_array.GetBuffer<ArrayDataType::kInt32>().data; + + // At the moment, we only support simultaneous reduction over width and + // height. This is mainly limited by the fact that currently, the runtime + // arrays are always 4-dimensional. + CHECK_EQ(op->reduction_indices.size(), 2); + CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) || + (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc new file mode 100644 index 0000000000..d5f5869c62 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { + const auto pad_it = model->operators.begin() + op_index; + auto* pad_op = pad_it->get(); + if (pad_op->type != OperatorType::kPad) return false; + + auto* op = static_cast<PadOperator*>(pad_op); + if (!op->left_padding.empty()) return false; + + CHECK_EQ(op->inputs.size(), 2); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& array = *model->arrays[op->inputs[1]]; + if (!array.has_shape()) return false; + + const std::vector<int>& dims = array.shape().dims(); + CHECK_EQ(dims.size(), 2); + + std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data; + + for (int i = 0; i < dims[0]; ++i) { + op->left_padding.push_back(buffer[i * 2]); + op->right_padding.push_back(buffer[i * 2 + 1]); + } + + // TODO(dkalenichenko): Delete the extra input? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc new file mode 100644 index 0000000000..8fa7b83bed --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { + auto reorder_it = model->operators.begin() + op_index; + auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get()); + if (reorder_op->type != OperatorType::kReorderAxes) { + return false; + } + const auto& input_array_name = reorder_op->inputs[0]; + const auto& output_array_name = reorder_op->outputs[0]; + auto& input_array = model->GetArray(input_array_name); + auto& output_array = model->GetArray(output_array_name); + string constant_input_array_name = input_array_name; + if (!input_array.buffer) { + const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); + if (op_producing_input && + op_producing_input->type == OperatorType::kFakeQuant) { + constant_input_array_name = op_producing_input->inputs[0]; + } + } + auto& constant_input_array = model->GetArray(constant_input_array_name); + if (!constant_input_array.buffer) { + return false; + } + // Yield until output dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + // Reorder the input array dims and buffer data + CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat); + CHECK(!output_array.buffer); + auto& input_data = + constant_input_array.GetMutableBuffer<ArrayDataType::kFloat>().data; + std::vector<float> reordered_data; + reordered_data.resize(RequiredBufferSizeForShape(output_array.shape())); + const auto input_axes_order = reorder_op->input_axes_order; + const auto output_axes_order = reorder_op->output_axes_order; + // TODO(b/62904716) Shapes should be used directly. + Shape input_shape = constant_input_array.shape(); + Shape output_shape = output_array.shape(); + if (AxesCount(input_axes_order) == 2) { + UnextendShape(&input_shape, 2); + UnextendShape(&output_shape, 2); + } + ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape, + input_data.data(), reordered_data.data()); + input_data = reordered_data; + input_array.copy_shape(output_array.shape()); + constant_input_array.copy_shape(output_array.shape()); + + // Update the edges of the graph to point to the input array + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == output_array_name) { + input = input_array_name; + } + } + } + + AddMessageF("Reordered axes for array %s", input_array_name); + + // Remove the op and output array. + model->arrays.erase(output_array_name); + model->operators.erase(reorder_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc new file mode 100644 index 0000000000..bed2a85bd2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op); + + if (!op->shape.empty()) return false; + + if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { + const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]]; + op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data; + } + + if (op->shape.empty()) return false; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc new file mode 100644 index 0000000000..1d0a2ec8f6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kSlice) return false; + + auto* op = static_cast<SliceOperator*>(slice_op); + if (!op->begin.empty()) return false; + + CHECK_EQ(op->inputs.size(), 3); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& begin_array = *model->arrays[op->inputs[1]]; + if (!begin_array.has_shape()) return false; + + const auto& size_array = *model->arrays[op->inputs[2]]; + if (!size_array.has_shape()) return false; + + op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data; + op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data; + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc new file mode 100644 index 0000000000..5fc3b25bc1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kStridedSlice) return false; + + auto* op = static_cast<StridedSliceOperator*>(slice_op); + if (!op->start_indices.empty()) return false; + + CHECK_EQ(op->inputs.size(), 4); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + + const auto& start_array = *model->arrays[op->inputs[1]]; + if (!start_array.has_shape()) return false; + + const auto& stop_array = *model->arrays[op->inputs[2]]; + if (!stop_array.has_shape()) return false; + + const auto& stride_array = *model->arrays[op->inputs[3]]; + if (!stride_array.has_shape()) return false; + + op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data; + op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data; + op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data; + + // Only 4D arrays are supported for now. + CHECK_EQ(op->start_indices.size(), 4); + CHECK_EQ(op->stop_indices.size(), 4); + CHECK_EQ(op->strides.size(), 4); + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc new file mode 100644 index 0000000000..b482f5cf51 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { + auto concat_it = model->operators.begin() + op_index; + const auto* tf_concat_op = concat_it->get(); + if (tf_concat_op->type != OperatorType::kTensorFlowConcat && + tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + return false; + } + + CHECK_GE(tf_concat_op->inputs.size(), 2); + // TensorFlow Concat and ConcatV2 nodes only differ by the ordering + // of inputs: in Concat, the concat_dim is the first input, while in + // ConcatV2, it is the last input. + std::size_t concat_dim_pos = 0; + if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + concat_dim_pos = tf_concat_op->inputs.size() - 1; + } + const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos]; + std::vector<string> concat_input_names; + for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) { + if (i != concat_dim_pos) { + concat_input_names.push_back(tf_concat_op->inputs[i]); + } + } + // If the concat_dim array hasn't been resolved to a constant yet, + // we need to yield. + const auto& concat_dim_array = model->GetArray(concat_dim_name); + if (!concat_dim_array.buffer) { + AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant", + LogName(*tf_concat_op)); + return false; + } + + CHECK(concat_dim_array.data_type == ArrayDataType::kInt32); + const auto& concat_dim_data = + concat_dim_array.GetBuffer<ArrayDataType::kInt32>().data; + CHECK_EQ(concat_dim_data.size(), 1); + const int concat_dim = concat_dim_data[0]; + + // Create the Concatenation op replacing the TensorFlowConcat op. + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {tf_concat_op->outputs[0]}; + auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op); + CHECK_EQ(depth_concat_it->get(), concatenation_op); + // Update invalidated iterator + concat_it = depth_concat_it + 1; + CHECK_EQ(concat_it->get(), tf_concat_op); + + // Remove the concat_dim array if it is not used by anything else. + if (CountOpsWithInput(*model, concat_dim_name) == 1) { + model->arrays.erase(concat_dim_name); + } + // Remove the TensorFlowConcat op + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc new file mode 100644 index 0000000000..bea7487051 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { + auto matmul_it = model->operators.begin() + op_index; + if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + return false; + } + const auto* matmul_op = matmul_it->get(); + + // Find the op producing the array passed to this MatMul + auto previous_op_it = model->operators.begin(); + bool found = false; + for (; previous_op_it != model->operators.end(); ++previous_op_it) { + for (const auto& output : (*previous_op_it)->outputs) { + if (output == matmul_op->inputs[0]) { + found = true; + break; + } + } + if (found) { + break; + } + } + Operator* previous_op = (found) ? previous_op_it->get() : nullptr; + + // construct the new FullyConnectedOperator + auto* fc_op = new FullyConnectedOperator; + fc_op->outputs = matmul_op->outputs; + + // insert the newly constructed FullyConnectedOperator + auto fc_it = model->operators.emplace(matmul_it, fc_op); + + // refresh invalidated iterator + matmul_it = fc_it + 1; + DCHECK_EQ(matmul_it->get(), matmul_op); + + // The way that TensorFlow encodes FullyConnected ops is as a pair + // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the + // MatMul + // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the + // input doesn't need reshaping, so we can't just match (Reshape, MatMul) + // pairs. + if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF("Combining %s and %s into %s", LogName(*previous_op), + LogName(*matmul_op), LogName(*fc_op)); + const auto& previous_op_output = previous_op->outputs[0]; + if (CountOpsWithInput(*model, previous_op_output) == 1) { + model->arrays.erase(previous_op_output); + } + CHECK_EQ(previous_op->inputs.size(), 2); + fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; + // Only remove Reshape node if no other node uses its output. + if (CountOpsWithInput(*model, previous_op_output) == 1) { + const auto& previous_op_shape = previous_op->inputs[1]; + if (CountOpsWithInput(*model, previous_op_shape) == 1 && + !GetOpWithOutput(*model, previous_op_shape)) { + model->arrays.erase(previous_op_shape); + } + model->operators.erase(previous_op_it); + } + + // We may have just invalidated matmul_it, so let's refresh it now. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } + CHECK(matmul_it != model->operators.end()); + CHECK(matmul_it->get() == matmul_op); + } else { + AddMessageF("Replacing %s by a FullyConnected operator", + LogName(*matmul_op)); + fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]}; + } + + // erase the MatMul operator + model->operators.erase(matmul_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc new file mode 100644 index 0000000000..cfa5ce0716 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { + const auto merge_it = model->operators.begin() + op_index; + const auto* merge_op = merge_it->get(); + if (merge_op->type != OperatorType::kTensorFlowMerge) { + return false; + } + + // We need to yield until this Merge node has only 1 input, which will mean + // that that is the selected input. Other graph transformations on other nodes + // such as ResolveTensorFlowSwitch, will take care of trimming the + // non-selected inputs, so that at some point there will be only 1 input left. + if (merge_op->inputs.size() > 1) { + AddMessageF("Waiting for %s to be resolved", LogName(*merge_op)); + return false; + } + + // Now that the merge node has 1 input exactly, it is the same as an Identity + // node and can be resolved trivially. + CHECK_EQ(merge_op->inputs.size(), 1); + + // Update the edges of the graph ahead of removing the node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == merge_op->outputs[0]) { + input = merge_op->inputs[0]; + } + } + } + + // Remove the node and its output array. + AddMessageF("Removing already-resolved %s", LogName(*merge_op)); + model->arrays.erase(merge_op->outputs[0]); + model->operators.erase(merge_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc new file mode 100644 index 0000000000..1d3f42b5ec --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) { + const auto squeeze_it = model->operators.begin() + op_index; + const auto* squeeze_op = squeeze_it->get(); + if (squeeze_op->type != OperatorType::kSqueeze) { + return false; + } + + CHECK_EQ(squeeze_op->inputs.size(), 1); + CHECK_EQ(squeeze_op->outputs.size(), 1); + + // If the output is consumed by a reshape op, it's a trivial squeeze. + if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF( + "%s is trivial because its output is only consumed by a " + "Reshape op", + LogName(*squeeze_op)); + + return RemoveTrivialPassthroughOp(this, model, op_index); + } + } + + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc new file mode 100644 index 0000000000..55adfca037 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { + const auto switch_it = model->operators.begin() + op_index; + const auto* switch_op = switch_it->get(); + if (switch_op->type != OperatorType::kTensorFlowSwitch) { + return false; + } + + CHECK_EQ(switch_op->inputs.size(), 2); + CHECK_EQ(switch_op->outputs.size(), 2); + const string& predicate_name = switch_op->inputs[1]; + // If the predicate array hasn't been resolved to a constant yet, + // we need to yield. + if (!IsConstantParameterArray(*model, predicate_name)) { + AddMessageF( + "Waiting for the boolean predicate of %s to be resolved to a constant", + LogName(*switch_op)); + return false; + } + + // The predicate should be boolean, and should consist of a single value. + const auto& predicate_array = model->GetArray(predicate_name); + CHECK(predicate_array.data_type == ArrayDataType::kBool); + for (const auto& dim : predicate_array.shape().dims()) { + CHECK_EQ(dim, 1); + } + + // Obtain the predicate boolean value. + const auto& predicate_data = + predicate_array.GetBuffer<ArrayDataType::kBool>().data; + CHECK_EQ(predicate_data.size(), 1); + const bool predicate_value = predicate_data[0]; + + // From the TensorFlow docs on .switch() in + // third_party/tensorflow/python/ops/control_flow_ops.py + // + // If `pred` is false, the `data` input is forwared to the first output. + // Otherwise, the data goes to the second output. + // + // Note that this comment used to say the opposite and was recently fixed: + // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c + const int selected_output_index = predicate_value ? 1 : 0; + const int nonselected_output_index = predicate_value ? 0 : 1; + + // Update the edges of the graph ahead of removing the node: + // edges that were pointing to the selected output, should instead + // point to the input of the Switch node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == switch_op->outputs[selected_output_index]) { + input = switch_op->inputs[0]; + } + } + } + + // There remains to handle the edges that were pointing to the nonselected + // output. We will just discard those edges. Concretely, at the moment, + // our only examples of graphs with Switch nodes have them feeding into Merge + // nodes, so what we're saying here is that we'll make the convention, + // in our toco internal representation, that Merge nodes with only 1 input + // are Merge nodes that have been resolved already and should be have as + // Identity nodes, simply forwarding their input. + // + for (const auto& other_op : model->operators) { + auto input_it = other_op->inputs.begin(); + while (input_it != other_op->inputs.end()) { + if (*input_it == switch_op->outputs[nonselected_output_index]) { + // Let us guard our assumption that only Merge nodes consume the outputs + // of Switch nodes: + CHECK(other_op->type == OperatorType::kTensorFlowMerge); + input_it = other_op->inputs.erase(input_it); + } else { + ++input_it; + } + } + } + + // Remove the output arrays if they are now unused. + for (int i = 0; i < 2; i++) { + if (!GetOpWithInput(*model, switch_op->outputs[i])) { + model->arrays.erase(switch_op->outputs[i]); + } + } + // Remove input arrays if they are only used by the switch itself and aren't + // the output of another op (will get handled by RemoveUnusedOp in that case). + for (const auto& input : switch_op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + // Remove the switch node itself. + AddMessageF("Removing already-resolved %s", LogName(*switch_op)); + model->operators.erase(switch_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc new file mode 100644 index 0000000000..9f7e7c42a2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, + int operand_index) { + CHECK(tile_op->type == OperatorType::kTensorFlowTile); + CHECK_EQ(binary_op->inputs.size(), 2); + CHECK_EQ(tile_op->inputs.size(), 2); + const string tile_multiplier_array = tile_op->inputs[1]; + const string tile_output_array = tile_op->outputs[0]; + binary_op->inputs[operand_index] = tile_op->inputs[0]; + auto tile_it = model->operators.begin(); + for (; tile_it != model->operators.end(); ++tile_it) { + if (tile_it->get() == tile_op) { + break; + } + } + CHECK(tile_it != model->operators.end()); + CHECK(tile_it->get() == tile_op); + model->operators.erase(tile_it); + if (!CountOpsWithInput(*model, tile_multiplier_array) && + !GetOpWithOutput(*model, tile_multiplier_array)) { + model->arrays.erase(tile_multiplier_array); + } + if (!CountOpsWithInput(*model, tile_output_array)) { + model->arrays.erase(tile_output_array); + } +} +} // namespace + +bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // In the unlikely case where both operands are Tile, we can't infer the + // output + // size without the Tile nodes, so we have to bail out. + if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && + op[1]->type == OperatorType::kTensorFlowTile) { + return false; + } + + for (int i = 0; i < 2; i++) { + if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { + // We can only remove a Tile operator is no other op than the present + // binary op was consuming its tiled output. + if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { + AddMessageF("Removing %s", LogName(*op[i])); + RemoveTileOperator(model, op[i], binary_op, i); + return true; + } + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD new file mode 100644 index 0000000000..8931498782 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc new file mode 100644 index 0000000000..c6705ad305 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +//#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<testing::Matcher<float>> ArrayFloatNear( + const std::vector<float>& values, float max_abs_error = 1e-5) { + std::vector<testing::Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +// The following 3 tests make sure the concatenation operation on different axis +// values match TensorFlow results listed below: +// +// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] +// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]] +// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]] +// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]] +// +// ConcatAtAxis0 test: +// t0 = tf.concat([x0, x1, x2, x3], 0) +// [[[ 0 1] +// [ 2 3]] +// +// [[ 4 5] +// [ 6 7]] +// +// [[10 11] +// [12 13]] +// +// [[14 15] +// [16 17]] +// +// [[20 21] +// [22 23]] +// +// [[24 25] +// [26 27]] +// +// [[30 31] +// [32 33]] +// +// [[34 35] +// [36 37]]] +// +// ConcatAtAxis1 test: +// t1 = tf.concat([x0, x1, x2, x3], 1) +// [[[ 0 1] +// [ 2 3] +// [10 11] +// [12 13] +// [20 21] +// [22 23] +// [30 31] +// [32 33]] +// +// [[ 4 5] +// [ 6 7] +// [14 15] +// [16 17] +// [24 25] +// [26 27] +// [34 35] +// [36 37]]] +// +// ConcatAtAxis2 test: +// t2 = tf.concat([x0, x1, x2, x3], 2) +// [[[ 0 1 10 11 20 21 30 31] +// [ 2 3 12 13 22 23 32 33]] +// +// [[ 4 5 14 15 24 25 34 35] +// [ 6 7 16 17 26 27 36 37]]] + +class ResolveConstantConcatenationTest : public ::testing::Test { + protected: + ResolveConstantConcatenationTest() {} + + // Prepare a hypothetical TOCO model with one Concatenation operator in it + // together with 4 arrays as its inputs. + // It receives the dimension of concatenation as input. + void PrepareModel(Model* model, int concat_dim) { + std::vector<string> concat_input_names = {"array0", "array1", "array2", + "array3"}; + + const int kDim = 3; + const int kElementPerDim = 2; + const int kBufSize = 8; + const int kNumArrays = 4; + static float in_buf[kNumArrays][kBufSize] = { + {0., 1., 2., 3., 4., 5., 6., 7.}, + {10., 11., 12., 13., 14., 15., 16., 17.}, + {20., 21., 22., 23., 24., 25., 26., 27.}, + {30., 31., 32., 33., 34., 35., 36., 37.}}; + int cnt = 0; + for (const string& concat_input_name : concat_input_names) { + Array& in_array = model->GetOrCreateArray(concat_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims(); + for (int i = 0; i < kDim; i++) { + in_array_shape_dim->push_back(kElementPerDim); + } + auto& in_array_buffer = + in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>(); + in_array_buffer.data.resize(kBufSize); + float* buf_ptr = + in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data(); + std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr); + cnt++; + } + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {"concat_op_outputs"}; + Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim); + for (int i = 0; i < kDim; i++) { + if (i == concat_dim) { + (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim; + } else { + (*out_array_shape_dim)[i] = kElementPerDim; + } + } + model->operators.push_back(std::unique_ptr<Operator>(concatenation_op)); + } +}; + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { + Model model; + const int concat_dim = 0; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., + 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25., + 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { + Model model; + const int concat_dim = 1; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., + 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15., + 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { + Model model; + const int concat_dim = 2; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., + 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25., + 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.}))); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc new file mode 100644 index 0000000000..4e273343df --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If a conv operation has an im2col array, yield: it should be dropped first. + if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) { + return false; + } + + Operator* ac_op = nullptr; + switch (op->fused_activation_function) { + case FusedActivationFunctionType::kRelu: + ac_op = new ReluOperator; + break; + case FusedActivationFunctionType::kRelu6: + ac_op = new Relu6Operator; + break; + case FusedActivationFunctionType::kRelu1: + ac_op = new Relu1Operator; + break; + default: + return false; + } + + // At this point we know that the op has a fused activation function. At the + // moment that only happens with ops having a single output, may be + // relaxed in the future. + CHECK_EQ(op->outputs.size(), 1); + + // Emplace unfused activation function, drop the fused one. + model->operators.emplace(it + 1, ac_op); + op->fused_activation_function = FusedActivationFunctionType::kNone; + + // Wire up arrays, constructing a new intermediate array to connect the + // op to its new unfused activation function. + ac_op->outputs = op->outputs; + const string& tmp_array_name = + AvailableArrayName(*model, op->outputs[0] + "_unfused"); + CHECK(!model->arrays.count(tmp_array_name)); + model->GetOrCreateArray(tmp_array_name); + ac_op->inputs = {tmp_array_name}; + op->outputs = {tmp_array_name}; + return true; +} + +} // namespace toco |