aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/reference_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/reference_util.cc')
-rw-r--r--tensorflow/compiler/xla/reference_util.cc75
1 files changed, 34 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 9f1afa2671..ceb5e74db7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
- float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
@@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
const Array4D<float>& source,
float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
@@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
}
ordered_input_dimensions[0] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
ordered_input_dimensions[1] =
- lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+ lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
ordered_kernel_dimensions[0] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
ordered_kernel_dimensions[1] =
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
std::vector<std::pair<int64, int64>> paddings =
MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim;
dim.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
dim.set_stride(kernel_stride.first);
dim.set_padding_low(paddings[0].first);
dim.set_padding_high(paddings[0].second);
@@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
WindowDimension dim2;
dim2.set_size(
- rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+ rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
dim2.set_stride(kernel_stride.second);
dim2.set_padding_low(paddings[1].first);
dim2.set_padding_high(paddings[1].second);
@@ -565,7 +558,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
*window.add_dimensions() = dim2;
const Shape& shape = ShapeInference::InferConvolveShape(
- lhs_literal->shape(), rhs_literal->shape(),
+ lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, window, dnums)
.ConsumeValueOrDie();
@@ -585,18 +578,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
auto computation = module.AddEntryComputation(b.Build());
HloEvaluator evaluator;
- std::unique_ptr<Literal> result_literal =
+ Literal result_literal =
evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
- CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+ CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
auto result =
- absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
- result_literal->shape().dimensions(1),
- result_literal->shape().dimensions(2),
- result_literal->shape().dimensions(3));
+ absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+ result_literal.shape().dimensions(1),
+ result_literal.shape().dimensions(2),
+ result_literal.shape().dimensions(3));
result->Each([&](absl::Span<const int64> indices, float* value) {
- *value = result_literal->Get<float>(indices);
+ *value = result_literal.Get<float>(indices);
});
return result;