aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-04-06 12:02:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-06 13:32:20 -0700
commit96bc32eab3b21192bfd065a5944c12269fe3f5a4 (patch)
tree9de1d371e57e700c69bb20de84e82a6c6a2bb800
parentb2e0dda6c6f92e4b14e567a508e1a5b5d475decc (diff)
[TF:XLA] Implement BatchToSpace, BatchToSpaceND, SpaceToBatch, SpaceToBatchND.
Fix crashes in core implementations of the same operators for zero-sized blocks. Change: 152416903
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc162
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py266
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc186
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc190
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc25
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h3
-rw-r--r--tensorflow/core/kernels/batchtospace_op.cc4
-rw-r--r--tensorflow/core/kernels/spacetobatch_op.cc4
11 files changed, 858 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 03e255e6b8..740a35c7ae 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -306,6 +306,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "spacetobatch_op_test",
+ size = "medium",
+ srcs = ["spacetobatch_op_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 18c4e3dcb1..7d91594db0 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -218,12 +218,11 @@ class OpTest : public ::testing::Test {
static constexpr int kDefaultMaxRank = 5;
static constexpr int64 kDefaultMaxDimensionSize = 20LL;
- // Returns a random dimension size.
+ // Returns a random dimension size, in the range [min, max).
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
// Returns a random shape. The tensor has rank in the range [min_rank,
- // max_rank).
- // Each dimension has size [0, kDefaultMaxDimensionSize].
+ // max_rank). Each dimension has size [min_size, max_size).
std::vector<int64> RandomDims(int min_rank = 0,
int max_rank = kDefaultMaxRank,
int64 min_size = 0,
@@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test";
return;
}
+ for (const Tensor& expected : expected_outputs) {
+ VLOG(1) << "Expected: " << expected.DebugString();
+ }
VLOG(1) << "Running test graph";
TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs));
@@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) {
});
}
+TEST_F(OpTest, BatchToSpace) {
+ Repeatedly([this]() {
+ const int num_block_dims = 2;
+ std::vector<int64> block_dims =
+ RandomDims(num_block_dims, num_block_dims, 0, 5);
+ int64 block_size = RandomDim(0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + 1);
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[0] *= block_size;
+ input_dims[1 + i] = block_dims[i];
+ }
+ input_dims[1 + num_block_dims] = RandomDim();
+
+ std::vector<int64> crop_vals;
+ std::uniform_int_distribution<int> distribution(0, 4);
+ for (int i = 0; i < num_block_dims; ++i) {
+ // Chooses crop values; does not always choose legal values.
+ crop_vals.push_back(distribution(generator()));
+ crop_vals.push_back(distribution(generator()));
+ }
+ Tensor crops;
+ CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(crops)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
+ });
+}
+
+TEST_F(OpTest, BatchToSpaceND) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
+ int num_block_dims = block_dims.size();
+ std::vector<int64> remaining_dims = RandomDims(0, 3);
+ std::vector<int64> block_multipliers =
+ RandomDims(block_dims.size(), block_dims.size(), 0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[0] *= block_dims[i];
+ }
+ std::copy(block_multipliers.begin(), block_multipliers.end(),
+ input_dims.begin() + 1);
+ std::copy(remaining_dims.begin(), remaining_dims.end(),
+ input_dims.begin() + 1 + num_block_dims);
+
+ std::vector<int64> crop_vals;
+ std::uniform_int_distribution<int> distribution(0, 3);
+ for (int i = 0; i < num_block_dims; ++i) {
+ // Chooses crop values; does not always choose legal values.
+ crop_vals.push_back(distribution(generator()));
+ crop_vals.push_back(distribution(generator()));
+ }
+ Tensor crops;
+ CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BatchToSpaceND")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(test::AsTensor<int32>(
+ std::vector<int32>(block_dims.begin(), block_dims.end())))
+ .Input(crops)
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, BiasAdd) {
Repeatedly([this]() {
auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
@@ -2036,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) {
});
}
+TEST_F(OpTest, SpaceToBatch) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
+ const int num_block_dims = 2;
+ int64 block_size = RandomDim(0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + 1);
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[1 + i] = block_dims[i] * block_size;
+ }
+ input_dims[1 + num_block_dims] = RandomDim();
+
+ std::vector<int64> padding_vals;
+ std::uniform_int_distribution<int> distribution(0, 7);
+ for (int i = 0; i < num_block_dims; ++i) {
+ int64 pad_before;
+ int64 pad_after;
+ do {
+ pad_before = distribution(generator());
+ pad_after = distribution(generator());
+ } while (pad_before + pad_after > input_dims[1 + i]);
+ input_dims[1 + i] -= pad_before + pad_after;
+ padding_vals.push_back(pad_before);
+ padding_vals.push_back(pad_after);
+ }
+ Tensor paddings;
+ CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(paddings)
+ .Attr("T", DT_FLOAT)
+ .Attr("block_size", block_size));
+ });
+}
+
+TEST_F(OpTest, SpaceToBatchND) {
+ Repeatedly([this]() {
+ std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
+ int num_block_dims = block_dims.size();
+ std::vector<int64> remaining_dims = RandomDims(0, 3);
+ std::vector<int64> block_multipliers =
+ RandomDims(block_dims.size(), block_dims.size(), 0, 4);
+
+ std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
+ input_dims[0] = RandomDim();
+ for (int i = 0; i < num_block_dims; ++i) {
+ input_dims[1 + i] = block_dims[i] * block_multipliers[i];
+ }
+ std::copy(remaining_dims.begin(), remaining_dims.end(),
+ input_dims.begin() + 1 + num_block_dims);
+
+ std::vector<int64> padding_vals;
+ std::uniform_int_distribution<int> distribution(0, 7);
+ for (int i = 0; i < num_block_dims; ++i) {
+ int64 pad_before;
+ int64 pad_after;
+ do {
+ pad_before = distribution(generator());
+ pad_after = distribution(generator());
+ } while (pad_before + pad_after > input_dims[1 + i]);
+ input_dims[1 + i] -= pad_before + pad_after;
+ padding_vals.push_back(pad_before);
+ padding_vals.push_back(pad_after);
+ }
+ Tensor paddings;
+ CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
+ TensorShape({num_block_dims, 2})));
+
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("SpaceToBatchND")
+ .Input(RandomTensor(DT_FLOAT, input_dims))
+ .Input(test::AsTensor<int32>(
+ std::vector<int32>(block_dims.begin(), block_dims.end())))
+ .Input(paddings)
+ .Attr("T", DT_FLOAT));
+ });
+}
+
TEST_F(OpTest, SparseMatMul) {
Repeatedly([this]() {
int64 x = RandomDim();
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
new file mode 100644
index 0000000000..9c3b86c84b
--- /dev/null
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -0,0 +1,266 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for SpaceToBatch and BatchToSpace ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.platform import test
+
+
+def space_to_batch_direct(input_array, block_shape, paddings):
+ """Direct Python implementation of space-to-batch conversion.
+
+ This is used for tests only.
+
+ Args:
+ input_array: N-D array
+ block_shape: 1-D array of shape [num_block_dims].
+ paddings: 2-D array of shape [num_block_dims, 2].
+
+ Returns:
+ Converted tensor.
+ """
+ input_array = np.array(input_array)
+ block_shape = np.array(block_shape)
+ num_block_dims = len(block_shape)
+ paddings = np.array(paddings).reshape((len(block_shape), 2))
+
+ padded = np.pad(input_array,
+ pad_width=([[0, 0]] + list(paddings) + [[0, 0]] *
+ (input_array.ndim - 1 - num_block_dims)),
+ mode="constant")
+ reshaped_padded_shape = [input_array.shape[0]]
+ output_shape = [input_array.shape[0] * np.prod(block_shape)]
+ for block_dim, block_shape_value in enumerate(block_shape):
+ reduced_size = padded.shape[block_dim + 1] // block_shape_value
+ reshaped_padded_shape.append(reduced_size)
+ output_shape.append(reduced_size)
+ reshaped_padded_shape.append(block_shape_value)
+ reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:])
+ output_shape.extend(input_array.shape[num_block_dims + 1:])
+
+ reshaped_padded = padded.reshape(reshaped_padded_shape)
+ permuted_reshaped_padded = np.transpose(reshaped_padded, (
+ list(np.arange(num_block_dims) * 2 + 2) + [0] +
+ list(np.arange(num_block_dims) * 2 + 1) + list(
+ np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims
+ * 2)))
+ return permuted_reshaped_padded.reshape(output_shape)
+
+
+class SpaceToBatchTest(XLATestCase):
+ """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
+
+ def _testPad(self, inputs, paddings, block_size, outputs):
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self.float_types:
+ # outputs = space_to_batch(inputs)
+ placeholder = array_ops.placeholder(dtype)
+ x_tf = gen_array_ops._space_to_batch(
+ placeholder, paddings, block_size=block_size)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
+ # inputs = batch_to_space(outputs)
+ x_tf = gen_array_ops._batch_to_space(
+ placeholder, paddings, block_size=block_size)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
+
+ def _testOne(self, inputs, block_size, outputs):
+ paddings = np.zeros((2, 2), dtype=np.int32)
+ self._testPad(inputs, paddings, block_size, outputs)
+
+ # [1, 2, 2, 1] <-> [4, 1, 1, 1]
+ def testSmallInput2x2(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ block_size = 2
+ x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1]
+ def testSmallInput2x2Pad1x0(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ paddings = np.array([[1, 0], [1, 0]], dtype=np.int32)
+ block_size = 3
+ x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]],
+ [[[3]]], [[[4]]]]
+ self._testPad(x_np, paddings, block_size, x_out)
+
+ # Test with depth larger than 1.
+ # [1, 2, 2, 3] <-> [4, 1, 1, 3]
+ def testDepthInput2x2(self):
+ x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]
+ block_size = 2
+ x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Test for larger input dimensions.
+ # [1, 4, 4, 1] <-> [4, 2, 2, 1]
+ def testLargerInput2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
+ block_size = 2
+ x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]],
+ [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Test with batch larger than 1.
+ # [2, 2, 4, 1] <-> [8, 1, 2, 1]
+ def testBatchInput2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
+ [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]]
+ block_size = 2
+ x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]],
+ [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ # Tests for larger input spatial dimensions AND batch larger than 1, to ensure
+ # that elements are correctly laid out spatially and properly interleaved
+ # along the batch dimension.
+ # [2, 4, 4, 1] <-> [8, 2, 2, 1]
+ def testLargerInputBatch2x2(self):
+ x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]], [[13], [14], [15], [16]]],
+ [[[17], [18], [19], [20]], [[21], [22], [23], [24]],
+ [[25], [26], [27], [28]], [[29], [30], [31], [32]]]]
+ x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]],
+ [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]],
+ [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]],
+ [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]]
+ block_size = 2
+ self._testOne(x_np, block_size, x_out)
+
+
+class SpaceToBatchNDTest(XLATestCase):
+ """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
+
+ def _testPad(self, inputs, block_shape, paddings, outputs):
+ block_shape = np.array(block_shape)
+ paddings = np.array(paddings).reshape((len(block_shape), 2))
+ with self.test_session() as sess, self.test_scope():
+ for dtype in self.float_types:
+ placeholder = array_ops.placeholder(dtype)
+ # outputs = space_to_batch(inputs)
+ x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs)
+ # inputs = batch_to_space(outputs)
+ placeholder = array_ops.placeholder(dtype)
+ x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings)
+ self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs)
+
+ def _testDirect(self, input_shape, block_shape, paddings):
+ inputs = np.arange(np.prod(input_shape), dtype=np.float32)
+ inputs = inputs.reshape(input_shape)
+ self._testPad(inputs, block_shape, paddings,
+ space_to_batch_direct(inputs, block_shape, paddings))
+
+ def testZeroBlockDimsZeroRemainingDims(self):
+ self._testPad(
+ inputs=[1, 2],
+ block_shape=[],
+ paddings=[],
+ outputs=[1, 2],)
+
+ def testZeroBlockDimsOneRemainingDim(self):
+ self._testPad(
+ inputs=[[1, 2], [3, 4]],
+ block_shape=[],
+ paddings=[],
+ outputs=[[1, 2], [3, 4]])
+
+ # Same thing, but with a no-op block dim.
+ self._testPad(
+ inputs=[[1, 2], [3, 4]],
+ block_shape=[1],
+ paddings=[[0, 0]],
+ outputs=[[1, 2], [3, 4]])
+
+ def testZeroBlockDimsTwoRemainingDims(self):
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[],
+ paddings=[],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ # Same thing, but with a no-op block dim.
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[1],
+ paddings=[[0, 0]],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ # Same thing, but with two no-op block dims.
+ self._testPad(
+ inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
+ block_shape=[1, 1],
+ paddings=[[0, 0], [0, 0]],
+ outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+
+ def testOneBlockDimZeroRemainingDims(self):
+ self._testPad(
+ inputs=[[1, 2, 3], [4, 5, 6]],
+ block_shape=[2],
+ paddings=[1, 0],
+ outputs=[[0, 2], [0, 5], [1, 3], [4, 6]])
+
+ def testOneBlockDimOneRemainingDim(self):
+ self._testPad(
+ inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]],
+ block_shape=[2],
+ paddings=[1, 0],
+ outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]],
+ [[4, 41], [6, 61]]])
+
+ def testDirect(self):
+ # Test with zero-size remaining dimension.
+ self._testDirect(
+ input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]])
+
+ # Test with zero-size blocked dimension.
+ self._testDirect(
+ input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]])
+
+ # Test with padding up from zero size.
+ self._testDirect(
+ input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]])
+
+ self._testDirect(
+ input_shape=[3, 3, 4, 5, 2],
+ block_shape=[3, 4, 2],
+ paddings=[[1, 2], [0, 0], [3, 0]])
+
+ self._testDirect(
+ input_shape=[3, 3, 4, 5, 2],
+ block_shape=[3, 4, 2, 2],
+ paddings=[[1, 2], [0, 0], [3, 0], [0, 0]])
+
+ self._testDirect(
+ input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
+ block_shape=[1, 1, 3, 4, 2, 2],
+ paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]])
+
+ self._testDirect(
+ input_shape=[3, 2, 2, 3, 4, 5, 2, 5],
+ block_shape=[1, 1, 3, 4, 2, 2, 1],
+ paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 53aa749a0a..44ff13ca34 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Any", "reduction_indices"},
{"ArgMax", "dimension"},
{"AvgPoolGrad", "orig_input_shape"},
+ {"BatchToSpace", "crops"},
+ {"BatchToSpaceND", "block_shape"},
+ {"BatchToSpaceND", "crops"},
{"BroadcastGradientArgs", "s0"},
{"BroadcastGradientArgs", "s1"},
{"Concat", "concat_dim"},
@@ -69,6 +72,9 @@ Status BackwardsConstAnalysis(const Graph& g,
{"ReverseV2", "axis"},
{"Slice", "begin"},
{"Slice", "size"},
+ {"SpaceToBatch", "paddings"},
+ {"SpaceToBatchND", "block_shape"},
+ {"SpaceToBatchND", "paddings"},
{"Split", "split_dim"},
{"SplitV", "split_dim"},
{"SplitV", "size_splits"},
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index e4f73b529f..14d2a72f7c 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -15,6 +15,7 @@ tf_kernel_library(
srcs = [
"aggregate_ops.cc",
"batch_matmul_op.cc",
+ "batchtospace_op.cc",
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
@@ -50,6 +51,7 @@ tf_kernel_library(
"shape_op.cc",
"slice_op.cc",
"softmax_op.cc",
+ "spacetobatch_op.cc",
"split_op.cc",
"strided_slice_op.cc",
"tile_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
new file mode 100644
index 0000000000..eb4bd47ee5
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+void BatchToSpace(XlaOpKernelContext* ctx,
+ const xla::ComputationDataHandle& input, DataType input_dtype,
+ const TensorShape& input_tensor_shape,
+ gtl::ArraySlice<int64> block_shape,
+ const xla::Literal& crops) {
+ const int input_rank = input_tensor_shape.dims();
+ const gtl::InlinedVector<int64, 4> input_shape =
+ input_tensor_shape.dim_sizes();
+ const int block_rank = block_shape.size();
+
+ OP_REQUIRES(
+ ctx, input_rank >= 1 + block_rank,
+ errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
+ " instead of ", input_rank));
+ gtl::ArraySlice<int64> remainder_shape(input_shape);
+ remainder_shape.remove_prefix(1 + block_rank);
+
+ OP_REQUIRES(
+ ctx,
+ xla::ShapeUtil::Rank(crops.shape()) == 2 &&
+ block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
+ 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
+ errors::InvalidArgument("crops should have shape [", block_rank,
+ ", 2] instead of ",
+ xla::ShapeUtil::HumanString(crops.shape())));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ const int64 batch_size = input_shape[0];
+
+ // Compute the product of the block_shape values.
+ int64 block_num_elems = 1;
+ for (int i = 0; i < block_rank; ++i) {
+ block_num_elems *= block_shape[i];
+ }
+ OP_REQUIRES(ctx, block_num_elems > 0,
+ errors::InvalidArgument(
+ "The product of the block dimensions must be positive"));
+
+ // 1. Reshape `input` to `reshaped` of shape:
+ // [block_shape[0], ..., block_shape[M-1],
+ // batch / prod(block_shape),
+ // input_shape[1], ..., input_shape[N-1]]
+
+ OP_REQUIRES(
+ ctx, batch_size % block_num_elems == 0,
+ errors::InvalidArgument("Input batch dimension (", batch_size,
+ ") is not divisible by product of block sizes (",
+ block_num_elems, ")"));
+ std::vector<int64> reshaped_shape(input_rank + block_rank);
+ std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin());
+ reshaped_shape[block_rank] = batch_size / block_num_elems;
+ std::copy(input_shape.begin() + 1, input_shape.end(),
+ reshaped_shape.begin() + block_rank + 1);
+ xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape);
+
+ // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1], block_shape[0],
+ // ...,
+ // input_shape[M], block_shape[M-1],
+ //
+ // input_shape[M+1], ..., input_shape[N-1]]
+ std::vector<int64> permutation(reshaped_shape.size());
+ permutation[0] = block_rank;
+ for (int i = 0; i < block_rank; ++i) {
+ permutation[1 + 2 * i] = block_rank + 1 + i;
+ permutation[1 + 2 * i + 1] = i;
+ }
+ std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
+ 1 + block_rank * 2);
+ xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation);
+
+ // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1] * block_shape[0],
+ // ...,
+ // input_shape[M] * block_shape[M-1],
+ //
+ // input_shape[M+1],
+ // ...,
+ // input_shape[N-1]]
+ std::vector<int64> reshaped_permuted_shape(input_rank);
+ reshaped_permuted_shape[0] = batch_size / block_num_elems;
+ for (int i = 0; i < block_rank; ++i) {
+ reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ reshaped_permuted_shape.begin() + 1 + block_rank);
+
+ xla::ComputationDataHandle reshaped_permuted =
+ b->Reshape(permuted, reshaped_permuted_shape);
+
+ // 4. Crop the start and end of dimensions `[1, ..., M]` of
+ // `reshaped_permuted` according to `crops` to produce the output of shape:
+ // [batch / prod(block_shape),
+ //
+ // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
+ // ...,
+ // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
+ //
+ // input_shape[M+1], ..., input_shape[N-1]]
+ std::vector<int64> start_indices(input_rank, 0);
+ std::vector<int64> end_indices = reshaped_permuted_shape;
+ for (int i = 0; i < block_rank; ++i) {
+ int64 crop_start = xla::LiteralUtil::Get<int64>(crops, {i, 0});
+ int64 crop_end = xla::LiteralUtil::Get<int64>(crops, {i, 1});
+ OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
+ errors::InvalidArgument("Crops must be non-negative"));
+ start_indices[1 + i] = crop_start;
+ end_indices[1 + i] -= crop_end;
+ OP_REQUIRES(
+ ctx, start_indices[1 + i] <= end_indices[1 + i],
+ errors::InvalidArgument(
+ "Cropped size must be non-negative: start: ", crop_start,
+ " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
+ }
+ xla::ComputationDataHandle output =
+ b->Slice(reshaped_permuted, start_indices, end_indices);
+ ctx->SetOutput(0, output);
+}
+
+class BatchToSpaceNDOp : public XlaOpKernel {
+ public:
+ explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> block_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
+
+ xla::Literal crops;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops));
+
+ BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ block_shape, crops);
+ }
+};
+REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp);
+
+class BatchToSpaceOp : public XlaOpKernel {
+ public:
+ explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
+ OP_REQUIRES(
+ ctx, block_size_ > 1,
+ errors::InvalidArgument("Block size should be > 1: ", block_size_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::Literal crops;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops));
+
+ BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ {block_size_, block_size_}, crops);
+ }
+
+ private:
+ int block_size_;
+};
+REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
new file mode 100644
index 0000000000..f15b354cb2
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -0,0 +1,190 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+void SpaceToBatch(XlaOpKernelContext* ctx,
+ const xla::ComputationDataHandle& input, DataType input_dtype,
+ const TensorShape& input_tensor_shape,
+ gtl::ArraySlice<int64> block_shape,
+ const xla::Literal& paddings) {
+ const int input_rank = input_tensor_shape.dims();
+ const gtl::InlinedVector<int64, 4> input_shape =
+ input_tensor_shape.dim_sizes();
+ const int block_rank = block_shape.size();
+
+ OP_REQUIRES(
+ ctx, input_rank >= 1 + block_rank,
+ errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
+ " instead of ", input_rank));
+ gtl::ArraySlice<int64> remainder_shape(input_shape);
+ remainder_shape.remove_prefix(1 + block_rank);
+
+ OP_REQUIRES(
+ ctx,
+ xla::ShapeUtil::Rank(paddings.shape()) == 2 &&
+ block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
+ 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
+ errors::InvalidArgument("paddings should have shape [", block_rank,
+ ", 2] instead of ",
+ xla::ShapeUtil::HumanString(paddings.shape())));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
+ // input according to `paddings` to produce `padded` of shape `padded_shape`.
+ xla::PaddingConfig padding_config;
+ std::vector<int64> padded_shape(input_shape.begin(), input_shape.end());
+ int64 block_num_elems = 1LL;
+ padding_config.add_dimensions(); // Don't pad the batch dimension.
+ for (int i = 0; i < block_rank; ++i) {
+ auto* dim = padding_config.add_dimensions();
+ int64 pad_start = xla::LiteralUtil::Get<int64>(paddings, {i, 0});
+ int64 pad_end = xla::LiteralUtil::Get<int64>(paddings, {i, 1});
+ OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
+ errors::InvalidArgument("Paddings must be non-negative"));
+ dim->set_edge_padding_low(pad_start);
+ dim->set_edge_padding_high(pad_end);
+ padded_shape[1 + i] += pad_start + pad_end;
+ block_num_elems *= block_shape[i];
+ }
+ // Don't pad the remainder dimensions.
+ for (int i = 0; i < remainder_shape.size(); ++i) {
+ padding_config.add_dimensions();
+ }
+ OP_REQUIRES(ctx, block_num_elems > 0,
+ errors::InvalidArgument(
+ "The product of the block dimensions must be positive"));
+
+ xla::ComputationDataHandle padded =
+ b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
+
+ // 2. Reshape `padded` to `reshaped_padded` of shape:
+ //
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1],
+ // block_shape[M-1]] +
+ // remaining_shape
+ const int64 batch_size = input_shape[0];
+ std::vector<int64> reshaped_padded_shape(input_rank + block_rank);
+ reshaped_padded_shape[0] = batch_size;
+ for (int i = 0; i < block_rank; ++i) {
+ OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
+ errors::InvalidArgument("padded_shape[", 1 + i,
+ "]=", padded_shape[1 + i],
+ " is not divisible by block_shape[", i,
+ "]=", block_shape[i]));
+
+ reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
+ reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ reshaped_padded_shape.begin() + 1 + 2 * block_rank);
+
+ xla::ComputationDataHandle reshaped_padded =
+ b->Reshape(padded, reshaped_padded_shape);
+
+ // 3. Permute dimensions of `reshaped_padded` to produce
+ // `permuted_reshaped_padded` of shape:
+ //
+ // block_shape +
+ // [batch] +
+ // [padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ std::vector<int64> permutation(reshaped_padded_shape.size());
+ for (int i = 0; i < block_rank; ++i) {
+ permutation[i] = 1 + 2 * i + 1;
+ permutation[block_rank + 1 + i] = 1 + 2 * i;
+ }
+ permutation[block_rank] = 0;
+ std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
+ 1 + block_rank * 2);
+ xla::ComputationDataHandle permuted_reshaped_padded =
+ b->Transpose(reshaped_padded, permutation);
+
+ // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
+ // batch dimension, producing an output tensor of shape:
+ //
+ // [batch * prod(block_shape)] +
+ // [padded_shape[1] / block_shape[0],
+ // ...,
+ // padded_shape[M] / block_shape[M-1]] +
+ // remaining_shape
+ // Determine the length of the prefix of block dims that can be combined
+ // into the batch dimension due to having no padding and block_shape=1.
+ std::vector<int64> output_shape(input_rank);
+ output_shape[0] = batch_size * block_num_elems;
+ for (int i = 0; i < block_rank; ++i) {
+ output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
+ }
+ std::copy(remainder_shape.begin(), remainder_shape.end(),
+ output_shape.begin() + 1 + block_rank);
+
+ xla::ComputationDataHandle output =
+ b->Reshape(permuted_reshaped_padded, output_shape);
+ ctx->SetOutput(0, output);
+}
+
+class SpaceToBatchNDOp : public XlaOpKernel {
+ public:
+ explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> block_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
+
+ xla::Literal paddings;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
+
+ SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ block_shape, paddings);
+ }
+};
+REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp);
+
+class SpaceToBatchOp : public XlaOpKernel {
+ public:
+ explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
+ OP_REQUIRES(
+ ctx, block_size_ > 1,
+ errors::InvalidArgument("Block size should be > 1: ", block_size_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::Literal paddings;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
+
+ SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
+ {block_size_, block_size_}, paddings);
+ }
+
+ private:
+ int block_size_;
+};
+REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 53dcdec7a2..a022de36a2 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
return LiteralToInt64Vector(literal, out);
}
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
+ xla::Literal* out) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ switch (literal.shape().element_type()) {
+ case xla::S32:
+ out->Clear();
+ *out->mutable_shape() = literal.shape();
+ out->mutable_shape()->set_element_type(xla::S64);
+ for (int32 x : literal.s32s()) {
+ out->add_s64s(x);
+ }
+ return Status::OK();
+
+ case xla::S64:
+ out->Swap(&literal);
+ return Status::OK();
+
+ default:
+ return errors::InvalidArgument(
+ "Invalid argument to ConstantInputAsInt64Literal: ",
+ xla::ShapeUtil::HumanString(literal.shape()));
+ }
+}
+
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 60e3b59d32..f97e07bea5 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -110,6 +110,9 @@ class XlaOpKernelContext {
// Converts a constant 1D int32 or int64 tensor into a vector of int64s.
Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+ // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
+ Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
+
// Converts a constant 1D int32 or int64 tensor into a TensorShape.
Status ConstantInputAsShape(int index, TensorShape* shape);
diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc
index b24a834083..99b5d3daaa 100644
--- a/tensorflow/core/kernels/batchtospace_op.cc
+++ b/tensorflow/core/kernels/batchtospace_op.cc
@@ -97,6 +97,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
+ OP_REQUIRES(
+ context, block_shape_product > 0,
+ errors::InvalidArgument("Product of block sizes must be positive, got ",
+ block_shape_product));
const int64 orig_input_batch_size = orig_input_tensor.dim_size(0);
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc
index 3815716ccd..c513683918 100644
--- a/tensorflow/core/kernels/spacetobatch_op.cc
+++ b/tensorflow/core/kernels/spacetobatch_op.cc
@@ -100,6 +100,10 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
+ OP_REQUIRES(
+ context, block_shape_product > 0,
+ errors::InvalidArgument("Product of block sizes must be positive, got ",
+ block_shape_product));
const int internal_block_dims =
block_dims - removed_prefix_block_dims - removed_suffix_block_dims;