aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-10-03 18:56:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 19:00:24 -0700
commit9bd6f5ed55e533ccac055a5bc7fbb771e2d432c5 (patch)
treee561aec682d97ff5ef0c8a8cd20e8e897efa8590 /tensorflow/compiler
parent54bebc286bbe7d6a866a3bdbcefd8af55adbe39a (diff)
[TF:XLA] Use xla::Iota rather than expanding Range ops to constants.
PiperOrigin-RevId: 215668016
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc39
1 files changed, 18 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 25a5bcbe1d..0c32b8def0 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
// The type-specific part of the implementation of Range.
template <typename T>
-Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
- const xla::LiteralSlice& limit_literal,
- const xla::LiteralSlice& delta_literal,
- Tensor* output) {
+xla::StatusOr<xla::XlaOp> CreateRangeTensor(
+ const xla::LiteralSlice& start_literal,
+ const xla::LiteralSlice& limit_literal,
+ const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) {
T start = start_literal.Get<T>({});
T limit = limit_literal.Get<T>({});
T delta = delta_literal.Get<T>({});
@@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
- *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size}));
- auto flat = output->flat<T>();
- T val = start;
- for (int64 i = 0; i < size; ++i) {
- flat(i) = val;
- val += delta;
- }
- return Status::OK();
+ return xla::ConstantR0(builder, start) +
+ xla::ConstantR0(builder, delta) *
+ xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(),
+ size);
}
class RangeOp : public XlaOpKernel {
@@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
DataType type = input_type(0);
- Tensor output;
- Status status;
+ xla::StatusOr<xla::XlaOp> output;
switch (type) {
case DT_INT32:
- status = CreateRangeTensor<int32>(start, limit, delta, &output);
+ output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder());
break;
case DT_INT64:
- status = CreateRangeTensor<int64>(start, limit, delta, &output);
+ output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder());
break;
case DT_FLOAT:
- status = CreateRangeTensor<float>(start, limit, delta, &output);
+ output = CreateRangeTensor<float>(start, limit, delta, ctx->builder());
break;
case DT_DOUBLE:
- status = CreateRangeTensor<double>(start, limit, delta, &output);
+ output = CreateRangeTensor<double>(start, limit, delta, ctx->builder());
break;
default:
- status = errors::InvalidArgument("Invalid type for Range ",
+ output = errors::InvalidArgument("Invalid type for Range ",
DataTypeString(type));
}
- OP_REQUIRES_OK(ctx, status);
- ctx->SetConstantOutput(0, output);
+ OP_REQUIRES_OK(ctx, output.status());
+ ctx->SetOutput(0, output.ValueOrDie());
}
};