aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 12:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 12:15:39 -0700
commit35cb434a9a95bef7ca8d7880d87dd9775eeba336 (patch)
tree976358f9a935cbbdf76407f60688c08b6484aeae /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent1536bba6be3e16f3983b79dd6931de313c900114 (diff)
[TF:XLA] Refactor TF/XLA code to use free functions in xla:: namespace to build XlaOps, rather than calling XlaBuilder methods.
PiperOrigin-RevId: 202348891
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc42
1 files changed, 21 insertions, 21 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 9adee78a1f..2f650ce305 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
const gtl::ArraySlice<int64>& update_dims,
const xla::XlaOp& start_indices) {
- xla::XlaOp current =
- builder->DynamicSlice(operand, start_indices, update_dims);
- xla::XlaOp sum = builder->Add(current, update);
- return builder->DynamicUpdateSlice(operand, sum, start_indices);
+ xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
+ xla::XlaOp sum = xla::Add(current, update);
+ return xla::DynamicUpdateSlice(operand, sum, start_indices);
}
class TensorArrayOp : public XlaOpKernel {
@@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel {
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
- value = b->Broadcast(zero, ta_shape.dim_sizes());
+ value = xla::Broadcast(zero, ta_shape.dim_sizes());
}
XlaContext& xc = XlaContext::Get(ctx);
@@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
- auto update = b->Reshape(value, slice_shape.dim_sizes());
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
@@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel {
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
- ctx->SetOutput(0, b->Reshape(read, value_shape));
+ ctx->SetOutput(0, xla::Reshape(read, value_shape));
}
private:
@@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
for (auto i = 1; i < ta_shape.dims(); i++) {
end[i] = ta_shape.dim_size(i);
}
- ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
+ ctx->SetOutput(0, xla::Slice(ta, begin, end, strides));
return;
}
}
@@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
if (scatter_all_elements_in_order) {
- ta = b->Add(ta, value);
+ ta = xla::Add(ta, value);
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+ auto slice = xla::Slice(value, value_starts, value_ends, value_strides);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
+ auto index = xla::Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
shape[0] *= ta_shape.dim_size(0);
- ctx->SetOutput(0, b->Reshape(ta, shape));
+ ctx->SetOutput(0, xla::Reshape(ta, shape));
Tensor lengths(DT_INT64, {ta_dims[0]});
auto lengths_vec = lengths.vec<int64>();
@@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
- ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add(
+ ta, xla::Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}