diff options
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 170 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_test.cc | 61 |
4 files changed, 210 insertions, 37 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 7481b357ff..9e4b9ccd25 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -790,24 +790,101 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, }); } +Status XlaBuilder::VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const { + if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + return InvalidArgument( + "Convolution arguments must have same number of " + "dimensions. Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_dims = ShapeUtil::Rank(lhs_shape); + if (num_dims < 2) { + return InvalidArgument( + "Convolution expects argument arrays with >= 3 dimensions. " + "Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_spatial_dims = num_dims - 2; + + const auto check_spatial_dimensions = + [&](const char* const field_name, + const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& + numbers) { + if (numbers.size() != num_spatial_dims) { + return InvalidArgument("Expected %d elements for %s, but got %d.", + num_spatial_dims, field_name, numbers.size()); + } + for (int i = 0; i < numbers.size(); ++i) { + if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { + return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + field_name, i, numbers.Get(i)); + } + } + return Status::OK(); + }; + TF_RETURN_IF_ERROR( + check_spatial_dimensions("input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions())); + TF_RETURN_IF_ERROR( + check_spatial_dimensions("kernel_spatial_dimensions", + dimension_numbers.kernel_spatial_dimensions())); + return check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); +} + XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - return UnimplementedOp(); + return ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { - return UnimplementedOp(); + return ConvGeneral(lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector<int64> base_area_dimensions( + dimension_numbers.input_spatial_dimensions_size()); + for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size(); + ++i) { + base_area_dimensions[i] = + lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i)); + } + + std::vector<int64> window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + + return ConvGeneral(lhs, rhs, window_strides, + MakePadding(base_area_dimensions, window_dimensions, + window_strides, padding), + dimension_numbers); + }); } XlaOp XlaBuilder::ConvGeneral( @@ -815,7 +892,8 @@ XlaOp XlaBuilder::ConvGeneral( tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return UnimplementedOp(); + return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, + dimension_numbers); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -825,7 +903,89 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector<int64> window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + lhs_dilation, rhs_dilation)); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), + dimension_numbers)); + + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, + {lhs, rhs}); + }); +} + +StatusOr<Window> XlaBuilder::MakeWindow( + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation) const { + const auto verify_size = [&](const size_t x, const char* x_name) { + if (x == 0 || x == window_dimensions.size()) { + return Status::OK(); + } else { + return InvalidArgument( + "%s", tensorflow::strings::StrCat( + "Window has different number of window dimensions than of ", + x_name, + "\nNumber of window dimensions: ", window_dimensions.size(), + "\nNumber of ", x_name, ": ", x, "\n") + .c_str()); + } + }; + TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); + TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); + TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); + TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); + + Window window; + for (size_t i = 0; i < window_dimensions.size(); i++) { + auto dim = window.add_dimensions(); + dim->set_size(window_dimensions[i]); + if (!window_strides.empty()) { + dim->set_stride(window_strides[i]); + } else { + dim->set_stride(1); + } + if (!padding.empty()) { + dim->set_padding_low(padding[i].first); + dim->set_padding_high(padding[i].second); + } else { + dim->set_padding_low(0); + dim->set_padding_high(0); + } + if (!lhs_dilation.empty()) { + dim->set_base_dilation(lhs_dilation[i]); + } else { + dim->set_base_dilation(1); + } + if (!rhs_dilation.empty()) { + dim->set_window_dilation(rhs_dilation[i]); + } else { + dim->set_window_dilation(1); + } + dim->set_window_reversal(false); + } + return window; } XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index d747691f16..24e0be2ac1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -835,6 +835,20 @@ class XlaBuilder { void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited, bool* is_constant) const; + // Checks bounds for convolution parameters. + Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + // Helper function for creating a Window proto from user-supplied data. + // Returns error if the user-supplied data was invalid. + StatusOr<Window> MakeWindow( + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation) const; + string name_; // Name to use for the built computation. // The first error encountered while building the computation. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 19fb4886db..67c53c6ac0 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -781,10 +781,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 72715398de..5eb3136abe 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -88,12 +88,12 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->width()); ASSERT_EQ(2, arhs->height()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR4FromArray4D<T>(*alhs); auto rhs = builder.ConstantR4FromArray4D<T>(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - ComputeAndCompare(&builder, conv, {}, error_spec_); + ComputeAndCompare(&builder, {}, error_spec_); } }; @@ -106,12 +106,12 @@ template <typename T> class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D<T> input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D<T>({ @@ -122,7 +122,7 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -137,12 +137,12 @@ template <typename T> class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D<T> input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D<T>({ @@ -156,7 +156,7 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -171,12 +171,12 @@ template <typename T> class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D<T> input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D<T>({ @@ -191,7 +191,7 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -207,12 +207,12 @@ template <typename T> class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D<T> input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -223,7 +223,7 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { filter_data.FillWithYX(Array2D<T>( {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -234,7 +234,7 @@ TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -264,7 +264,7 @@ template <typename T> class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2}); @@ -300,7 +300,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -331,7 +331,7 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -365,7 +365,7 @@ template <typename T> class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2}); @@ -402,7 +402,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<int64> input_dims = {1, 4, 2, 3, 3}; std::vector<int64> filter_dims = {2, 2, 2, 3, 3}; Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); @@ -469,7 +469,7 @@ template <typename T> class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<int64> input_dims = {1, 3, 3, 5}; std::vector<int64> filter_dims = {3, 3, 5, 3}; Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims); @@ -537,7 +537,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "convolution-canonicalization"); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); @@ -551,8 +551,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, - Padding::kValid, dnums); + builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D<float> param0(4, 29); param0.FillUnique(); @@ -563,7 +562,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Array2D<float> expected_result(29, 10); expected_result.Fill(0); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(param0)), std::move(*Literal::CreateFromArray(param1))}, error_spec_); @@ -587,7 +586,7 @@ class Convolve1D1WindowTestBase protected: template <typename T> void TestImpl() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); int64 input_feature = GetParam().input_feature; int64 output_feature = GetParam().output_feature; int64 batch = GetParam().batch; @@ -724,12 +723,12 @@ INSTANTIATE_TEST_CASE_P( #endif XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D<bfloat16> input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D<bfloat16>({ @@ -740,7 +739,7 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { {bfloat16(5), bfloat16(6)}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); |