diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/reduce_window_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_window_test.cc | 138 |
1 files changed, 71 insertions, 67 deletions
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 741974480c..161b74a5c8 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - auto init = - CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_); + auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), + &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -81,7 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMaxComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -91,7 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -102,9 +104,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>({1, 1, 1, 1}), &builder_); + *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0<float>(0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -119,32 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*Literal::CreateR0<float>(42.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*Literal::CreateR0<float>(1.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR0<float>(43.0), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({100, 1}), {}, - ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}), + {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *Literal::CreateR1<float>({1000, 100, 10, 1, 1}), {}, - ErrorSpec(0.00001)); + *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}), + {}, ErrorSpec(0.00001)); } XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { @@ -156,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -171,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -185,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -202,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -224,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -247,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -272,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -289,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -308,12 +310,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get())); + CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_), + CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -327,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D<float> input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); std::unique_ptr<Literal> input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -347,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -376,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { auto shape = ShapeUtil::MakeShape(F32, input_dims); std::unique_ptr<Literal> arg_literal = - Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -385,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr<Literal> expected = - Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f); + LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -394,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D<float> input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr<Literal> input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -408,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -416,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D<float> input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr<Literal> input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -430,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -438,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D<float> input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr<Literal> input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -452,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -473,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector<float> input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>(input_vector), &builder_); + *LiteralUtil::CreateR1<float>(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *Literal::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -499,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>(input_vector), &builder_); + *LiteralUtil::CreateR1<float>(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {}, DefaultErrorSpec()); } @@ -516,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1<float>(input_vector), &builder_); + *LiteralUtil::CreateR1<float>(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {}, DefaultErrorSpec()); } @@ -535,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray<float>(*res), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D<float> input_array(6, 4, 1.0f); XlaOp input = Broadcast( - CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); + CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -550,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray<float>(*res), {}, + DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -609,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr<Literal> input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", @@ -621,7 +625,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) @@ -647,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*stride=*/param.strides, /*padding=*/padding); std::unique_ptr<Literal> expected_literal = - Literal::CreateFromArray(*expected); + LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( input_literal->shape().element_type(), AsInt64Slice(expected_literal->shape().dimensions()), param.layout); @@ -959,14 +963,14 @@ TEST_P(R3ReduceWindowTest, Add) { Array3D<float> input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], 1.0f); std::unique_ptr<Literal> input_literal = - Literal::CreateR3FromArray3DWithLayout( + LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, /*computation=*/CreateScalarAddComputation(FloatType(), &b), @@ -977,7 +981,7 @@ TEST_P(R3ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1093,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -1107,7 +1111,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1123,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1292,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector<float> input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr<Literal> input_literal = - Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector)); + LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1304,7 +1308,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1323,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateR1<float>(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected), {input_arg.get()}, DefaultErrorSpec()); } |