diff options
Diffstat (limited to 'tensorflow/compiler/xla/reference_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 75 |
1 files changed, 34 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 9f1afa2671..ceb5e74db7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, /* static */ std::unique_ptr<std::vector<float>> ReferenceUtil::ReduceWindow1DGeneric( - const absl::Span<const float>& operand, float init, + absl::Span<const float> operand, float init, const std::function<float(float, float)>& reduce_func, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, - const absl::Span<const std::pair<int64, int64>>& padding) { + absl::Span<const int64> window, absl::Span<const int64> stride, + absl::Span<const std::pair<int64, int64>> padding) { std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; std::vector<int64> window_counts(window.size(), 0); std::vector<int64> pad_low(window.size(), 0); @@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr<std::vector<float>> -ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand, - float init, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, +ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init, + absl::Span<const int64> window, + absl::Span<const int64> stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; @@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand, ReferenceUtil::ReduceWindow2DGeneric( const Array2D<float>& operand, float init, const std::function<float(float, float)>& reduce_func, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, - const absl::Span<const std::pair<int64, int64>>& padding) { + absl::Span<const int64> window, absl::Span<const int64> stride, + absl::Span<const std::pair<int64, int64>> padding) { std::vector<int64> dim_lengths{operand.height(), operand.width()}; std::vector<int64> window_counts(window.size(), 0); @@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( - const Array2D<float>& operand, float init, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, Padding padding) { + const Array2D<float>& operand, float init, absl::Span<const int64> window, + absl::Span<const int64> stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector<int64> dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd( - const Array3D<float>& operand, float init, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, Padding padding) { + const Array3D<float>& operand, float init, absl::Span<const int64> window, + absl::Span<const int64> stride, Padding padding) { std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D<float>& operand, float init, const std::function<float(float, float)>& reduce_func, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, Padding padding) { + absl::Span<const int64> window, absl::Span<const int64> stride, + Padding padding) { std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D<float>& operand, float init, const std::function<float(float, float)>& reduce_func, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, - const absl::Span<const std::pair<int64, int64>>& padding) { + absl::Span<const int64> window, absl::Span<const int64> stride, + absl::Span<const std::pair<int64, int64>> padding) { std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( - const Array4D<float>& operand, float init, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, Padding padding) { + const Array4D<float>& operand, float init, absl::Span<const int64> window, + absl::Span<const int64> stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand, const Array4D<float>& source, float init, - const absl::Span<const int64>& window, - const absl::Span<const int64>& stride, + absl::Span<const int64> window, + absl::Span<const int64> stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(), @@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector<std::pair<int64, int64>> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -565,7 +558,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( *window.add_dimensions() = dim2; const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal->shape(), rhs_literal->shape(), + lhs_literal.shape(), rhs_literal.shape(), /*feature_group_count=*/1, window, dnums) .ConsumeValueOrDie(); @@ -585,18 +578,18 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr<Literal> result_literal = + Literal result_literal = evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); result->Each([&](absl::Span<const int64> indices, float* value) { - *value = result_literal->Get<float>(indices); + *value = result_literal.Get<float>(indices); }); return result; |