aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/util.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc92
1 files changed, 46 insertions, 46 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 31d823ca33..cc7b13571c 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -27,15 +27,14 @@ limitations under the License.
namespace tensorflow {
-xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- const xla::Shape& shape) {
+xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
return builder->Broadcast(
builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
}
-xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type, double value) {
+xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ double value) {
switch (type) {
case xla::F16:
return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
@@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
}
}
-xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
- xla::PrimitiveType type,
- int64 value) {
+xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
+ int64 value) {
xla::Literal literal;
switch (type) {
case xla::U8:
@@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
return builder->ConstantLiteral(literal);
}
-xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
+xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end) {
TF_RET_CHECK(start.size() == end.size());
int64 n_minor_dims = start.size();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - n_minor_dims);
@@ -140,7 +139,7 @@ xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
return builder->Slice(x, padded_start, padded_end, strides);
}
-std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
+std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
const gtl::ArraySlice<int64>& major_dims,
const gtl::ArraySlice<int64>& indices) {
std::vector<int64> output(indices.size() + major_dims.size());
@@ -149,16 +148,16 @@ std::vector<int64> PrependMajorDims(xla::ComputationBuilder* builder,
return output;
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts,
+xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts,
const gtl::ArraySlice<int64>& sizes) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
int64 n_minor_dims = starts.size();
TF_RET_CHECK(n_minor_dims == sizes.size());
TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
/*pos=*/0,
/*len=*/n_dims - sizes.size());
TF_ASSIGN_OR_RETURN(auto padded_starts,
@@ -167,27 +166,29 @@ xla::StatusOr<xla::ComputationDataHandle> DynamicSliceInMinorDims(
return builder->DynamicSlice(x, padded_starts, padded_sizes);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
+xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
auto start_constant = builder->ConstantR1<int32>(start_as_int32);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> start_constant_shape,
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
builder->GetShape(start_constant));
const int64 start_length =
- xla::ShapeUtil::GetDimension(*start_constant_shape, -1);
+ xla::ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
return builder->DynamicUpdateSlice(x, update, start_constant);
}
-xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x,
+ const xla::XlaOp& update,
+ gtl::ArraySlice<int64> start) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
const int64 n_minor_dims = start.size();
TF_RET_CHECK(n_minor_dims <= n_dims);
std::vector<int64> padded_start(n_dims, 0);
@@ -196,22 +197,21 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
return UpdateSlice(builder, x, update, padded_start);
}
-xla::StatusOr<xla::ComputationDataHandle> DynamicUpdateSliceInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const xla::ComputationDataHandle& update,
- const std::vector<xla::ComputationDataHandle>& starts) {
+xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
+ const std::vector<xla::XlaOp>& starts) {
TF_ASSIGN_OR_RETURN(auto padded_starts,
PrependZerosInMajorDims(builder, x, starts));
return builder->DynamicUpdateSlice(x, update, padded_starts);
}
-xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- const std::vector<xla::ComputationDataHandle>& starts) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
+ xla::XlaBuilder* builder, const xla::XlaOp& x,
+ const std::vector<xla::XlaOp>& starts) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
- std::vector<xla::ComputationDataHandle> padded_starts(n_dims, zero);
+ std::vector<xla::XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
padded_starts[n_dims - starts.size() + i] =
builder->Reshape(starts[i], {1});
@@ -219,10 +219,10 @@ xla::StatusOr<xla::ComputationDataHandle> PrependZerosInMajorDims(
return builder->ConcatInDim(padded_starts, 0);
}
-xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(*shape);
+xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
+ const xla::XlaOp& x) {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);