diff options
author | 2017-07-25 12:53:50 -0700 | |
---|---|---|
committer | 2017-07-25 12:58:31 -0700 | |
commit | b4c97bf13b618fbdc22981ce04f9faf358da034c (patch) | |
tree | ca129ae15fdcdee0eb9cd96232ef5b4d14490e89 | |
parent | 50d48d606c3b2c08eef249b6fe4f543a51ca8455 (diff) |
Implementation of UnsortedSegmentSum in tf2xla bridge.
PiperOrigin-RevId: 163109769
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/compiler/tests/segment_reduction_ops_test.py | 139 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc | 155 |
4 files changed, 309 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4f0137e8d9..c693f58f8b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -354,6 +354,20 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "segment_reduction_ops_test", + size = "small", + srcs = ["segment_reduction_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:math_ops_gen", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py new file mode 100644 index 0000000000..260a04421b --- /dev/null +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for segment reduction ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class SegmentReductionOpsTest(XLATestCase): + """Test cases for segment reduction ops.""" + + def UnsortedSegmentSum(self, data, indices, num_segments): + with self.test_session() as sess, self.test_scope(): + d = array_ops.placeholder(data.dtype, shape=data.shape) + if isinstance(indices, int): + i = array_ops.placeholder(np.int32, shape=[]) + else: + i = array_ops.placeholder(indices.dtype, shape=indices.shape) + return sess.run( + math_ops.unsorted_segment_sum(d, i, num_segments), + {d: data, + i: indices}) + + def testUnsortedSegmentSum0DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array( + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], + [0, 0, 0, 0, 0, 0]], + dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) + + def testUnsortedSegmentSum1DIndices1DData(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([1, 3, 2, 9], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) + num_segments = 10 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], + [40, 41, 42, 43], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], + [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): + for dtype in self.numeric_types: + data = np.array( + [[0, 1, 2, 3], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43], + [50, 51, 52, 53]], + dtype=dtype) + indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) + num_segments = 4 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], + [0, 0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum2DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) + num_segments = 8 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[210, 211, 212], [110, 111, 112], [310, 311, 312], + [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, + 302], [0, 0, 0]], + dtype=dtype), y) + + def testUnsortedSegmentSum1DIndices3DData(self): + for dtype in self.numeric_types: + data = np.array( + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], + [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], + [310, 311, 312]]], + dtype=dtype) + indices = np.array([3, 0, 2, 5], dtype=np.int32) + num_segments = 6 + y = self.UnsortedSegmentSum(data, indices, num_segments) + self.assertAllClose( + np.array( + [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], + [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], + [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], + dtype=dtype), y) + + def testUnsortedSegmentSumShapeError(self): + for dtype in self.numeric_types: + data = np.ones((4, 8, 7), dtype=dtype) + indices = np.ones((3, 2), dtype=np.int32) + num_segments = 4 + self.assertRaises(ValueError, + functools.partial(self.UnsortedSegmentSum, data, + indices, num_segments)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 35bc6b5a24..546e9be864 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -47,6 +47,7 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "segment_reduction_ops.cc", "select_op.cc", "sequence_ops.cc", "shape_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc new file mode 100644 index 0000000000..6a0ce775dc --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <sstream> +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace { + +class UnsortedSegmentSum : public XlaOpKernel { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // output = unsorted_segment_sum(data, indices, num_segments) + // Compute a tensor such that: + // output[i] = sum over {j where indices[j] == i} of data[j] + // output[i] == 0 if i does not appear in indices + // + // Contrast with segment_sum(), which assumes indices are sorted and that + // max(indices)+1 is the desired size of the output. + // + // The returned output tensor has the same type as data, and the same shape + // as data with the first indices.rank dimensions are replaced + // by a single dimension with size num_segments. + + xla::ComputationBuilder* builder = ctx->builder(); + + auto data = ctx->Input(0); + auto data_shape = ctx->InputShape(0); + + auto indices = ctx->Input(1); + auto indices_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), + errors::InvalidArgument( + "UnsortedSegmentSum requires that indices' rank be" + " less than or equal to data's rank.")); + // Validate that indices.shape is a prefix of data.shape. + for (int d = 0; d < indices_shape.dims(); ++d) { + OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument( + "UnsortedSegmentSum requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), " vs. ", + indices_shape.dim_size(d))); + } + + int64 num_segments; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); + + // Flatten the indices into 1-D. + auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()}); + + // flatten data for dynamic indexing. + int64 out_tensor_dims = data_shape.dims() - indices_shape.dims(); + std::vector<int64> flat_shape(1 + out_tensor_dims); + flat_shape[0] = indices_shape.num_elements(); + for (int64 k = 0; k < out_tensor_dims; ++k) { + flat_shape[1 + k] = data_shape.dim_size(indices_shape.dims() + k); + } + auto data_flat = builder->Reshape(data, flat_shape); + + // output shape; same as data_shape, but dimension 0 is num_segments. + std::vector<int64> out_shape(flat_shape); + out_shape[0] = num_segments; + + // Pad the output array dims to rank >= 3 to work around lowering issues. + // TODO(b/37575001) This is awkward, and could be improved. + int64 extra_dims = 0; + if (out_shape.size() < 3) { + extra_dims = 3u - out_shape.size(); + } + std::vector<int64> rshape(extra_dims + out_shape.size(), 1); + for (unsigned k = 0; k < out_shape.size(); ++k) { + rshape[extra_dims + k] = out_shape[k]; + } + auto output = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), rshape); + + auto zero = builder->ConstantR1<int32>({0}); + + for (int64 i = 0; i < indices_shape.num_elements(); ++i) { + // output[indices[i]] += data[i] + + std::vector<int64> data_start_indices(flat_shape.size()); + data_start_indices[0] = i; + for (unsigned d = 1; d < flat_shape.size(); ++d) { + data_start_indices[d] = 0; + } + std::vector<int64> data_limit_indices(flat_shape); + data_limit_indices[0] = i + 1; + std::vector<int64> stride(flat_shape.size(), 1); + + auto data_slice = builder->Slice(data_flat, data_start_indices, + data_limit_indices, stride); + + // Reshape the sliced data into the R3+ shape to match output array. + std::vector<int64> rdata_shape(extra_dims + flat_shape.size()); + for (int64 k = 0; k <= extra_dims; ++k) { + rdata_shape[k] = 1; + } + for (unsigned k = 1; k < data_limit_indices.size(); ++k) { + rdata_shape[extra_dims + k] = data_limit_indices[k]; + } + auto rdata_slice = builder->Reshape(data_slice, rdata_shape); + + auto index = builder->Slice(indices_1d, {i}, {i + 1}, {1}); + + // Construct the index into the R3+ output array 0, ..., <index>, 0, ... + std::vector<xla::ComputationDataHandle> out_start_index_parts( + extra_dims + flat_shape.size(), zero); + out_start_index_parts[extra_dims] = builder->Reshape(index, {1}); + auto out_start_indices = builder->ConcatInDim(out_start_index_parts, 0); + + std::vector<int64> slice_size(rshape); + slice_size[extra_dims] = 1; + + auto out_slice = + builder->DynamicSlice(output, out_start_indices, slice_size); + auto sumval = builder->Add(out_slice, rdata_slice); + output = builder->DynamicUpdateSlice(output, sumval, out_start_indices); + } + auto reshaped_output = builder->Reshape(output, out_shape); + ctx->SetOutput(0, reshaped_output); + } + + private: + DataType dtype_; +}; + +REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); + +} // namespace +} // namespace tensorflow |