aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_helpers.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 17:41:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 17:43:59 -0700
commit45bafe9a3589fc735c22c3c703f8689ea9c1e71e (patch)
treee39723521a1ca68e9c2c74d1a9d3ac5ef2e8abc4 /tensorflow/compiler/tf2xla/xla_helpers.cc
parentc89a1d9605427d74079774af7da37933f9ca153c (diff)
[XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compiler/aot:
- xla::ComputationBuilder -> xla::XlaBuilder - xla::ComputationDataHandle -> xla::XlaOp - xla::Computation -> xla::XlaComputation - xla::CompileOnlyClient::AotComputationInstance -> xla::CompileOnlyClient::AotXlaComputationInstance - xla::SessionModule -> xla::HloSnapshot PiperOrigin-RevId: 194874462
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_helpers.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc95
1 files changed, 40 insertions, 55 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index a3deb02a1f..f1594193af 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,13 +32,12 @@ namespace tensorflow {
namespace {
-Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, bool is_min,
- xla::ComputationDataHandle* argminmax) {
- xla::ComputationDataHandle init_value;
- const xla::Computation* reducer;
+Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input, const TensorShape& input_shape,
+ DataType input_type, DataType output_type, int axis,
+ bool is_min, xla::XlaOp* argminmax) {
+ xla::XlaOp init_value;
+ const xla::XlaComputation* reducer;
if (is_min) {
init_value = XlaHelpers::MaxValue(builder, input_type);
reducer = ctx->GetOrCreateMin(input_type);
@@ -50,13 +49,13 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
xla::PrimitiveType xla_output_type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
- xla::ComputationDataHandle input_max = builder->Reduce(
- input, init_value, *reducer, /*dimensions_to_reduce=*/{axis});
+ xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer,
+ /*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.dims() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
- xla::ComputationDataHandle partial_mask = builder->ConvertElementType(
+ xla::XlaOp partial_mask = builder->ConvertElementType(
builder->Eq(input, input_max, broadcast_dims), xla_output_type);
// In order to make identity elements for a bitwise And, we:
@@ -65,23 +64,23 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
// 0xFF...F
int32 bits_in_type =
xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
- xla::ComputationDataHandle shift_amount =
+ xla::XlaOp shift_amount =
XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
- xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic(
+ xla::XlaOp full_mask = builder->ShiftRightArithmetic(
builder->ShiftLeft(partial_mask, shift_amount), shift_amount);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
- xla::ComputationDataHandle iota;
+ xla::XlaOp iota;
const int64 axis_size = input_shape.dim_size(axis);
TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
- xla::ComputationDataHandle product =
+ xla::XlaOp product =
builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
- xla::ComputationDataHandle output =
+ xla::XlaOp output =
builder->Reduce(product, XlaHelpers::MinValue(builder, output_type),
*ctx->GetOrCreateMax(output_type),
/*dimensions_to_reduce=*/{axis});
@@ -91,36 +90,31 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
} // namespace
-xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::MinValue(type));
}
-xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::MaxValue(type));
}
-xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::Zero(type));
}
-xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return b->ConstantLiteral(xla::Literal::One(type));
}
-xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
- DataType data_type) {
+xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) {
switch (data_type) {
case DT_HALF:
return b->ConstantR0<Eigen::half>(
@@ -137,16 +131,15 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
}
}
-xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
- xla::ComputationBuilder* b, DataType data_type, int64 value) {
+xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
+ int64 value) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return ::tensorflow::IntegerLiteral(b, type, value);
}
-xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
- DataType data_type,
- double value) {
+xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
+ double value) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
return ::tensorflow::FloatLiteral(b, type, value);
@@ -183,28 +176,24 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
return linspace;
}
-Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
+Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input,
const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmax) {
+ DataType output_type, int axis, xla::XlaOp* argmax) {
return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
axis, /*is_min=*/false, argmax);
}
-Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder,
- XlaOpKernelContext* ctx,
- const xla::ComputationDataHandle& input,
+Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
+ const xla::XlaOp& input,
const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis,
- xla::ComputationDataHandle* argmin) {
+ DataType output_type, int axis, xla::XlaOp* argmin) {
return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
axis, /*is_min=*/true, argmin);
}
-Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype,
- int64 size, xla::ComputationDataHandle* iota) {
+Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
+ xla::XlaOp* iota) {
TensorShape linspace_shape({size});
Tensor linspace;
switch (dtype) {
@@ -227,13 +216,10 @@ Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype,
return Status::OK();
}
-Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
- int axis, DataType index_type,
- const TensorShape& indices_shape,
- const xla::ComputationDataHandle& indices,
- const xla::ComputationDataHandle& on_value,
- const xla::ComputationDataHandle& off_value,
- xla::ComputationDataHandle* one_hot) {
+Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
+ DataType index_type, const TensorShape& indices_shape,
+ const xla::XlaOp& indices, const xla::XlaOp& on_value,
+ const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
const int indices_dims = indices_shape.dims();
const int output_dims = indices_dims + 1;
@@ -267,7 +253,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
std::vector<int64> broadcast_dims(indices_shape.dims());
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- xla::ComputationDataHandle one_hot_bool = builder->Eq(
+ xla::XlaOp one_hot_bool = builder->Eq(
indices, builder->ConstantLiteral(linspace_literal), broadcast_dims);
// Selects the user-provided off_value and on_value values.
@@ -284,10 +270,9 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
return dtype;
}
-xla::ComputationDataHandle XlaHelpers::ConvertElementType(
- xla::ComputationBuilder* const builder,
- const xla::ComputationDataHandle& operand,
- const DataType new_element_type) {
+xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder,
+ const xla::XlaOp& operand,
+ const DataType new_element_type) {
xla::PrimitiveType convert_to;
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
return builder->ConvertElementType(operand, convert_to);