aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc13
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc63
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h8
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc15
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h5
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc17
-rw-r--r--tensorflow/compiler/xla/service/local_service.h3
-rw-r--r--tensorflow/compiler/xla/service/service.cc68
-rw-r--r--tensorflow/compiler/xla/service/service.h21
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc33
-rw-r--r--tensorflow/compiler/xla/xla.proto4
12 files changed, 90 insertions, 161 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index c5a68e05d9..e36eafa6e4 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -119,22 +119,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
// Ask the XLA compiler to evaluate the data handle to a literal.
- xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed =
+ xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
builder()->ComputeConstant(handle, &layout);
if (!computed.ok()) {
return errors::InvalidArgument(
"Error evaluating ", context_->op_kernel().name(), " input ", index,
": ", computed.status().error_message());
}
- // Fetch the literal from the compiler service.
- xla::StatusOr<std::unique_ptr<xla::Literal>> constant =
- builder()->client()->Transfer(*computed.ValueOrDie());
- if (!constant.ok()) {
- return errors::InvalidArgument(
- "Error evaluating ", context_->op_kernel().name(), " input ", index,
- ": ", constant.status().error_message());
- }
- constant_literal->Swap(constant.ValueOrDie().get());
+ constant_literal->Swap(computed.ValueOrDie().get());
+
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index cbecab5037..e6ffc4f98d 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -111,13 +111,12 @@ bool ComputationBuilder::MakeWindow(
return true;
} else {
NoteError(InvalidArgument(
- "%s",
- tensorflow::strings::StrCat(
- "Window has different number of window dimensions than of ",
- x_name, "\nNumber of window dimensions: ",
- window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
- "\n")
- .c_str())); //
+ "%s", tensorflow::strings::StrCat(
+ "Window has different number of window dimensions than of ",
+ x_name, "\nNumber of window dimensions: ",
+ window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
+ "\n")
+ .c_str())); //
return false;
}
};
@@ -663,24 +662,26 @@ bool ComputationBuilder::VerifyConvolution(
}
int num_spatial_dims = num_dims - 2;
- const auto check_spatial_dimensions = [&](
- const char* const field_name,
- const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
- numbers) {
- if (numbers.size() != num_spatial_dims) {
- NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
- num_spatial_dims, field_name, numbers.size()));
- return false;
- }
- for (int i = 0; i < numbers.size(); ++i) {
- if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
- NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
- field_name, i, numbers.Get(i)));
- return false;
- }
- }
- return true;
- };
+ const auto check_spatial_dimensions =
+ [&](const char* const field_name,
+ const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
+ numbers) {
+ if (numbers.size() != num_spatial_dims) {
+ NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
+ num_spatial_dims, field_name,
+ numbers.size()));
+ return false;
+ }
+ for (int i = 0; i < numbers.size(); ++i) {
+ if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
+ NoteError(
+ InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
+ field_name, i, numbers.Get(i)));
+ return false;
+ }
+ }
+ return true;
+ };
return check_spatial_dimensions("spatial_dimensions",
dimension_numbers.spatial_dimensions()) &&
check_spatial_dimensions(
@@ -1268,7 +1269,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
return response.is_constant();
}
-StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
+StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
const ComputationDataHandle& operand, const Layout* output_layout) {
if (!first_error_.ok()) {
return first_error_;
@@ -1291,8 +1292,14 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
return s;
}
- TF_RET_CHECK(response.output().handle() != 0);
- return MakeUnique<GlobalData>(client_->stub(), response.output());
+ VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
+
+ if (!response.has_literal()) {
+ return InternalError(
+ "no computed literal in the provided response in ComputeConstant "
+ "request");
+ }
+ return MakeUnique<Literal>(response.literal());
}
ComputationDataHandle ComputationBuilder::Map(
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index e35ed6186a..cf1f3b074e 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -679,12 +679,12 @@ class ComputationBuilder {
// Computes the value of a constant indicated by a
// ComputationDataHandle.
//
- // The handle must be from the computation currently being built -
+ // The operand must be from the computation currently being built -
// i.e., returned from this builder with no intervening call to
// Build(). This happens to currently work regardless of that, but
// that may stop working at any time.
//
- // The handle must represent a constant value, which in this case
+ // The operand must represent a constant value, which in this case
// means that it must not statically depend on a parameter to the
// computation that is being built.
//
@@ -702,8 +702,8 @@ class ComputationBuilder {
//
// If output_layout is non-null, then the output of the computation
// will be stored using that layout.
- StatusOr<std::unique_ptr<GlobalData>> ComputeConstant(
- const ComputationDataHandle& handle,
+ StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+ const ComputationDataHandle& operand,
const Layout* output_layout = nullptr);
// Returns a new ComputationBuilder whose resultant Computation is used only
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ab41dd3654..0da4e14de0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -428,6 +428,7 @@ cc_library(
":gpu_transfer_manager",
":hlo",
":hlo_cost_analysis",
+ ":hlo_evaluator",
":hlo_execution_profile",
":hlo_module_config",
":platform_util",
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 1dfe4a73b3..62dab56a71 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -50,19 +50,14 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
- CreateComputeConstantBackend());
- std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
- options, compiler, std::move(compute_constant_backend)));
+ std::unique_ptr<CompileOnlyService> service(
+ new CompileOnlyService(options, compiler));
return std::move(service);
}
-CompileOnlyService::CompileOnlyService(
- const ServiceOptions& options, Compiler* compiler,
- std::unique_ptr<Backend> compute_constant_backend)
- : Service(options, /*backend=*/nullptr,
- std::move(compute_constant_backend)),
- compiler_(compiler) {}
+CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
+ Compiler* compiler)
+ : Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 0a1911cbd1..9859941c6c 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -102,9 +102,8 @@ class CompileOnlyService : public Service {
}
private:
- explicit CompileOnlyService(
- const ServiceOptions& options, Compiler* compiler,
- std::unique_ptr<Backend> compute_constant_backend);
+ explicit CompileOnlyService(const ServiceOptions& options,
+ Compiler* compiler);
CompileOnlyService(const CompileOnlyService&) = delete;
void operator=(const CompileOnlyService&) = delete;
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 2042558a29..1eb4edbe3e 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -54,23 +54,19 @@ namespace xla {
}
BackendOptions backend_options;
- backend_options.set_platform(platform)
- .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
+ backend_options.set_platform(platform).set_intra_op_parallelism_threads(
+ options.intra_op_parallelism_threads());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
Backend::CreateBackend(backend_options));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
- CreateComputeConstantBackend());
- std::unique_ptr<LocalService> service(new LocalService(
- options, std::move(backend), std::move(compute_constant_backend)));
+ std::unique_ptr<LocalService> service(
+ new LocalService(options, std::move(backend)));
return std::move(service);
}
LocalService::LocalService(const ServiceOptions& options,
- std::unique_ptr<Backend> execute_backend,
- std::unique_ptr<Backend> compute_constant_backend)
- : Service(options, std::move(execute_backend),
- std::move(compute_constant_backend)) {}
+ std::unique_ptr<Backend> execute_backend)
+ : Service(options, std::move(execute_backend)) {}
namespace {
// Returns the space required to allocate a shape. If
@@ -161,7 +157,6 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
argument_layouts.size());
return BuildExecutable(versioned_handle, std::move(module_config),
- /*executable_for_compute_constant=*/false,
argument_buffers, execute_backend_.get(), executor);
}
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 13797ec045..c90943f3c0 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -57,8 +57,7 @@ class LocalService : public Service {
private:
explicit LocalService(const ServiceOptions& options,
- std::unique_ptr<Backend> backend,
- std::unique_ptr<Backend> compute_constant_backend);
+ std::unique_ptr<Backend> backend);
LocalService(const LocalService&) = delete;
void operator=(const LocalService&) = delete;
};
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 75c9571d27..ad2d5235f8 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
@@ -144,36 +145,15 @@ int ServiceOptions::intra_op_parallelism_threads() const {
backend_options.set_platform(platform);
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
- CreateComputeConstantBackend());
std::unique_ptr<Service> service(
- new Service(options, std::move(execute_backend),
- std::move(compute_constant_backend)));
+ new Service(options, std::move(execute_backend)));
return std::move(service);
}
-/* static */ StatusOr<std::unique_ptr<Backend>>
-Service::CreateComputeConstantBackend() {
- TF_ASSIGN_OR_RETURN(std::vector<se::Platform*> platforms,
- PlatformUtil::GetSupportedPlatforms());
- for (auto* platform : platforms) {
- if (platform->id() == se::host::kHostPlatformId) {
- BackendOptions backend_options;
- backend_options.set_platform(platform);
- return Backend::CreateBackend(backend_options);
- }
- }
- return NotFound("CPU platform not found");
-}
-
Service::Service(const ServiceOptions& options,
- std::unique_ptr<Backend> execute_backend,
- std::unique_ptr<Backend> compute_constant_backend)
- : options_(options),
- execute_backend_(std::move(execute_backend)),
- compute_constant_backend_(std::move(compute_constant_backend)) {
+ std::unique_ptr<Backend> execute_backend)
+ : options_(options), execute_backend_(std::move(execute_backend)) {
CHECK(options_.number_of_replicas() > 0);
-
if (execute_backend_) {
if (execute_backend_->device_count() > 0) {
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
@@ -418,7 +398,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config,
- bool executable_for_compute_constant,
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Backend* backend, se::StreamExecutor* executor) {
@@ -431,8 +410,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
module_config->debug_options().xla_dump_computations_to();
const string& other_directory_path =
module_config->debug_options().xla_dump_executions_to();
- if (!executable_for_compute_constant &&
- (!directory_path.empty() || !other_directory_path.empty())) {
+ if (!directory_path.empty() || !other_directory_path.empty()) {
TF_ASSIGN_OR_RETURN(
session_module,
computation_tracker_.SnapshotComputation(versioned_handle.handle));
@@ -450,7 +428,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
std::unique_ptr<HloModule> module,
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
/*include_unreachable_instructions=*/
- !executable_for_compute_constant));
+ true));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
@@ -490,8 +468,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
HloModuleConfig original_module_config = *module_config;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable_unique_ptr,
- BuildExecutable(versioned_handle, std::move(module_config),
- /*executable_for_compute_constant=*/false, arguments,
+ BuildExecutable(versioned_handle, std::move(module_config), arguments,
backend, executor));
if (profile != nullptr) {
@@ -1098,7 +1075,6 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
TF_ASSIGN_OR_RETURN(bool is_constant,
user_computation->IsConstant(arg->operand()));
-
if (!is_constant) {
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
}
@@ -1114,8 +1090,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
- execution_options.mutable_debug_options()->set_xla_backend_optimization_level(
- 0);
+ execution_options.mutable_debug_options()
+ ->set_xla_eliminate_hlo_implicit_broadcast(true);
*execution_options.mutable_shape_with_output_layout() =
program_shape.result();
@@ -1130,20 +1106,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, {}, execution_options));
+ // Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
- std::shared_ptr<Executable> executable,
- BuildExecutable(versioned_handle, std::move(module_config),
- /*executable_for_compute_constant=*/true,
- /*arguments=*/{}, compute_constant_backend_.get(),
- compute_constant_backend_->default_stream_executor()));
+ std::unique_ptr<HloModule> module,
+ computation_tracker_.BuildHloModule(versioned_handle, *module_config,
+ /*include_unreachable_instructions=*/
+ false));
+
+ HloEvaluator evaluator;
+ TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
+ // Since the shape_with_output_layout option in ExecutionOption is
+ // non-effective to the Evaluator results, explicit relayout here.
+ if (arg->has_output_layout()) {
+ result_literal = result_literal->Relayout(arg->output_layout());
+ }
+ *result->mutable_literal() = result_literal->ToProto();
- TF_ASSIGN_OR_RETURN(
- *result->mutable_output(),
- ExecuteAndRegisterResult(
- executable.get(), /*arguments=*/{}, compute_constant_backend_.get(),
- compute_constant_backend_->default_stream_executor(),
- "constant computed from " + user_computation->name(),
- /*profile=*/nullptr));
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index a07f7cd042..bb86a53c62 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -71,9 +71,9 @@ class ServiceOptions {
int intra_op_parallelism_threads_ = -1;
};
-// The XLA service object, which is the same across all
-// platforms. It maintains the service state of computations and allocations,
-// and delegates target-specific requests to the target-specific infrastructure
+// The XLA service object, which is the same across all platforms. It maintains
+// the service state of computations and allocations, and delegates
+// target-specific requests to the target-specific infrastructure
// (target-specific compiler, StreamExecutor).
class Service : public ServiceInterface {
public:
@@ -258,8 +258,8 @@ class Service : public ServiceInterface {
// The constructor is private. Use the NewService factory to create new
// service objects.
- Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
- std::unique_ptr<Backend> compute_constant_backend);
+ Service(const ServiceOptions& options,
+ std::unique_ptr<Backend> execute_backend);
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
@@ -280,16 +280,10 @@ class Service : public ServiceInterface {
const ExecutionOptions* execution_options,
bool has_hybrid_result = false);
- // Builds an Executable for the given parameters. If
- // executable_for_compute_constant is true, then the executable is intended to
- // be used for ComputeConstant which means dead parameter instructions are not
- // included in the executable.The parameter "profile" can optionally point to
- // an ExecutionProfile object which will be filled in with profile data
- // relevant to compilation.
+ // Builds an Executable for the given parameters.
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config,
- bool executable_for_compute_constant,
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Backend* backend, perftools::gputools::StreamExecutor* executor);
@@ -381,9 +375,6 @@ class Service : public ServiceInterface {
// TODO(b/28616830): Support multiple backends for execution.
std::unique_ptr<Backend> execute_backend_;
- // Backend to use when executing ComputeConstant.
- std::unique_ptr<Backend> compute_constant_backend_;
-
TF_DISALLOW_COPY_AND_ASSIGN(Service);
};
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index ec408e92d3..b2e9743af7 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -72,9 +72,8 @@ class ComputeConstantTest : public ::testing::Test {
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
Client* client, const ComputationDataHandle& operand,
ComputationBuilder* builder, Layout* output_layout = nullptr) {
- TF_ASSIGN_OR_RETURN(auto remote_computed,
+ TF_ASSIGN_OR_RETURN(auto computed,
builder->ComputeConstant(operand, output_layout));
- TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed));
return std::move(computed);
}
@@ -253,35 +252,5 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
}
}
-// This test is permanently disabled on CPU because it requires that the
-// backend used for execution is different than the backend used for
-// ComputeConstant which is always cpu.
-TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
- // Compute a trivial constant, then try to use the value in an Execute
- // call. This should fail because the constant resides on the CPU and the
- // Execute call is executed on a different backend. This test only makes
- // sense with LocalClient, since CompileOnlyClient does not support
- // execution.
- Client* client = ClientOrDie(platform_, ClientType::kLocal);
- ComputationBuilder constant_b(client, TestName());
- auto constant = constant_b.ConstantR0<int32>(42);
- auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
- auto literal = client->Transfer(*handle).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR0Equal(42, *literal);
-
- // Build trivial computation which takes one parameter.
- ComputationBuilder b(client, TestName());
- b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
- auto computation = b.Build().ConsumeValueOrDie();
-
- // Try to use value from ComputeConstant in Execute.
- auto execute_status = client->Execute(computation, {handle.get()});
- EXPECT_FALSE(execute_status.ok());
- EXPECT_THAT(
- execute_status.status().error_message(),
- ::testing::ContainsRegex("argument 0 is on device Host:0 but computation "
- "will be executed on device"));
-}
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index d0f4a548ed..ae033ae826 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -350,7 +350,9 @@ message ComputeConstantRequest {
}
message ComputeConstantResponse {
- GlobalDataHandle output = 1;
+ // A LiteralProto is returned directly for this request, instead of a
+ // ComputationDataHandle.
+ LiteralProto literal = 1;
}
message DeconstructTupleRequest {