aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-25 12:53:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 12:58:31 -0700
commitb4c97bf13b618fbdc22981ce04f9faf358da034c (patch)
treeca129ae15fdcdee0eb9cd96232ef5b4d14490e89
parent50d48d606c3b2c08eef249b6fe4f543a51ca8455 (diff)
Implementation of UnsortedSegmentSum in tf2xla bridge.
PiperOrigin-RevId: 163109769
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py139
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc155
4 files changed, 309 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 4f0137e8d9..c693f58f8b 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -354,6 +354,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "segment_reduction_ops_test",
+ size = "small",
+ srcs = ["segment_reduction_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "spacetobatch_op_test",
size = "medium",
srcs = ["spacetobatch_op_test.py"],
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
new file mode 100644
index 0000000000..260a04421b
--- /dev/null
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for segment reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class SegmentReductionOpsTest(XLATestCase):
+ """Test cases for segment reduction ops."""
+
+ def UnsortedSegmentSum(self, data, indices, num_segments):
+ with self.test_session() as sess, self.test_scope():
+ d = array_ops.placeholder(data.dtype, shape=data.shape)
+ if isinstance(indices, int):
+ i = array_ops.placeholder(np.int32, shape=[])
+ else:
+ i = array_ops.placeholder(indices.dtype, shape=indices.shape)
+ return sess.run(
+ math_ops.unsorted_segment_sum(d, i, num_segments),
+ {d: data,
+ i: indices})
+
+ def testUnsortedSegmentSum0DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array(
+ [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
+ [0, 0, 0, 0, 0, 0]],
+ dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
+
+ def testUnsortedSegmentSum1DIndices1DData(self):
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array([1, 3, 2, 9], dtype=dtype),
+ self.UnsortedSegmentSum(
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
+
+ def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
+ num_segments = 10
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
+ [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
+ [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43],
+ [50, 51, 52, 53]],
+ dtype=dtype)
+ indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
+ num_segments = 4
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
+ [0, 0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum2DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
+ num_segments = 8
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[210, 211, 212], [110, 111, 112], [310, 311, 312],
+ [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
+ 302], [0, 0, 0]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSum1DIndices3DData(self):
+ for dtype in self.numeric_types:
+ data = np.array(
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
+ [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
+ [310, 311, 312]]],
+ dtype=dtype)
+ indices = np.array([3, 0, 2, 5], dtype=np.int32)
+ num_segments = 6
+ y = self.UnsortedSegmentSum(data, indices, num_segments)
+ self.assertAllClose(
+ np.array(
+ [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
+ [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]],
+ [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]],
+ dtype=dtype), y)
+
+ def testUnsortedSegmentSumShapeError(self):
+ for dtype in self.numeric_types:
+ data = np.ones((4, 8, 7), dtype=dtype)
+ indices = np.ones((3, 2), dtype=np.int32)
+ num_segments = 4
+ self.assertRaises(ValueError,
+ functools.partial(self.UnsortedSegmentSum, data,
+ indices, num_segments))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 35bc6b5a24..546e9be864 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -47,6 +47,7 @@ tf_kernel_library(
"reshape_op.cc",
"retval_op.cc",
"reverse_op.cc",
+ "segment_reduction_ops.cc",
"select_op.cc",
"sequence_ops.cc",
"shape_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
new file mode 100644
index 0000000000..6a0ce775dc
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sstream>
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+namespace {
+
+class UnsortedSegmentSum : public XlaOpKernel {
+ public:
+ explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // output = unsorted_segment_sum(data, indices, num_segments)
+ // Compute a tensor such that:
+ // output[i] = sum over {j where indices[j] == i} of data[j]
+ // output[i] == 0 if i does not appear in indices
+ //
+ // Contrast with segment_sum(), which assumes indices are sorted and that
+ // max(indices)+1 is the desired size of the output.
+ //
+ // The returned output tensor has the same type as data, and the same shape
+ // as data with the first indices.rank dimensions are replaced
+ // by a single dimension with size num_segments.
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ auto data = ctx->Input(0);
+ auto data_shape = ctx->InputShape(0);
+
+ auto indices = ctx->Input(1);
+ auto indices_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires that indices' rank be"
+ " less than or equal to data's rank."));
+ // Validate that indices.shape is a prefix of data.shape.
+ for (int d = 0; d < indices_shape.dims(); ++d) {
+ OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+ errors::InvalidArgument(
+ "UnsortedSegmentSum requires indices shape to be prefix"
+ " of data_shape, but dimension ",
+ d, " differs ", data_shape.dim_size(d), " vs. ",
+ indices_shape.dim_size(d)));
+ }
+
+ int64 num_segments;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
+
+ // Flatten the indices into 1-D.
+ auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()});
+
+ // flatten data for dynamic indexing.
+ int64 out_tensor_dims = data_shape.dims() - indices_shape.dims();
+ std::vector<int64> flat_shape(1 + out_tensor_dims);
+ flat_shape[0] = indices_shape.num_elements();
+ for (int64 k = 0; k < out_tensor_dims; ++k) {
+ flat_shape[1 + k] = data_shape.dim_size(indices_shape.dims() + k);
+ }
+ auto data_flat = builder->Reshape(data, flat_shape);
+
+ // output shape; same as data_shape, but dimension 0 is num_segments.
+ std::vector<int64> out_shape(flat_shape);
+ out_shape[0] = num_segments;
+
+ // Pad the output array dims to rank >= 3 to work around lowering issues.
+ // TODO(b/37575001) This is awkward, and could be improved.
+ int64 extra_dims = 0;
+ if (out_shape.size() < 3) {
+ extra_dims = 3u - out_shape.size();
+ }
+ std::vector<int64> rshape(extra_dims + out_shape.size(), 1);
+ for (unsigned k = 0; k < out_shape.size(); ++k) {
+ rshape[extra_dims + k] = out_shape[k];
+ }
+ auto output = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), rshape);
+
+ auto zero = builder->ConstantR1<int32>({0});
+
+ for (int64 i = 0; i < indices_shape.num_elements(); ++i) {
+ // output[indices[i]] += data[i]
+
+ std::vector<int64> data_start_indices(flat_shape.size());
+ data_start_indices[0] = i;
+ for (unsigned d = 1; d < flat_shape.size(); ++d) {
+ data_start_indices[d] = 0;
+ }
+ std::vector<int64> data_limit_indices(flat_shape);
+ data_limit_indices[0] = i + 1;
+ std::vector<int64> stride(flat_shape.size(), 1);
+
+ auto data_slice = builder->Slice(data_flat, data_start_indices,
+ data_limit_indices, stride);
+
+ // Reshape the sliced data into the R3+ shape to match output array.
+ std::vector<int64> rdata_shape(extra_dims + flat_shape.size());
+ for (int64 k = 0; k <= extra_dims; ++k) {
+ rdata_shape[k] = 1;
+ }
+ for (unsigned k = 1; k < data_limit_indices.size(); ++k) {
+ rdata_shape[extra_dims + k] = data_limit_indices[k];
+ }
+ auto rdata_slice = builder->Reshape(data_slice, rdata_shape);
+
+ auto index = builder->Slice(indices_1d, {i}, {i + 1}, {1});
+
+ // Construct the index into the R3+ output array 0, ..., <index>, 0, ...
+ std::vector<xla::ComputationDataHandle> out_start_index_parts(
+ extra_dims + flat_shape.size(), zero);
+ out_start_index_parts[extra_dims] = builder->Reshape(index, {1});
+ auto out_start_indices = builder->ConcatInDim(out_start_index_parts, 0);
+
+ std::vector<int64> slice_size(rshape);
+ slice_size[extra_dims] = 1;
+
+ auto out_slice =
+ builder->DynamicSlice(output, out_start_indices, slice_size);
+ auto sumval = builder->Add(out_slice, rdata_slice);
+ output = builder->DynamicUpdateSlice(output, sumval, out_start_indices);
+ }
+ auto reshaped_output = builder->Reshape(output, out_shape);
+ ctx->SetOutput(0, reshaped_output);
+ }
+
+ private:
+ DataType dtype_;
+};
+
+REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum);
+
+} // namespace
+} // namespace tensorflow