diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/sequence_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/sequence_ops.cc | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc new file mode 100644 index 0000000000..42ae978c3c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific sequence and range Ops. + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace { + +template <typename T> +Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); + *value = xla::LiteralUtil::Get<T>(literal, {}); + return Status::OK(); +} + +Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + *value = xla::LiteralUtil::Get<int32>(literal, {}); + break; + case xla::S64: + *value = xla::LiteralUtil::Get<int64>(literal, {}); + break; + default: + return errors::InvalidArgument("Invalid argument type for argument", + index); + } + return Status::OK(); +} + +// The type-specific part of the implementation of Range. +template <typename T> +Status CreateRangeTensor(const xla::Literal& start_literal, + const xla::Literal& limit_literal, + const xla::Literal& delta_literal, Tensor* output) { + T start = xla::LiteralUtil::Get<T>(start_literal, {}); + T limit = xla::LiteralUtil::Get<T>(limit_literal, {}); + T delta = xla::LiteralUtil::Get<T>(delta_literal, {}); + + if (delta == 0) { + return errors::InvalidArgument("Requires delta != 0: ", delta); + } + if (delta > 0) { + if (start > limit) { + return errors::InvalidArgument("Requires start <= limit when delta > 0: ", + start, "/", limit); + } + } else { + if (start < limit) { + return errors::InvalidArgument("Requires start >= limit when delta < 0: ", + start, "/", limit); + } + } + int64 size = + (std::is_integral<T>::value + ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) + : std::ceil(std::abs((limit - start) / delta))); + + *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size})); + auto flat = output->flat<T>(); + T val = start; + for (int64 i = 0; i < size; ++i) { + flat(i) = val; + val += delta; + } + return Status::OK(); +} + +class RangeOp : public XlaOpKernel { + public: + explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape start_in_shape = ctx->InputShape(0); + const TensorShape limit_in_shape = ctx->InputShape(1); + const TensorShape delta_in_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in_shape.DebugString())); + OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + errors::InvalidArgument("limit must be a scalar, not shape ", + limit_in_shape.DebugString())); + OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + errors::InvalidArgument("delta must be a scalar, not shape ", + delta_in_shape.DebugString())); + xla::Literal start, limit, delta; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); + + DataType type = input_type(0); + Tensor output; + Status status; + switch (type) { + case DT_INT32: + status = CreateRangeTensor<int32>(start, limit, delta, &output); + break; + case DT_INT64: + status = CreateRangeTensor<int64>(start, limit, delta, &output); + break; + case DT_FLOAT: + status = CreateRangeTensor<float>(start, limit, delta, &output); + break; + case DT_DOUBLE: + status = CreateRangeTensor<double>(start, limit, delta, &output); + break; + default: + status = errors::InvalidArgument("Invalid type for Range ", + DataTypeString(type)); + } + OP_REQUIRES_OK(ctx, status); + ctx->SetConstantOutput(0, output); + } +}; + +REGISTER_XLA_OP("Range", RangeOp); + +class LinSpaceOp : public XlaOpKernel { + public: + explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape start_in_shape = ctx->InputShape(0); + const TensorShape stop_in_shape = ctx->InputShape(1); + const TensorShape num_in_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), + errors::InvalidArgument("start must be a scalar, not shape ", + start_in_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape), + errors::InvalidArgument("stop must be a scalar, not shape ", + stop_in_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape), + errors::InvalidArgument("num must be a scalar, not shape ", + num_in_shape.DebugString())); + + DataType type = ctx->input_type(0); + + int64 num; + OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES(ctx, num > 0, + errors::InvalidArgument("Requires num > 0: ", num)); + Tensor out_constant(type, TensorShape({num})); + + switch (type) { + case DT_FLOAT: { + float start, stop; + OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); + OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + auto flat = out_constant.flat<float>(); + if (num == 1) { + flat(0) = start; + } else { + const float step = (stop - start) / (num - 1); + for (int64 i = 0; i < num; ++i) { + flat(i) = start + step * i; + } + } + break; + } + case DT_DOUBLE: { + double start, stop; + OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); + OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + auto flat = out_constant.flat<double>(); + if (num == 1) { + flat(0) = start; + } else { + const double step = (stop - start) / (num - 1); + for (int64 i = 0; i < num; ++i) { + flat(i) = start + step * i; + } + } + break; + } + + default: + ctx->SetStatus(errors::InvalidArgument("Invalid argument type ", + DataTypeString(type))); + return; + } + ctx->SetConstantOutput(0, out_constant); + } +}; + +REGISTER_XLA_OP("LinSpace", LinSpaceOp); + +} // namespace +} // namespace tensorflow |