aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc170
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h14
-rw-r--r--tensorflow/compiler/xla/tests/BUILD2
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc61
4 files changed, 210 insertions, 37 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 7481b357ff..9e4b9ccd25 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -790,24 +790,101 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
});
}
+Status XlaBuilder::VerifyConvolution(
+ const Shape& lhs_shape, const Shape& rhs_shape,
+ const ConvolutionDimensionNumbers& dimension_numbers) const {
+ if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
+ return InvalidArgument(
+ "Convolution arguments must have same number of "
+ "dimensions. Got: %s and %s",
+ ShapeUtil::HumanString(lhs_shape).c_str(),
+ ShapeUtil::HumanString(rhs_shape).c_str());
+ }
+ int num_dims = ShapeUtil::Rank(lhs_shape);
+ if (num_dims < 2) {
+ return InvalidArgument(
+ "Convolution expects argument arrays with >= 3 dimensions. "
+ "Got: %s and %s",
+ ShapeUtil::HumanString(lhs_shape).c_str(),
+ ShapeUtil::HumanString(rhs_shape).c_str());
+ }
+ int num_spatial_dims = num_dims - 2;
+
+ const auto check_spatial_dimensions =
+ [&](const char* const field_name,
+ const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
+ numbers) {
+ if (numbers.size() != num_spatial_dims) {
+ return InvalidArgument("Expected %d elements for %s, but got %d.",
+ num_spatial_dims, field_name, numbers.size());
+ }
+ for (int i = 0; i < numbers.size(); ++i) {
+ if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
+ return InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
+ field_name, i, numbers.Get(i));
+ }
+ }
+ return Status::OK();
+ };
+ TF_RETURN_IF_ERROR(
+ check_spatial_dimensions("input_spatial_dimensions",
+ dimension_numbers.input_spatial_dimensions()));
+ TF_RETURN_IF_ERROR(
+ check_spatial_dimensions("kernel_spatial_dimensions",
+ dimension_numbers.kernel_spatial_dimensions()));
+ return check_spatial_dimensions(
+ "output_spatial_dimensions",
+ dimension_numbers.output_spatial_dimensions());
+}
+
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- return UnimplementedOp();
+ return ConvWithGeneralDimensions(
+ lhs, rhs, window_strides, padding,
+ CreateDefaultConvDimensionNumbers(window_strides.size()));
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return UnimplementedOp();
+ return ConvGeneral(lhs, rhs, window_strides, padding,
+ CreateDefaultConvDimensionNumbers(window_strides.size()));
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+
+ TF_RETURN_IF_ERROR(
+ VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
+
+ std::vector<int64> base_area_dimensions(
+ dimension_numbers.input_spatial_dimensions_size());
+ for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
+ ++i) {
+ base_area_dimensions[i] =
+ lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i));
+ }
+
+ std::vector<int64> window_dimensions(
+ dimension_numbers.kernel_spatial_dimensions_size());
+ for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
+ ++i) {
+ window_dimensions[i] =
+ rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
+ }
+
+ return ConvGeneral(lhs, rhs, window_strides,
+ MakePadding(base_area_dimensions, window_dimensions,
+ window_strides, padding),
+ dimension_numbers);
+ });
}
XlaOp XlaBuilder::ConvGeneral(
@@ -815,7 +892,8 @@ XlaOp XlaBuilder::ConvGeneral(
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return UnimplementedOp();
+ return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
+ dimension_numbers);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -825,7 +903,89 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+ TF_RETURN_IF_ERROR(
+ VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
+
+ std::vector<int64> window_dimensions(
+ dimension_numbers.kernel_spatial_dimensions_size());
+ for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
+ ++i) {
+ window_dimensions[i] =
+ rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
+ }
+ TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
+ MakeWindow(window_dimensions, window_strides, padding,
+ lhs_dilation, rhs_dilation));
+
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(),
+ dimension_numbers));
+
+ *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
+
+ return AddInstruction(std::move(instr), HloOpcode::kConvolution,
+ {lhs, rhs});
+ });
+}
+
+StatusOr<Window> XlaBuilder::MakeWindow(
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation) const {
+ const auto verify_size = [&](const size_t x, const char* x_name) {
+ if (x == 0 || x == window_dimensions.size()) {
+ return Status::OK();
+ } else {
+ return InvalidArgument(
+ "%s", tensorflow::strings::StrCat(
+ "Window has different number of window dimensions than of ",
+ x_name,
+ "\nNumber of window dimensions: ", window_dimensions.size(),
+ "\nNumber of ", x_name, ": ", x, "\n")
+ .c_str());
+ }
+ };
+ TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
+ TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
+ TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
+ TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
+
+ Window window;
+ for (size_t i = 0; i < window_dimensions.size(); i++) {
+ auto dim = window.add_dimensions();
+ dim->set_size(window_dimensions[i]);
+ if (!window_strides.empty()) {
+ dim->set_stride(window_strides[i]);
+ } else {
+ dim->set_stride(1);
+ }
+ if (!padding.empty()) {
+ dim->set_padding_low(padding[i].first);
+ dim->set_padding_high(padding[i].second);
+ } else {
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ }
+ if (!lhs_dilation.empty()) {
+ dim->set_base_dilation(lhs_dilation[i]);
+ } else {
+ dim->set_base_dilation(1);
+ }
+ if (!rhs_dilation.empty()) {
+ dim->set_window_dilation(rhs_dilation[i]);
+ } else {
+ dim->set_window_dilation(1);
+ }
+ dim->set_window_reversal(false);
+ }
+ return window;
}
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index d747691f16..24e0be2ac1 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -835,6 +835,20 @@ class XlaBuilder {
void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
bool* is_constant) const;
+ // Checks bounds for convolution parameters.
+ Status VerifyConvolution(
+ const Shape& lhs_shape, const Shape& rhs_shape,
+ const ConvolutionDimensionNumbers& dimension_numbers) const;
+
+ // Helper function for creating a Window proto from user-supplied data.
+ // Returns error if the user-supplied data was invalid.
+ StatusOr<Window> MakeWindow(
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation) const;
+
string name_; // Name to use for the built computation.
// The first error encountered while building the computation.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 19fb4886db..67c53c6ac0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -781,10 +781,10 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 72715398de..5eb3136abe 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -20,10 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -88,12 +88,12 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
ASSERT_EQ(2, arhs->width());
ASSERT_EQ(2, arhs->height());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR4FromArray4D<T>(*alhs);
auto rhs = builder.ConstantR4FromArray4D<T>(*arhs);
- auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
- ComputeAndCompare(&builder, conv, {}, error_spec_);
+ ComputeAndCompare(&builder, {}, error_spec_);
}
};
@@ -106,12 +106,12 @@ template <typename T>
class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<T> input_data(1, 1, 1, 2);
input_data.FillWithYX(Array2D<T>({
@@ -122,7 +122,7 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
{5.0f, 6.0f},
}));
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(input_data)),
std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
@@ -137,12 +137,12 @@ template <typename T>
class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({
@@ -156,7 +156,7 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{5.0f, 6.0f},
{7.0f, 8.0f},
}));
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(input_data)),
std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
@@ -171,12 +171,12 @@ template <typename T>
class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({
@@ -191,7 +191,7 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
{7.0f, 8.0f},
}));
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(input_data)),
std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
@@ -207,12 +207,12 @@ template <typename T>
class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
@@ -223,7 +223,7 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
filter_data.FillWithYX(Array2D<T>(
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(input_data)),
std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
@@ -234,7 +234,7 @@ TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
@@ -264,7 +264,7 @@ template <typename T>
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
@@ -300,7 +300,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
@@ -331,7 +331,7 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
}
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
@@ -365,7 +365,7 @@ template <typename T>
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
@@ -402,7 +402,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<int64> input_dims = {1, 4, 2, 3, 3};
std::vector<int64> filter_dims = {2, 2, 2, 3, 3};
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
@@ -469,7 +469,7 @@ template <typename T>
class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
public:
void RunTest() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<int64> input_dims = {1, 3, 3, 5};
std::vector<int64> filter_dims = {3, 3, 5, 3};
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
@@ -537,7 +537,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"convolution-canonicalization");
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
@@ -551,8 +551,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
dnums.set_kernel_output_feature_dimension(1);
dnums.set_output_batch_dimension(0);
dnums.set_output_feature_dimension(1);
- auto conv = builder.ConvWithGeneralDimensions(input, filter, {},
- Padding::kValid, dnums);
+ builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
Array2D<float> param0(4, 29);
param0.FillUnique();
@@ -563,7 +562,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
Array2D<float> expected_result(29, 10);
expected_result.Fill(0);
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(param0)),
std::move(*Literal::CreateFromArray(param1))},
error_spec_);
@@ -587,7 +586,7 @@ class Convolve1D1WindowTestBase
protected:
template <typename T>
void TestImpl() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
int64 input_feature = GetParam().input_feature;
int64 output_feature = GetParam().output_feature;
int64 batch = GetParam().batch;
@@ -724,12 +723,12 @@ INSTANTIATE_TEST_CASE_P(
#endif
XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<bfloat16> input_data(1, 1, 1, 2);
input_data.FillWithYX(Array2D<bfloat16>({
@@ -740,7 +739,7 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
{bfloat16(5), bfloat16(6)},
}));
- ComputeAndCompare(&builder, conv,
+ ComputeAndCompare(&builder,
{std::move(*Literal::CreateFromArray(input_data)),
std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);