aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc93
1 files changed, 76 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index c698a9567a..5364eebbc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -27,6 +27,73 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace toco {
+namespace {
+
+// Using the function reducer, reduce input along all axes in axes.
+// Put the reduced data in output, which should aleady be appropriately sized.
+// check_output_shape is set to what this code computes the final shape
+// to be, so it can be cross checked with the shape computation logic.
+void ReduceGeneric(bool keep_dims, const std::vector<int>& axes,
+ const Shape& input_shape, const std::vector<float>& input,
+ Shape* check_output_shape, std::vector<float>* output,
+ const std::function<float(float, float)>& reducer) {
+ if (!IsNonEmpty(input_shape)) {
+ // Zero-dimensions will break the NextIndices() logic, so just early out if
+ // we have an empty shape.
+ return;
+ }
+
+ // Set up output_shape to be the same length as input_shape, with
+ // appropriate dimensions squashed to 1. If keep_dims is false, we'll strip
+ // out the one dimensions at the end, but it's convenient to leave them for
+ // now. We recompute the shape because we need the output shape to have
+ // 1-dims in all the squashed dimensions; the shape from shape computation may
+ // remove those squashed dimensions, depending on the options used.
+ Shape output_shape = input_shape;
+
+ // Reduction mask will be elementwise multiplied against the input
+ // indices to figure out the output index for the element.
+ std::vector<int> reduction_mask(input_shape.dimensions_count(), 1);
+ for (int axis : axes) {
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, input_shape.dimensions_count());
+ reduction_mask[axis] = 0;
+ output_shape.mutable_dims()->at(axis) = 1;
+ }
+
+ std::vector<int> output_indices(input_shape.dimensions_count());
+ for (int input_offset = 0; input_offset < input.size(); ++input_offset) {
+ std::vector<int> input_indices = ReverseOffset(input_shape, input_offset);
+ // Calculate the output location by squashing input indices to 0
+ // in reduced axes.
+ for (int i = 0; i < input_shape.dimensions_count(); ++i) {
+ output_indices[i] = input_indices[i] * reduction_mask[i];
+ }
+ int output_offset = Offset(output_shape, output_indices);
+ if (input_indices == output_indices) {
+ // Base element for the reduced axes
+ output->at(output_offset) = input.at(input_offset);
+ } else {
+ // Reduce with existing element.
+ output->at(output_offset) =
+ reducer(output->at(output_offset), input.at(input_offset));
+ }
+ }
+
+ if (!keep_dims) {
+ // Strip out the dims from output_shape.
+ std::vector<int> new_dims;
+ for (int i = 0; i < output_shape.dimensions_count(); ++i) {
+ if (reduction_mask[i]) {
+ new_dims.push_back(output_shape.dims(i));
+ }
+ }
+ output_shape.mutable_dims()->swap(new_dims);
+ }
+ *check_output_shape = output_shape;
+}
+
+} // namespace
bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
auto& output_array = model->GetArray(op.outputs[0]);
@@ -176,27 +243,19 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
- int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
- CHECK_LT(axis, input_shape.dimensions_count()) << "Axis out of bounds";
- // We currently only handle reduction on axis 0.
- CHECK_EQ(axis, 0) << "Only reduction along axis 0 is supported";
- // We currently only handle 1-D and 2-D input tensors.
- CHECK_LE(input_shape.dimensions_count(), 2) << "Rank >2 not yet supported";
// We only support keep_dims=true; shape prop will need to change otherwise.
auto sum_op = static_cast<const TensorFlowSumOperator*>(unary_op);
- CHECK(sum_op->keep_dims) << "Only keep_dims=true is supported";
+ Shape check_output_shape;
- std::vector<int> indices(input_shape.dimensions_count());
- for (int i = 0; i < input_shape.dims(1); ++i) {
- indices[1] = i;
- float sum = 0.f;
- for (int j = 0; j < input_shape.dims(0); ++j) {
- indices[0] = j;
- sum += (*input_float_data)[Offset(input_shape, indices)];
- }
- output_float_data[i] = sum;
- }
+ ReduceGeneric(
+ sum_op->keep_dims, axis_array.GetBuffer<ArrayDataType::kInt32>().data,
+ input_shape, *input_float_data, &check_output_shape, &output_float_data,
+ [](float existing, float current) -> float {
+ return existing + current;
+ });
+ CHECK(check_output_shape == output_shape)
+ << "Shape propagation output shape doesn't match output shape from op";
} else if (unary_op->type == OperatorType::kReduceMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.