diff options
author | Peter Hawkins <phawkins@google.com> | 2017-04-06 12:02:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-06 13:32:20 -0700 |
commit | 96bc32eab3b21192bfd065a5944c12269fe3f5a4 (patch) | |
tree | 9de1d371e57e700c69bb20de84e82a6c6a2bb800 | |
parent | b2e0dda6c6f92e4b14e567a508e1a5b5d475decc (diff) |
[TF:XLA] Implement BatchToSpace, BatchToSpaceND, SpaceToBatch, SpaceToBatchND.
Fix crashes in core implementations of the same operators for zero-sized blocks.
Change: 152416903
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 162 | ||||
-rw-r--r-- | tensorflow/compiler/tests/spacetobatch_op_test.py | 266 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/const_analysis.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc | 186 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc | 190 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/batchtospace_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/spacetobatch_op.cc | 4 |
11 files changed, 858 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 03e255e6b8..740a35c7ae 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -306,6 +306,19 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "spacetobatch_op_test", + size = "medium", + srcs = ["spacetobatch_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "ternary_ops_test", size = "small", srcs = ["ternary_ops_test.py"], diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 18c4e3dcb1..7d91594db0 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -218,12 +218,11 @@ class OpTest : public ::testing::Test { static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 20LL; - // Returns a random dimension size. + // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); // Returns a random shape. The tensor has rank in the range [min_rank, - // max_rank). - // Each dimension has size [0, kDefaultMaxDimensionSize]. + // max_rank). Each dimension has size [min_size, max_size). std::vector<int64> RandomDims(int min_rank = 0, int max_rank = kDefaultMaxRank, int64 min_size = 0, @@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; return; } + for (const Tensor& expected : expected_outputs) { + VLOG(1) << "Expected: " << expected.DebugString(); + } VLOG(1) << "Running test graph"; TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); @@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) { }); } +TEST_F(OpTest, BatchToSpace) { + Repeatedly([this]() { + const int num_block_dims = 2; + std::vector<int64> block_dims = + RandomDims(num_block_dims, num_block_dims, 0, 5); + int64 block_size = RandomDim(0, 4); + + std::vector<int64> input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_size; + input_dims[1 + i] = block_dims[i]; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector<int64> crop_vals; + std::uniform_int_distribution<int> distribution(0, 4); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(crops) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, BatchToSpaceND) { + Repeatedly([this]() { + std::vector<int64> block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector<int64> remaining_dims = RandomDims(0, 3); + std::vector<int64> block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_dims[i]; + } + std::copy(block_multipliers.begin(), block_multipliers.end(), + input_dims.begin() + 1); + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector<int64> crop_vals; + std::uniform_int_distribution<int> distribution(0, 3); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BatchToSpaceND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor<int32>( + std::vector<int32>(block_dims.begin(), block_dims.end()))) + .Input(crops) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, BiasAdd) { Repeatedly([this]() { auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); @@ -2036,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) { }); } +TEST_F(OpTest, SpaceToBatch) { + Repeatedly([this]() { + std::vector<int64> block_dims = RandomDims(4, 4, 0, 5); + const int num_block_dims = 2; + int64 block_size = RandomDim(0, 4); + + std::vector<int64> input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_size; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector<int64> padding_vals; + std::uniform_int_distribution<int> distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(paddings) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, SpaceToBatchND) { + Repeatedly([this]() { + std::vector<int64> block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector<int64> remaining_dims = RandomDims(0, 3); + std::vector<int64> block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_multipliers[i]; + } + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector<int64> padding_vals; + std::uniform_int_distribution<int> distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SpaceToBatchND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor<int32>( + std::vector<int32>(block_dims.begin(), block_dims.end()))) + .Input(paddings) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SparseMatMul) { Repeatedly([this]() { int64 x = RandomDim(); diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py new file mode 100644 index 0000000000..9c3b86c84b --- /dev/null +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -0,0 +1,266 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for SpaceToBatch and BatchToSpace ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import test + + +def space_to_batch_direct(input_array, block_shape, paddings): + """Direct Python implementation of space-to-batch conversion. + + This is used for tests only. + + Args: + input_array: N-D array + block_shape: 1-D array of shape [num_block_dims]. + paddings: 2-D array of shape [num_block_dims, 2]. + + Returns: + Converted tensor. + """ + input_array = np.array(input_array) + block_shape = np.array(block_shape) + num_block_dims = len(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + + padded = np.pad(input_array, + pad_width=([[0, 0]] + list(paddings) + [[0, 0]] * + (input_array.ndim - 1 - num_block_dims)), + mode="constant") + reshaped_padded_shape = [input_array.shape[0]] + output_shape = [input_array.shape[0] * np.prod(block_shape)] + for block_dim, block_shape_value in enumerate(block_shape): + reduced_size = padded.shape[block_dim + 1] // block_shape_value + reshaped_padded_shape.append(reduced_size) + output_shape.append(reduced_size) + reshaped_padded_shape.append(block_shape_value) + reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:]) + output_shape.extend(input_array.shape[num_block_dims + 1:]) + + reshaped_padded = padded.reshape(reshaped_padded_shape) + permuted_reshaped_padded = np.transpose(reshaped_padded, ( + list(np.arange(num_block_dims) * 2 + 2) + [0] + + list(np.arange(num_block_dims) * 2 + 1) + list( + np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims + * 2))) + return permuted_reshaped_padded.reshape(output_shape) + + +class SpaceToBatchTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" + + def _testPad(self, inputs, paddings, block_size, outputs): + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + # outputs = space_to_batch(inputs) + placeholder = array_ops.placeholder(dtype) + x_tf = gen_array_ops._space_to_batch( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + x_tf = gen_array_ops._batch_to_space( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testOne(self, inputs, block_size, outputs): + paddings = np.zeros((2, 2), dtype=np.int32) + self._testPad(inputs, paddings, block_size, outputs) + + # [1, 2, 2, 1] <-> [4, 1, 1, 1] + def testSmallInput2x2(self): + x_np = [[[[1], [2]], [[3], [4]]]] + block_size = 2 + x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + self._testOne(x_np, block_size, x_out) + + # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1] + def testSmallInput2x2Pad1x0(self): + x_np = [[[[1], [2]], [[3], [4]]]] + paddings = np.array([[1, 0], [1, 0]], dtype=np.int32) + block_size = 3 + x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]], + [[[3]]], [[[4]]]] + self._testPad(x_np, paddings, block_size, x_out) + + # Test with depth larger than 1. + # [1, 2, 2, 3] <-> [4, 1, 1, 3] + def testDepthInput2x2(self): + x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] + block_size = 2 + x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + self._testOne(x_np, block_size, x_out) + + # Test for larger input dimensions. + # [1, 4, 4, 1] <-> [4, 2, 2, 1] + def testLargerInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]], + [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Test with batch larger than 1. + # [2, 2, 4, 1] <-> [8, 1, 2, 1] + def testBatchInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Tests for larger input spatial dimensions AND batch larger than 1, to ensure + # that elements are correctly laid out spatially and properly interleaved + # along the batch dimension. + # [2, 4, 4, 1] <-> [8, 2, 2, 1] + def testLargerInputBatch2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]], + [[[17], [18], [19], [20]], [[21], [22], [23], [24]], + [[25], [26], [27], [28]], [[29], [30], [31], [32]]]] + x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]], + [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]], + [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]], + [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]] + block_size = 2 + self._testOne(x_np, block_size, x_out) + + +class SpaceToBatchNDTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" + + def _testPad(self, inputs, block_shape, paddings, outputs): + block_shape = np.array(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + placeholder = array_ops.placeholder(dtype) + # outputs = space_to_batch(inputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + placeholder = array_ops.placeholder(dtype) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testDirect(self, input_shape, block_shape, paddings): + inputs = np.arange(np.prod(input_shape), dtype=np.float32) + inputs = inputs.reshape(input_shape) + self._testPad(inputs, block_shape, paddings, + space_to_batch_direct(inputs, block_shape, paddings)) + + def testZeroBlockDimsZeroRemainingDims(self): + self._testPad( + inputs=[1, 2], + block_shape=[], + paddings=[], + outputs=[1, 2],) + + def testZeroBlockDimsOneRemainingDim(self): + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[], + paddings=[], + outputs=[[1, 2], [3, 4]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[1, 2], [3, 4]]) + + def testZeroBlockDimsTwoRemainingDims(self): + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[], + paddings=[], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with two no-op block dims. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1, 1], + paddings=[[0, 0], [0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + def testOneBlockDimZeroRemainingDims(self): + self._testPad( + inputs=[[1, 2, 3], [4, 5, 6]], + block_shape=[2], + paddings=[1, 0], + outputs=[[0, 2], [0, 5], [1, 3], [4, 6]]) + + def testOneBlockDimOneRemainingDim(self): + self._testPad( + inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]], + block_shape=[2], + paddings=[1, 0], + outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], + [[4, 41], [6, 61]]]) + + def testDirect(self): + # Test with zero-size remaining dimension. + self._testDirect( + input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + + # Test with zero-size blocked dimension. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + + # Test with padding up from zero size. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2], + paddings=[[1, 2], [0, 0], [3, 0]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2, 2], + paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2, 1], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 53aa749a0a..44ff13ca34 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Any", "reduction_indices"}, {"ArgMax", "dimension"}, {"AvgPoolGrad", "orig_input_shape"}, + {"BatchToSpace", "crops"}, + {"BatchToSpaceND", "block_shape"}, + {"BatchToSpaceND", "crops"}, {"BroadcastGradientArgs", "s0"}, {"BroadcastGradientArgs", "s1"}, {"Concat", "concat_dim"}, @@ -69,6 +72,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"ReverseV2", "axis"}, {"Slice", "begin"}, {"Slice", "size"}, + {"SpaceToBatch", "paddings"}, + {"SpaceToBatchND", "block_shape"}, + {"SpaceToBatchND", "paddings"}, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e4f73b529f..14d2a72f7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -15,6 +15,7 @@ tf_kernel_library( srcs = [ "aggregate_ops.cc", "batch_matmul_op.cc", + "batchtospace_op.cc", "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", @@ -50,6 +51,7 @@ tf_kernel_library( "shape_op.cc", "slice_op.cc", "softmax_op.cc", + "spacetobatch_op.cc", "split_op.cc", "strided_slice_op.cc", "tile_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc new file mode 100644 index 0000000000..eb4bd47ee5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void BatchToSpace(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice<int64> block_shape, + const xla::Literal& crops) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector<int64, 4> input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice<int64> remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(crops.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), + errors::InvalidArgument("crops should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(crops.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + const int64 batch_size = input_shape[0]; + + // Compute the product of the block_shape values. + int64 block_num_elems = 1; + for (int i = 0; i < block_rank; ++i) { + block_num_elems *= block_shape[i]; + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + // 1. Reshape `input` to `reshaped` of shape: + // [block_shape[0], ..., block_shape[M-1], + // batch / prod(block_shape), + // input_shape[1], ..., input_shape[N-1]] + + OP_REQUIRES( + ctx, batch_size % block_num_elems == 0, + errors::InvalidArgument("Input batch dimension (", batch_size, + ") is not divisible by product of block sizes (", + block_num_elems, ")")); + std::vector<int64> reshaped_shape(input_rank + block_rank); + std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin()); + reshaped_shape[block_rank] = batch_size / block_num_elems; + std::copy(input_shape.begin() + 1, input_shape.end(), + reshaped_shape.begin() + block_rank + 1); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + + // 2. Permute dimensions of `reshaped` to produce `permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1], block_shape[0], + // ..., + // input_shape[M], block_shape[M-1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector<int64> permutation(reshaped_shape.size()); + permutation[0] = block_rank; + for (int i = 0; i < block_rank; ++i) { + permutation[1 + 2 * i] = block_rank + 1 + i; + permutation[1 + 2 * i + 1] = i; + } + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + + // 3. Reshape `permuted` to produce `reshaped_permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0], + // ..., + // input_shape[M] * block_shape[M-1], + // + // input_shape[M+1], + // ..., + // input_shape[N-1]] + std::vector<int64> reshaped_permuted_shape(input_rank); + reshaped_permuted_shape[0] = batch_size / block_num_elems; + for (int i = 0; i < block_rank; ++i) { + reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_permuted_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle reshaped_permuted = + b->Reshape(permuted, reshaped_permuted_shape); + + // 4. Crop the start and end of dimensions `[1, ..., M]` of + // `reshaped_permuted` according to `crops` to produce the output of shape: + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + // ..., + // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector<int64> start_indices(input_rank, 0); + std::vector<int64> end_indices = reshaped_permuted_shape; + for (int i = 0; i < block_rank; ++i) { + int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0}); + int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1}); + OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, + errors::InvalidArgument("Crops must be non-negative")); + start_indices[1 + i] = crop_start; + end_indices[1 + i] -= crop_end; + OP_REQUIRES( + ctx, start_indices[1 + i] <= end_indices[1 + i], + errors::InvalidArgument( + "Cropped size must be non-negative: start: ", crop_start, + " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); + } + xla::ComputationDataHandle output = + b->Slice(reshaped_permuted, start_indices, end_indices); + ctx->SetOutput(0, output); +} + +class BatchToSpaceNDOp : public XlaOpKernel { + public: + explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector<int64> block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, crops); + } +}; +REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); + +class BatchToSpaceOp : public XlaOpKernel { + public: + explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, crops); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc new file mode 100644 index 0000000000..f15b354cb2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void SpaceToBatch(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice<int64> block_shape, + const xla::Literal& paddings) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector<int64, 4> input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice<int64> remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(paddings.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), + errors::InvalidArgument("paddings should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(paddings.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + + // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the + // input according to `paddings` to produce `padded` of shape `padded_shape`. + xla::PaddingConfig padding_config; + std::vector<int64> padded_shape(input_shape.begin(), input_shape.end()); + int64 block_num_elems = 1LL; + padding_config.add_dimensions(); // Don't pad the batch dimension. + for (int i = 0; i < block_rank; ++i) { + auto* dim = padding_config.add_dimensions(); + int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0}); + int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1}); + OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, + errors::InvalidArgument("Paddings must be non-negative")); + dim->set_edge_padding_low(pad_start); + dim->set_edge_padding_high(pad_end); + padded_shape[1 + i] += pad_start + pad_end; + block_num_elems *= block_shape[i]; + } + // Don't pad the remainder dimensions. + for (int i = 0; i < remainder_shape.size(); ++i) { + padding_config.add_dimensions(); + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + xla::ComputationDataHandle padded = + b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + + // 2. Reshape `padded` to `reshaped_padded` of shape: + // + // [batch] + + // [padded_shape[1] / block_shape[0], + // block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1], + // block_shape[M-1]] + + // remaining_shape + const int64 batch_size = input_shape[0]; + std::vector<int64> reshaped_padded_shape(input_rank + block_rank); + reshaped_padded_shape[0] = batch_size; + for (int i = 0; i < block_rank; ++i) { + OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0, + errors::InvalidArgument("padded_shape[", 1 + i, + "]=", padded_shape[1 + i], + " is not divisible by block_shape[", i, + "]=", block_shape[i])); + + reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i]; + reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_padded_shape.begin() + 1 + 2 * block_rank); + + xla::ComputationDataHandle reshaped_padded = + b->Reshape(padded, reshaped_padded_shape); + + // 3. Permute dimensions of `reshaped_padded` to produce + // `permuted_reshaped_padded` of shape: + // + // block_shape + + // [batch] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + std::vector<int64> permutation(reshaped_padded_shape.size()); + for (int i = 0; i < block_rank; ++i) { + permutation[i] = 1 + 2 * i + 1; + permutation[block_rank + 1 + i] = 1 + 2 * i; + } + permutation[block_rank] = 0; + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted_reshaped_padded = + b->Transpose(reshaped_padded, permutation); + + // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch * prod(block_shape)] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + // Determine the length of the prefix of block dims that can be combined + // into the batch dimension due to having no padding and block_shape=1. + std::vector<int64> output_shape(input_rank); + output_shape[0] = batch_size * block_num_elems; + for (int i = 0; i < block_rank; ++i) { + output_shape[1 + i] = padded_shape[1 + i] / block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + output_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped_padded, output_shape); + ctx->SetOutput(0, output); +} + +class SpaceToBatchNDOp : public XlaOpKernel { + public: + explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector<int64> block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, paddings); + } +}; +REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); + +class SpaceToBatchOp : public XlaOpKernel { + public: + explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, paddings); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 53dcdec7a2..a022de36a2 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, + xla::Literal* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + out->Clear(); + *out->mutable_shape() = literal.shape(); + out->mutable_shape()->set_element_type(xla::S64); + for (int32 x : literal.s32s()) { + out->add_s64s(x); + } + return Status::OK(); + + case xla::S64: + out->Swap(&literal); + return Status::OK(); + + default: + return errors::InvalidArgument( + "Invalid argument to ConstantInputAsInt64Literal: ", + xla::ShapeUtil::HumanString(literal.shape())); + } +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 60e3b59d32..f97e07bea5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -110,6 +110,9 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector<int64>* out); + // Converts a constant int32 or int64 Tensor into an xla int64 Literal. + Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc index b24a834083..99b5d3daaa 100644 --- a/tensorflow/core/kernels/batchtospace_op.cc +++ b/tensorflow/core/kernels/batchtospace_op.cc @@ -97,6 +97,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context, for (int block_dim = 0; block_dim < block_dims; ++block_dim) { block_shape_product *= block_shape[block_dim]; } + OP_REQUIRES( + context, block_shape_product > 0, + errors::InvalidArgument("Product of block sizes must be positive, got ", + block_shape_product)); const int64 orig_input_batch_size = orig_input_tensor.dim_size(0); OP_REQUIRES( diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index 3815716ccd..c513683918 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -100,6 +100,10 @@ void SpaceToBatchOpCompute(OpKernelContext* context, for (int block_dim = 0; block_dim < block_dims; ++block_dim) { block_shape_product *= block_shape[block_dim]; } + OP_REQUIRES( + context, block_shape_product > 0, + errors::InvalidArgument("Product of block sizes must be positive, got ", + block_shape_product)); const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims; |