aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/computation.h2
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD5
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc90
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h55
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc16
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.h1
-rw-r--r--tensorflow/compiler/xla/service/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc73
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/BUILD18
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc19
-rw-r--r--tensorflow/compiler/xla/tests/set_return_value_test.cc98
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc3
19 files changed, 85 insertions, 343 deletions
diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h
index a53fc9e9cf..9a1bcde763 100644
--- a/tensorflow/compiler/xla/client/computation.h
+++ b/tensorflow/compiler/xla/client/computation.h
@@ -30,6 +30,8 @@ namespace xla {
// Wraps a ComputationHandle protobuf with a lifetime. Computation is
// movable and not copyable to capture the same kind of unique
// ownership that std::unique_ptr represents.
+//
+// TODO(b/74197823): Deprecated. Use XlaComputation instead.
class Computation {
public:
// Creates a null Computation.
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 9431c2c459..ac1eb915cc 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -48,6 +48,8 @@ namespace xla {
// deferred from being handled until Build() is called.
//
// Thread-compatible.
+//
+// TODO(b/74197823): Deprecated. Use XlaBuilder instead.
class ComputationBuilder {
public:
// client: client in which to build the computation.
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 59c4a53c05..d49d959a6c 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -22,8 +22,6 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
@@ -43,9 +41,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index 63df449e0b..a1d34796cc 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,7 +17,8 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,28 +28,6 @@ limitations under the License.
namespace xla {
namespace {
-using InstructionGenerator =
- ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&,
- const ComputationDataHandle&);
-
-Computation CreateScalarComputation(const string& name, PrimitiveType type,
- ComputationBuilder* builder,
- InstructionGenerator generator) {
- std::unique_ptr<ComputationBuilder> b;
- if (type == PRED) {
- b = builder->CreateSubBuilder(name);
- } else {
- b = builder->CreateSubBuilder(
- tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
- }
-
- const Shape scalar = ShapeUtil::MakeShape(type, {});
- auto lhs = b->Parameter(0, scalar, "lhs");
- auto rhs = b->Parameter(1, scalar, "rhs");
- generator(b.get(), lhs, rhs);
- return b->BuildAndNoteError();
-}
-
using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
@@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
} // namespace
-Computation CreateScalarAddComputation(PrimitiveType type,
- ComputationBuilder* builder) {
- return CreateScalarComputation(
- "add", type, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); });
-}
-
-Computation CreateScalarMultiplyComputation(PrimitiveType type,
- ComputationBuilder* builder) {
- return CreateScalarComputation(
- "mul", type, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); });
-}
-
-Computation CreateScalarGeComputation(PrimitiveType type,
- ComputationBuilder* builder) {
- return CreateScalarComputation(
- "ge", type, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); });
-}
-
-Computation CreateScalarMaxComputation(PrimitiveType type,
- ComputationBuilder* builder) {
- return CreateScalarComputation(
- "max", type, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); });
-}
-
-Computation CreateScalarMinComputation(PrimitiveType type,
- ComputationBuilder* builder) {
- return CreateScalarComputation(
- "min", type, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); });
-}
-
-Computation CreateScalarAndComputation(ComputationBuilder* builder) {
- return CreateScalarComputation(
- "and", PRED, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->And(lhs, rhs); });
-}
-
-Computation CreateScalarOrComputation(ComputationBuilder* builder) {
- return CreateScalarComputation(
- "or", PRED, builder,
- [](ComputationBuilder* b, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); });
-}
-
-StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
- ComputationBuilder* builder) {
- auto f = builder->ConstantR0<bool>(false);
- Computation logical_or = CreateScalarOrComputation(builder);
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Shape> predicates_shape,
- builder->GetShape(predicates));
- std::vector<int64> all_dimensions(ShapeUtil::Rank(*predicates_shape));
- std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
- return builder->Reduce(predicates, f, logical_or, all_dimensions);
-}
-
XlaComputation CreateScalarAddComputation(PrimitiveType type,
XlaBuilder* builder) {
return CreateScalarComputation(
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index f4d3fc8015..64b6b7d633 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -18,8 +18,6 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -27,74 +25,31 @@ limitations under the License.
namespace xla {
// Creates a scalar add computation and returns it.
-Computation CreateScalarAddComputation(PrimitiveType type,
- ComputationBuilder* builder);
-
-// Creates a scalar multiply computation and returns it.
-Computation CreateScalarMultiplyComputation(PrimitiveType type,
- ComputationBuilder* builder);
-
-// Creates a scalar ge computation and returns it.
-Computation CreateScalarGeComputation(PrimitiveType type,
- ComputationBuilder* builder);
-
-// Creates a scalar max computation and returns it.
-Computation CreateScalarMaxComputation(PrimitiveType type,
- ComputationBuilder* builder);
-
-// Creates a scalar min computation and returns it.
-Computation CreateScalarMinComputation(PrimitiveType type,
- ComputationBuilder* builder);
-
-// Creates a scalar logical AND computation and returns it.
-Computation CreateScalarAndComputation(ComputationBuilder* builder);
-
-// Creates a scalar logical OR computation and returns it.
-Computation CreateScalarOrComputation(ComputationBuilder* builder);
-
-// Returns whether any predicate in "predicates" is set.
-//
-// Note: if predicates is zero-sized, Any() vacuously returns false.
-StatusOr<ComputationDataHandle> Any(const ComputationDataHandle& predicates,
- ComputationBuilder* builder);
-
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
-// Creates a scalar add computation and returns it.
XlaComputation CreateScalarAddComputation(PrimitiveType type,
XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
+
// Creates a scalar multiply computation and returns it.
XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
+
// Creates a scalar ge computation and returns it.
XlaComputation CreateScalarGeComputation(PrimitiveType type,
XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
+
// Creates a scalar max computation and returns it.
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
+
// Creates a scalar min computation and returns it.
XlaComputation CreateScalarMinComputation(PrimitiveType type,
XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
+
// Creates a scalar logical AND computation and returns it.
XlaComputation CreateScalarAndComputation(XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
// Creates a scalar logical OR computation and returns it.
XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
-//
// Returns whether any predicate in "predicates" is set.
//
// Note: if predicates is zero-sized, Any() vacuously returns false.
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 311dc4bdd7..9cd87f7473 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) {
return total_size;
}
-// Create a ComputationDataHandle for an op what generates fake data with the
-// given shape.
-ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape,
- ComputationBuilder* builder) {
+// Creates a XlaOp for an op what generates fake data with the given shape.
+XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
if (ShapeUtil::IsArray(shape)) {
return builder->Broadcast(
builder->ConstantLiteral(Literal::One(shape.element_type())),
AsInt64Slice(shape.dimensions()));
}
- std::vector<ComputationDataHandle> parts;
+ std::vector<XlaOp> parts;
for (const Shape& s : shape.tuple_shapes()) {
parts.push_back(BuildFakeDataOpOnDevice(s, builder));
}
@@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape,
std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
Client* client) {
- ComputationBuilder b(
- client,
+ XlaBuilder b(
tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
BuildFakeDataOpOnDevice(shape, &b);
- Computation computation = b.Build().ConsumeValueOrDie();
+ XlaComputation computation = b.Build().ConsumeValueOrDie();
auto execution_options = CreateDefaultExecutionOptions();
*execution_options.mutable_shape_with_output_layout() = shape;
diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h
index 1dc2622972..9e06141b1f 100644
--- a/tensorflow/compiler/xla/client/lib/testing.h
+++ b/tensorflow/compiler/xla/client/lib/testing.h
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 0b8b22b44c..9c362d8cad 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -233,7 +233,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1669,10 +1669,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -2406,7 +2406,6 @@ tf_cc_test(
srcs = ["hlo_tfgraph_builder_test.cc"],
deps = [
":hlo_tfgraph_builder",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
@@ -2475,7 +2474,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -2512,6 +2511,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index cb81e413a3..7e6d58c7fa 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -365,10 +365,10 @@ tf_cc_binary(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index b3f4609d46..167aa4adda 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -19,10 +19,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -48,13 +48,13 @@ int main(int argc, char** argv) {
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
// Build computation.
- xla::ComputationBuilder builder(client, "");
+ xla::XlaBuilder builder("");
auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
auto add = builder.Add(p1, p0, {0});
- xla::StatusOr<xla::Computation> computation_status = builder.Build();
- xla::Computation computation = computation_status.ConsumeValueOrDie();
+ xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
+ xla::XlaComputation computation = computation_status.ConsumeValueOrDie();
// Execute and transfer result of computation.
xla::ExecutionProfile profile;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 81cc7c4bdc..16fdda8a8b 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -20,16 +20,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/user_computation.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/logging.h"
@@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test {
// whitebox accesses to the user computation built from the client,
// as shown in the BuildHloGraph functions below.
service_(static_cast<Service*>(ClientLibrary::GetXlaService(
- static_cast<LocalClient*>(client_)->platform()))),
- computation_tracker_(service_->computation_tracker()) {
+ static_cast<LocalClient*>(client_)->platform()))) {
// Create a computation for a unary user function: x => exp(x + 0.5)
{
- ComputationBuilder builder(client_, "add_and_exp");
+ XlaBuilder builder("add_and_exp");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto half = builder.ConstantR0<float>(0.5);
builder.Exp(builder.Add(x, half));
@@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary user function: (x, y) => x + y
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
{
- ComputationBuilder builder(client_, "sigmoid");
+ XlaBuilder builder("sigmoid");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto one = builder.ConstantR0<float>(1.0);
builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
@@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary max function: (x, y) => max (x, y)
{
- ComputationBuilder builder(client_, "max");
+ XlaBuilder builder("max");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Max(x, y);
@@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary GT function: (x, y) => x > y
{
- ComputationBuilder builder(client_, "gt");
+ XlaBuilder builder("gt");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Gt(x, y);
@@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test {
}
// Build HLO graph from the given builder and return the HLO module.
- std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) {
+ std::unique_ptr<HloModule> BuildHloGraph(XlaBuilder* builder) {
auto computation_status = builder->Build();
TF_CHECK_OK(computation_status.status());
auto computation = computation_status.ConsumeValueOrDie();
- auto user_computation_status =
- computation_tracker_.Resolve(computation.handle());
- TF_CHECK_OK(user_computation_status.status());
- auto user_computation = user_computation_status.ConsumeValueOrDie();
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
- return std::move(
- computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig())
- .ValueOrDie());
+ auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
+ DebugOptions())
+ .ConsumeValueOrDie();
+ return HloModule::CreateFromProto(computation.proto(), config)
+ .ConsumeValueOrDie();
}
Client* client_;
Service* service_;
- const ComputationTracker& computation_tracker_;
// User computations used for higher order operations (e.g., Map, Reduce).
- Computation add_;
- Computation add_and_exp_;
- Computation sigmoid_;
- Computation max_;
- Computation gt_;
+ XlaComputation add_;
+ XlaComputation add_and_exp_;
+ XlaComputation sigmoid_;
+ XlaComputation max_;
+ XlaComputation gt_;
};
TEST_F(HloCostAnalysisTest, MatrixMultiply) {
- ComputationBuilder builder(client_, "matrix_multiply");
+ XlaBuilder builder("matrix_multiply");
auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
auto result = builder.Dot(lhs, rhs);
@@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
}
TEST_F(HloCostAnalysisTest, Map) {
- ComputationBuilder builder(client_, "map");
+ XlaBuilder builder("map");
auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
auto result = builder.Map({input}, add_and_exp_, {0});
@@ -184,7 +175,7 @@ TEST_F(HloCostAnalysisTest, Map) {
}
TEST_F(HloCostAnalysisTest, Convolution) {
- ComputationBuilder builder(client_, "convolution");
+ XlaBuilder builder("convolution");
auto input = builder.Parameter(
0,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
@@ -213,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) {
}
TEST_F(HloCostAnalysisTest, Reduce) {
- ComputationBuilder builder(client_, "reduce");
+ XlaBuilder builder("reduce");
auto input =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
auto result =
@@ -231,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) {
}
TEST_F(HloCostAnalysisTest, ReduceWindow) {
- ComputationBuilder builder(client_, "reduce_window");
+ XlaBuilder builder("reduce_window");
auto input =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
@@ -248,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
}
TEST_F(HloCostAnalysisTest, SelectAndScatter) {
- ComputationBuilder builder(client_, "select_and_scatter");
+ XlaBuilder builder("select_and_scatter");
auto operand =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
auto source =
@@ -269,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) {
}
TEST_F(HloCostAnalysisTest, Broadcast) {
- ComputationBuilder b(client_, "broadcast");
+ XlaBuilder b("broadcast");
b.Broadcast(b.ConstantR0<float>(42), {10, 7});
auto hlo_module = BuildHloGraph(&b);
HloCostAnalysis analysis(ShapeSize);
@@ -280,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) {
// Calculates the computation cost of a graph with more than one HLO node.
TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
- ComputationBuilder builder(client_, "fully_connected_forward");
+ XlaBuilder builder("fully_connected_forward");
auto input =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
auto weight =
@@ -305,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
HloCostAnalysis conv_analysis(ShapeSize);
{
- ComputationBuilder builder(client_, "conv_looking_matmul");
+ XlaBuilder builder("conv_looking_matmul");
auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
"input");
auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
@@ -318,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
HloCostAnalysis matmul_analysis(ShapeSize);
{
- ComputationBuilder builder(client_, "matmul");
+ XlaBuilder builder("matmul");
auto lhs =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
auto rhs =
@@ -427,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- ComputationBuilder builder(client_, "matmul");
+ XlaBuilder builder("matmul");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y");
auto tuple = builder.Tuple({x, y});
@@ -443,7 +434,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
}
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
- ComputationBuilder builder(client_, "BaseDilatedConvolution");
+ XlaBuilder builder("BaseDilatedConvolution");
auto input = builder.Parameter(
0,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
@@ -458,7 +449,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
auto result = builder.ConvGeneralDilated(
input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}},
/*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
- ComputationBuilder::CreateDefaultConvDimensionNumbers(2));
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2));
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 230147abfe..cc16446778 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -827,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
*window.add_dimensions() = dim;
ConvolutionDimensionNumbers dnums =
- ComputationBuilder::CreateDefaultConvDimensionNumbers(2);
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2);
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
b.AddInstruction(HloInstruction::CreateConvolve(
@@ -1046,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
*window.add_dimensions() = dim;
ConvolutionDimensionNumbers dnums =
- ComputationBuilder::CreateDefaultConvDimensionNumbers(2);
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2);
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
b.AddInstruction(HloInstruction::CreateConvolve(
@@ -1109,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
*window.add_dimensions() = dim;
ConvolutionDimensionNumbers dnums =
- ComputationBuilder::CreateDefaultConvDimensionNumbers(2);
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2);
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
b.AddInstruction(HloInstruction::CreateConvolve(
@@ -1180,7 +1180,7 @@ TEST_P(HloEvaluatorTest,
*window.add_dimensions() = dim;
ConvolutionDimensionNumbers dnums =
- ComputationBuilder::CreateDefaultConvDimensionNumbers(2);
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2);
const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
b.AddInstruction(HloInstruction::CreateConvolve(
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index f8d98f0678..be156d765d 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index c7c4160345..0319109f7f 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -222,7 +222,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
HloInstruction* transpose_y =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3}));
- auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
Window window;
for (int i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
@@ -275,7 +275,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
HloInstruction* transpose_y =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2}));
- auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
Window window;
for (int i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
@@ -334,7 +334,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
HloInstruction* transpose_x =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3}));
- auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
Window window;
for (int i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
@@ -398,7 +398,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
HloInstruction* transpose_x =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
- auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers();
Window window;
for (int i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
index a4e67cc9d9..f5331280ee 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 54cf0543b8..0571ff5055 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1934,24 +1934,6 @@ xla_test(
)
xla_test(
- name = "set_return_value_test",
- srcs = ["set_return_value_test.cc"],
- deps = [
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
- "//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
- "//tensorflow/compiler/xla/client/xla_client:xla_computation",
- "//tensorflow/compiler/xla/tests:client_library_test_base",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- ],
-)
-
-xla_test(
name = "reshape_motion_test",
srcs = ["reshape_motion_test.cc"],
deps = [
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 3704ddd801..a366afe826 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -21,7 +21,8 @@ limitations under the License.
#include "llvm/ADT/Triple.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -29,27 +30,31 @@ limitations under the License.
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
+namespace {
+
using xla::string;
-xla::Computation Doubler(xla::Client* client) {
- xla::ComputationBuilder builder(client, "doubler");
+xla::XlaComputation Doubler() {
+ xla::XlaBuilder builder("doubler");
auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
auto x = builder.Parameter(0, r0f32, "x");
builder.Mul(x, builder.ConstantR0<float>(2.0));
return std::move(builder.Build().ValueOrDie());
}
+} // namespace
+
int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
- xla::ComputationBuilder builder(client, "aot_test_helper");
+ xla::XlaBuilder builder("aot_test_helper");
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
auto opaque_param = builder.Parameter(0, opaque_shape, "x");
auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32);
- builder.Call(Doubler(client), {sum});
+ builder.Call(Doubler(), {sum});
if (argc != 2) {
LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
@@ -71,8 +76,8 @@ int main(int argc, char** argv) {
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
- xla::Computation computation = builder.Build().ConsumeValueOrDie();
- xla::CompileOnlyClient::AotComputationInstance instance{
+ xla::XlaComputation computation = builder.Build().ConsumeValueOrDie();
+ xla::CompileOnlyClient::AotXlaComputationInstance instance{
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
xla::cpu::CpuAotCompilationOptions options(
diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc
deleted file mode 100644
index 29f79ec28a..0000000000
--- a/tensorflow/compiler/xla/tests/set_return_value_test.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <vector>
-
-#include "tensorflow/compiler/xla/client/computation_builder.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
-#include "tensorflow/compiler/xla/tests/literal_test_util.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace xla {
-namespace {
-
-class SetReturnValueTest : public ClientLibraryTestBase {};
-
-TEST_F(SetReturnValueTest, NoSetValue) {
- ComputationBuilder builder(client_, "no_set_value");
- auto alpha = builder.ConstantR0<float>(1.0);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto ax = builder.Add(alpha, x);
- auto aax = builder.Add(alpha, ax);
-
- std::vector<float> expected = {1.0, 3.0, 4.0, 0.0, -1.0,
- 5.0, 6.0, -2.0, -3.0, 7.0};
-
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
-}
-
-TEST_F(SetReturnValueTest, SetValue) {
- ComputationBuilder builder(client_, "set_value");
- auto alpha = builder.ConstantR0<float>(1.0);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto ax = builder.Add(alpha, x);
- auto aax = builder.Add(alpha, ax);
- auto builder_status = builder.SetReturnValue(ax);
- EXPECT_TRUE(builder_status.ok());
-
- std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
- 4.0, 5.0, -3.0, -4.0, 6.0};
-
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
-}
-
-TEST_F(SetReturnValueTest, SetValueAndModify) {
- ComputationBuilder builder(client_, "set_value_and_modify");
- auto alpha = builder.ConstantR0<float>(1.0);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto ax = builder.Add(alpha, x);
- auto aax = builder.Add(alpha, ax);
- auto builder_status = builder.SetReturnValue(ax);
- EXPECT_TRUE(builder_status.ok());
- auto aaax = builder.Add(alpha, aax);
-
- std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
- 4.0, 5.0, -3.0, -4.0, 6.0};
-
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
-}
-
-TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) {
- ComputationBuilder builder(client_, "set_value_multiple_times_and_modify");
- auto alpha = builder.ConstantR0<float>(1.0);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto ax = builder.Add(alpha, x);
- auto aax = builder.Add(alpha, ax);
- auto builder_status = builder.SetReturnValue(aax);
- EXPECT_TRUE(builder_status.ok());
- auto aaax = builder.Add(alpha, aax);
- builder_status = builder.SetReturnValue(ax);
- EXPECT_TRUE(builder_status.ok());
- auto aaaax = builder.Add(alpha, aaax);
-
- std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
- 4.0, 5.0, -3.0, -4.0, 6.0};
-
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
-}
-
-} // namespace
-} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index 3dded3f715..5cce7a2bf8 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -350,7 +349,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
}
XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR0<int64>(0);
auto one = builder.ConstantR0<int64>(10);
auto x = builder.ConstantR1<int64>({-3, 3, 9, 13});