diff options
author | Peter Hawkins <phawkins@google.com> | 2018-10-03 18:56:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 19:00:24 -0700 |
commit | 9bd6f5ed55e533ccac055a5bc7fbb771e2d432c5 (patch) | |
tree | e561aec682d97ff5ef0c8a8cd20e8e897efa8590 /tensorflow/compiler | |
parent | 54bebc286bbe7d6a866a3bdbcefd8af55adbe39a (diff) |
[TF:XLA] Use xla::Iota rather than expanding Range ops to constants.
PiperOrigin-RevId: 215668016
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/sequence_ops.cc | 39 |
1 files changed, 18 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 25a5bcbe1d..0c32b8def0 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template <typename T> -Status CreateRangeTensor(const xla::LiteralSlice& start_literal, - const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, - Tensor* output) { +xla::StatusOr<xla::XlaOp> CreateRangeTensor( + const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { T start = start_literal.Get<T>({}); T limit = limit_literal.Get<T>({}); T delta = delta_literal.Get<T>({}); @@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal, ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); - *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size})); - auto flat = output->flat<T>(); - T val = start; - for (int64 i = 0; i < size; ++i) { - flat(i) = val; - val += delta; - } - return Status::OK(); + return xla::ConstantR0(builder, start) + + xla::ConstantR0(builder, delta) * + xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(), + size); } class RangeOp : public XlaOpKernel { @@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); DataType type = input_type(0); - Tensor output; - Status status; + xla::StatusOr<xla::XlaOp> output; switch (type) { case DT_INT32: - status = CreateRangeTensor<int32>(start, limit, delta, &output); + output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder()); break; case DT_INT64: - status = CreateRangeTensor<int64>(start, limit, delta, &output); + output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder()); break; case DT_FLOAT: - status = CreateRangeTensor<float>(start, limit, delta, &output); + output = CreateRangeTensor<float>(start, limit, delta, ctx->builder()); break; case DT_DOUBLE: - status = CreateRangeTensor<double>(start, limit, delta, &output); + output = CreateRangeTensor<double>(start, limit, delta, ctx->builder()); break; default: - status = errors::InvalidArgument("Invalid type for Range ", + output = errors::InvalidArgument("Invalid type for Range ", DataTypeString(type)); } - OP_REQUIRES_OK(ctx, status); - ctx->SetConstantOutput(0, output); + OP_REQUIRES_OK(ctx, output.status()); + ctx->SetOutput(0, output.ValueOrDie()); } }; |