aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 12:09:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 12:14:11 -0700
commitdcd3b4307a3095e3f18aef53f5034787e3cc3af6 (patch)
treec7d797908bf37d97256ebd65ae3d3c88d33e16bf /tensorflow/contrib
parent723fd1245ed650ad07e5049faec021f4f0f6d408 (diff)
Remove the restrictions that constant resolution of reduce_sum operators must be on axis 0, and can only be on 1 or 2-d inputs.
PiperOrigin-RevId: 216226776
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc93
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc140
3 files changed, 229 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index c698a9567a..5364eebbc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -27,6 +27,73 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace toco {
+namespace {
+
+// Using the function reducer, reduce input along all axes in axes.
+// Put the reduced data in output, which should aleady be appropriately sized.
+// check_output_shape is set to what this code computes the final shape
+// to be, so it can be cross checked with the shape computation logic.
+void ReduceGeneric(bool keep_dims, const std::vector<int>& axes,
+ const Shape& input_shape, const std::vector<float>& input,
+ Shape* check_output_shape, std::vector<float>* output,
+ const std::function<float(float, float)>& reducer) {
+ if (!IsNonEmpty(input_shape)) {
+ // Zero-dimensions will break the NextIndices() logic, so just early out if
+ // we have an empty shape.
+ return;
+ }
+
+ // Set up output_shape to be the same length as input_shape, with
+ // appropriate dimensions squashed to 1. If keep_dims is false, we'll strip
+ // out the one dimensions at the end, but it's convenient to leave them for
+ // now. We recompute the shape because we need the output shape to have
+ // 1-dims in all the squashed dimensions; the shape from shape computation may
+ // remove those squashed dimensions, depending on the options used.
+ Shape output_shape = input_shape;
+
+ // Reduction mask will be elementwise multiplied against the input
+ // indices to figure out the output index for the element.
+ std::vector<int> reduction_mask(input_shape.dimensions_count(), 1);
+ for (int axis : axes) {
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, input_shape.dimensions_count());
+ reduction_mask[axis] = 0;
+ output_shape.mutable_dims()->at(axis) = 1;
+ }
+
+ std::vector<int> output_indices(input_shape.dimensions_count());
+ for (int input_offset = 0; input_offset < input.size(); ++input_offset) {
+ std::vector<int> input_indices = ReverseOffset(input_shape, input_offset);
+ // Calculate the output location by squashing input indices to 0
+ // in reduced axes.
+ for (int i = 0; i < input_shape.dimensions_count(); ++i) {
+ output_indices[i] = input_indices[i] * reduction_mask[i];
+ }
+ int output_offset = Offset(output_shape, output_indices);
+ if (input_indices == output_indices) {
+ // Base element for the reduced axes
+ output->at(output_offset) = input.at(input_offset);
+ } else {
+ // Reduce with existing element.
+ output->at(output_offset) =
+ reducer(output->at(output_offset), input.at(input_offset));
+ }
+ }
+
+ if (!keep_dims) {
+ // Strip out the dims from output_shape.
+ std::vector<int> new_dims;
+ for (int i = 0; i < output_shape.dimensions_count(); ++i) {
+ if (reduction_mask[i]) {
+ new_dims.push_back(output_shape.dims(i));
+ }
+ }
+ output_shape.mutable_dims()->swap(new_dims);
+ }
+ *check_output_shape = output_shape;
+}
+
+} // namespace
bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
auto& output_array = model->GetArray(op.outputs[0]);
@@ -176,27 +243,19 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
- int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
- CHECK_LT(axis, input_shape.dimensions_count()) << "Axis out of bounds";
- // We currently only handle reduction on axis 0.
- CHECK_EQ(axis, 0) << "Only reduction along axis 0 is supported";
- // We currently only handle 1-D and 2-D input tensors.
- CHECK_LE(input_shape.dimensions_count(), 2) << "Rank >2 not yet supported";
// We only support keep_dims=true; shape prop will need to change otherwise.
auto sum_op = static_cast<const TensorFlowSumOperator*>(unary_op);
- CHECK(sum_op->keep_dims) << "Only keep_dims=true is supported";
+ Shape check_output_shape;
- std::vector<int> indices(input_shape.dimensions_count());
- for (int i = 0; i < input_shape.dims(1); ++i) {
- indices[1] = i;
- float sum = 0.f;
- for (int j = 0; j < input_shape.dims(0); ++j) {
- indices[0] = j;
- sum += (*input_float_data)[Offset(input_shape, indices)];
- }
- output_float_data[i] = sum;
- }
+ ReduceGeneric(
+ sum_op->keep_dims, axis_array.GetBuffer<ArrayDataType::kInt32>().data,
+ input_shape, *input_float_data, &check_output_shape, &output_float_data,
+ [](float existing, float current) -> float {
+ return existing + current;
+ });
+ CHECK(check_output_shape == output_shape)
+ << "Shape propagation output shape doesn't match output shape from op";
} else if (unary_op->type == OperatorType::kReduceMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index acf1e3ede5..6f1be298ca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -30,3 +30,16 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)
+
+tf_cc_test(
+ name = "resolve_constant_unary_test",
+ srcs = ["resolve_constant_unary_test.cc"],
+ tags = ["no_oss"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
new file mode 100644
index 0000000000..a53abc9941
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+void RunResolveSum(const std::vector<float>& input,
+ const std::vector<int>& input_shape,
+ const std::vector<int>& axis,
+ const std::vector<int>& output_shape,
+ const std::vector<float>& expected_output) {
+ Model model;
+ Array& input0 = model.GetOrCreateArray("input0");
+ Array& input1 = model.GetOrCreateArray("input1");
+ Array& output = model.GetOrCreateArray("output");
+
+ *input0.mutable_shape()->mutable_dims() = input_shape;
+ input0.data_type = ArrayDataType::kFloat;
+ input0.GetMutableBuffer<ArrayDataType::kFloat>().data = input;
+
+ *input1.mutable_shape()->mutable_dims() = {static_cast<int>(axis.size())};
+ input1.GetMutableBuffer<ArrayDataType::kInt32>().data = axis;
+ input1.data_type = ArrayDataType::kInt32;
+
+ *output.mutable_shape()->mutable_dims() = output_shape;
+
+ auto sum_op = absl::make_unique<TensorFlowSumOperator>();
+ sum_op->keep_dims = true;
+ sum_op->inputs = {"input0", "input1"};
+ sum_op->outputs = {"output"};
+ model.operators.push_back(std::move(sum_op));
+ ResolveConstantUnaryOperator().Run(&model, 0);
+ EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
+ expected_output);
+ EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);
+}
+
+// Reduce a 2d array across axis 0
+TEST(ResolveConstantUnary, ResolveSumAxis0_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {0},
+
+ // Expected output shape,
+ {1, 4},
+
+ // Expected output
+ {13, 13, 11, 15});
+ // clang-format on
+}
+
+// Reduce a 2d array across axis 1
+TEST(ResolveConstantUnary, ResolveSumAxis1_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {1},
+
+ // Expected output shape,
+ {3, 1},
+
+ // Expected output
+ {9, 22, 21});
+ // clang-format on
+}
+
+// Reduce a 3d tensor across axes 0 and 2.
+TEST(ResolveConstantUnary, ResolveSumAxis0_2_3D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ { 0, 1, 2,
+ 3, 10, 11,
+ 12, 13, 20,
+ 21, 22, 23,
+
+ 100, 101, 102,
+ 103, 110, 111,
+ 112, 113, 120,
+ 121, 122, 123,
+
+ 200, 201, 202,
+ 203, 210, 211,
+ 212, 213, 220,
+ 221, 222, 223 },
+
+ // Input shape
+ {3, 4, 3},
+
+ // Axes
+ {0, 2},
+
+ // Expected output shape,
+ {1, 4, 1},
+
+ // Expected output, generated using octave.
+ { 909, 972, 1035, 1098});
+ // clang-format on
+}
+
+} // namespace
+} // namespace toco