aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tile_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc128
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
new file mode 100644
index 0000000000..45ac5e12c7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Tile Op.
+
+#include <vector>
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace {
+
+// --------------------------------------------------------------------------
+class TileOp : public XlaOpKernel {
+ public:
+ explicit TileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape multiples_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(
+ ctx, IsLegacyVector(multiples_shape),
+ errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
+ multiples_shape.DebugString()));
+ OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(),
+ errors::InvalidArgument(
+ "Expected multiples argument to be a vector of length ",
+ input_shape.dims(), " but got length ",
+ multiples_shape.dim_size(0)));
+ const int input_dims = input_shape.dims();
+
+ // If input is a scalar then multiples has 0 elements and this is
+ // a NoOp.
+ if (input_dims == 0) {
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+
+ // zero_element_result is true if the final shape has 0 elements,
+ // i.e. if any of the input dimensions or multiples is zero.
+ std::vector<int64> multiples_array(input_dims);
+ std::vector<int64> output_shape;
+ bool all_multiples_are_one = true;
+ bool one_dimension_is_broadcasted_without_multiple = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int multiple = xla::LiteralUtil::Get<int>(literal, {i});
+ OP_REQUIRES(ctx, multiple,
+ errors::InvalidArgument("Expected multiples[", i,
+ "] >= 0, but got ", multiple));
+ int64 new_dim = input_shape.dim_size(i) * multiple;
+ output_shape.push_back(new_dim);
+ multiples_array[i] = multiple;
+ all_multiples_are_one = all_multiples_are_one && multiple == 1;
+ // If the multiple of a non-one dimensions is not one, then binary
+ // operation broadcast semantics will not be sufficient to implement the
+ // tile operation.
+ one_dimension_is_broadcasted_without_multiple =
+ one_dimension_is_broadcasted_without_multiple &&
+ ((input_shape.dim_size(i) > 1 && multiple == 1) ||
+ input_shape.dim_size(i) == 1);
+ }
+ auto input = ctx->Input(0);
+ // If all multiples are 1, than the input is the same as the output.
+ if (all_multiples_are_one) {
+ ctx->SetOutput(0, input);
+ return;
+ }
+ if (one_dimension_is_broadcasted_without_multiple) {
+ // Create a constant Zero the size of the output shape to leverage binary
+ // operation broadcast semantics.
+ auto broadcasted_zero = ctx->builder()->Broadcast(
+ XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape);
+ ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input));
+ return;
+ }
+
+ // First broadcast the requisite number of multiples along each
+ // dimension. This prepends the broadcasted dimensions, so an
+ // input of shape [2,3,1] broadcast with multiples [5,4,3] will
+ // end up with shape [5,4,3,2,3,1].
+ auto broadcasted = ctx->builder()->Broadcast(input, multiples_array);
+ // Now flatten and reshape. The broadcasted dimensions are
+ // paired with the original dimensions so in the above example
+ // we flatten [0,3,1,4,2,5] then reshape to [10,12,3].
+ std::vector<int64> flattened;
+ for (int i = 0; i < output_shape.size(); ++i) {
+ flattened.push_back(i);
+ flattened.push_back(i + output_shape.size());
+ }
+ xla::ComputationDataHandle output =
+ ctx->builder()->Reshape(broadcasted, flattened, output_shape);
+
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
+};
+
+REGISTER_XLA_OP("Tile", TileOp);
+
+} // namespace
+} // namespace tensorflow