diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc new file mode 100644 index 0000000000..53e1be7a05 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector<bool> VectorGreaterThan(const std::vector<int>& a, + const std::vector<int>& b) { + DCHECK_EQ(a.size(), b.size()); + const int size = a.size(); + std::vector<bool> result(size); + for (int i = 0; i < size; i++) { + result[i] = a[i] > b[i]; + } + return result; +} + +void PairwiseVectorSelect(const std::vector<bool>& selector, + const std::vector<int>& input_a, + const std::vector<int>& input_b, + std::vector<int>* output_a, + std::vector<int>* output_b) { + DCHECK_EQ(input_a.size(), input_b.size()); + DCHECK_EQ(output_a->size(), output_b->size()); + DCHECK_EQ(input_a.size(), output_a->size()); + DCHECK_EQ(selector.size(), input_a.size()); + const int size = input_a.size(); + for (int i = 0; i < size; i++) { + if (selector[i]) { + (*output_a)[i] = input_a[i]; + (*output_b)[i] = input_b[i]; + } else { + (*output_a)[i] = input_b[i]; + (*output_b)[i] = input_a[i]; + } + } +} + +template <ArrayDataType InputsDataType, ArrayDataType OutputDataType> +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + CHECK(IsConstantParameterArray(*model, binary_op->inputs[0])); + CHECK(IsConstantParameterArray(*model, binary_op->inputs[1])); + CHECK(binary_op->fused_activation_function == + FusedActivationFunctionType::kNone); + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + const auto& output_name = binary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + CHECK(input0_array.data_type == InputsDataType); + CHECK(input1_array.data_type == InputsDataType); + CHECK(output_array.data_type == OutputDataType); + + // We have already tested above for existence of input buffers + // (synonymous to being a constant param). + CHECK(input0_array.buffer); + CHECK(input1_array.buffer); + // On the other hand, the output should not already have a buffer. + CHECK(!output_array.buffer); + + const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data; + const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data; + // Create the buffer on the output array, effectively turning it into + // a constant parameter + + const Shape& output_shape = output_array.shape(); + auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data; + const int output_buffer_size = RequiredBufferSizeForShape(output_shape); + output_data.resize(output_buffer_size); + const int dims_count = output_shape.dimensions_count(); + + // It will be convenient here to have copies of the operands shapes + // extended to match the number of dimensions of the output shape. + Shape input0_shape = input0_array.shape(); + Shape input1_shape = input1_array.shape(); + ExtendShape(&input0_shape, dims_count); + ExtendShape(&input1_shape, dims_count); + // Now we may still have operands of different sizes, which would indicate + // that we have to "broadcast" the smaller dimension. We do this using a + // a vector of Booleans indicating which input is the larger in each + // dimension. + CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count()); + CHECK_EQ(input0_shape.dimensions_count(), dims_count); + const std::vector<bool> input0_larger = + VectorGreaterThan(input0_shape.dims(), input1_shape.dims()); + + std::vector<int> big_sizes(dims_count); + std::vector<int> small_sizes(dims_count); + PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(), + &big_sizes, &small_sizes); + + // The output should already be correctly sized to match the big dimensions. + for (int i = 0; i < dims_count; i++) { + CHECK_EQ(output_shape.dims(i), big_sizes[i]); + } + + std::vector<int> input0_indices(dims_count); + std::vector<int> input1_indices(dims_count); + std::vector<int> modulo_indices(dims_count); + + for (int k = 0; k < output_buffer_size; k++) { + const std::vector<int> output_indices = ReverseOffset(output_shape, k); + for (int i = 0; i < dims_count; i++) { + modulo_indices[i] = output_indices[i] % small_sizes[i]; + } + PairwiseVectorSelect(input0_larger, output_indices, modulo_indices, + &input0_indices, &input1_indices); + const auto val0 = input0_data[Offset(input0_shape, input0_indices)]; + const auto val1 = input1_data[Offset(input1_shape, input1_indices)]; + + DataType<OutputDataType> outval; + if (binary_op->type == OperatorType::kAdd) { + outval = val0 + val1; + } else if (binary_op->type == OperatorType::kMul) { + outval = val0 * val1; + } else if (binary_op->type == OperatorType::kSub) { + outval = val0 - val1; + } else if (binary_op->type == OperatorType::kDiv) { + outval = val0 / val1; + } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + outval = std::min(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + outval = std::max(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowLess) { + outval = val0 < val1; + } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + outval = val0 <= val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + outval = val0 > val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + outval = val0 >= val1; + } else { + LOG(FATAL) << "should not get here"; + } + output_data[Offset(output_shape, output_indices)] = outval; + } +} + +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; + const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; +#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ + if (inputs_data_type == InputsDataType && \ + output_data_type == OutputDataType) { \ + EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \ + model, binary_op); \ + return; \ + } + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat) + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool) + LOG(FATAL) << "Unimplemented: don't know how to resolve a constant " + << "binary operator for these data types."; +#undef TOCO_HANDLE_CASE +} +} // namespace + +bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kTensorFlowMinimum && + binary_op->type != OperatorType::kTensorFlowMaximum && + binary_op->type != OperatorType::kTensorFlowLess && + binary_op->type != OperatorType::kTensorFlowLessEqual && + binary_op->type != OperatorType::kTensorFlowGreater && + binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + return false; + } + CHECK_EQ(binary_op->inputs.size(), 2); + + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + // Check if both inputs are constant parameters. + if (!input0_array.buffer || !input1_array.buffer) { + return false; + } + + auto& output_array = *model->arrays[binary_op->outputs[0]]; + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + // Check that input data types agree. + CHECK(input0_array.data_type == input1_array.data_type); + + // Do the actual constants propagation + EvaluateBinaryOperatorOnConstantInputs(model, binary_op); + + // Remove the binary operator and its inputs + if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { + model->arrays.erase(binary_op->inputs[0]); + } + if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { + model->arrays.erase(binary_op->inputs[1]); + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*binary_op)); + model->operators.erase(binary_it); + return true; +} + +} // namespace toco |