aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc33
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h49
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc6
-rw-r--r--tensorflow/compiler/xla/service/service.cc48
-rw-r--r--tensorflow/compiler/xla/service/service.h3
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h20
16 files changed, 70 insertions, 177 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index cf07910c4a..5f9710914b 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
- const ComputationLayout& host_computation_layout =
- executable_->module_config().host_entry_computation_layout();
- const ComputationLayout& device_computation_layout =
- executable_->module_config().device_entry_computation_layout();
+ const ComputationLayout& computation_layout =
+ executable_->module_config().entry_computation_layout();
// Check argument number, shapes, and layouts.
- if (arguments.size() != host_computation_layout.parameter_count()) {
+ if (arguments.size() != computation_layout.parameter_count()) {
return InvalidArgument(
"invalid number of arguments for computation: expected %d, got %zu",
- host_computation_layout.parameter_count(), arguments.size());
- }
- if (arguments.size() != device_computation_layout.parameter_count()) {
- return InvalidArgument(
- "invalid number of arguments for computation: expected %d, got %zu",
- device_computation_layout.parameter_count(), arguments.size());
+ computation_layout.parameter_count(), arguments.size());
}
for (int i = 0; i < arguments.size(); ++i) {
- if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
return InvalidParameterArgument(
executable_.get(), i,
@@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions(
"parameter "
"%d: want %s, got %s",
i,
- ShapeUtil::HumanString(
- host_computation_layout.parameter_layout(i).shape())
+ ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
.c_str(),
ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
}
- if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape(
- arguments[i]->on_device_shape())) {
- return InvalidParameterArgument(
- executable_.get(), i,
- "Argument does not match device shape or layout of computation "
- "parameter "
- "%d: want %s, got %s",
- i,
- ShapeUtil::HumanString(
- device_computation_layout.parameter_layout(i).shape())
- .c_str(),
- ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str());
- }
}
if (run_options.stream() != nullptr) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index d039132535..52da9d6eac 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -303,8 +303,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_device_entry_computation_layout(),
- &target_machine_features);
+ module->mutable_entry_computation_layout(), &target_machine_features);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index cf43b74c69..1093559892 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -206,8 +206,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
- /*on_host_shape=*/host_result_shape(),
- /*on_device_shape=*/host_result_shape(), run_options->allocator(),
+ /*on_host_shape=*/result_shape(),
+ /*on_device_shape=*/result_shape(), run_options->allocator(),
stream->parent()->device_ordinal());
// Move OwningDeviceMemory values which contain the array(s) of the result
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index bd92bfa50f..98eaeee30a 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -131,8 +131,8 @@ class Executable {
// The shape (including layout) that results from this execution. This is the
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
- const Shape& host_result_shape() const {
- return hlo_module_->config().host_entry_computation_layout().result_shape();
+ const Shape& result_shape() const {
+ return hlo_module_->config().entry_computation_layout().result_shape();
}
// Returns the size of the executable in bytes. Returns -1 by default if the
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index a040e6b681..decfc40daf 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -205,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassPipeline pipeline("layout_assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout(), stream_exec);
+ hlo_module->mutable_entry_computation_layout(), stream_exec);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 11384c1456..39bc25ba42 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal(
// If the module configuration has no entry layout computation set, create a
// default one based on the program shape.
- if (!config_.has_host_entry_computation_layout()) {
+ if (!config_.has_entry_computation_layout()) {
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
@@ -231,14 +231,11 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
TF_RET_CHECK(proto.has_program_shape())
<< "No program shape found in the proto";
const auto& expected_program_shape = proto.program_shape();
- TF_RET_CHECK(
- expected_program_shape.parameters_size() ==
- module_config.device_entry_computation_layout().parameter_count());
+ TF_RET_CHECK(expected_program_shape.parameters_size() ==
+ module_config.entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
- module_config.device_entry_computation_layout()
- .parameter_layout(i)
- .shape();
+ module_config.entry_computation_layout().parameter_layout(i).shape();
TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
@@ -248,7 +245,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
}
const Shape& result_shape =
- module_config.device_entry_computation_layout().result_layout().shape();
+ module_config.entry_computation_layout().result_layout().shape();
TF_RET_CHECK(
ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
@@ -327,7 +324,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
ComputationLayout* entry_layout =
- module_config.mutable_host_entry_computation_layout();
+ module_config.mutable_entry_computation_layout();
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
TF_RETURN_IF_ERROR(
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -335,9 +332,6 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
}
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
program_shape.result()));
- *module_config.mutable_device_entry_computation_layout() =
- module_config.host_entry_computation_layout();
-
return module_config;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 5dc94e78e3..d2e726a0db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -105,20 +105,19 @@ class HloModule {
return entry_computation_;
}
- ComputationLayout* mutable_host_entry_computation_layout() {
- return config_.mutable_host_entry_computation_layout();
+ // Creates the ComputationLayout which describes the current status of the HLO
+ // module entry computation.
+ ComputationLayout compute_computation_layout() const {
+ return ComputationLayout(entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
}
- const ComputationLayout& host_entry_computation_layout() const {
- return config_.host_entry_computation_layout();
+ ComputationLayout* mutable_entry_computation_layout() {
+ return config_.mutable_entry_computation_layout();
}
- ComputationLayout* mutable_device_entry_computation_layout() {
- return config_.mutable_device_entry_computation_layout();
- }
-
- const ComputationLayout& device_entry_computation_layout() const {
- return config_.device_entry_computation_layout();
+ const ComputationLayout& entry_computation_layout() const {
+ return config_.entry_computation_layout();
}
// Gets the computations in this module.
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index dae5578a31..07a8c798db 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -28,16 +28,14 @@ namespace xla {
using tensorflow::strings::StrAppend;
-HloModuleConfig::HloModuleConfig() {}
-
-HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
- : host_entry_computation_layout_(program_shape),
- device_entry_computation_layout_(program_shape) {}
+HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
+ bool ignore_layouts)
+ : entry_computation_layout_(
+ ComputationLayout(program_shape, ignore_layouts)) {}
void HloModuleConfig::SetDefaultComputationLayout(
const ProgramShape& program_shape) {
- host_entry_computation_layout_ = ComputationLayout(program_shape);
- device_entry_computation_layout_ = ComputationLayout(program_shape);
+ entry_computation_layout_ = ComputationLayout(program_shape);
}
string HloModuleConfig::compilation_cache_key() const {
@@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const {
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
- host_entry_computation_layout_->parameter_layouts()) {
+ entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
- host_entry_computation_layout_->result_shape().SerializeAsString());
- for (const ShapeLayout& param_layout :
- device_entry_computation_layout_->parameter_layouts()) {
- params.push_back(param_layout.shape().DebugString());
- }
- StrAppend(
- &key, tensorflow::str_util::Join(params, ", "), ") => ",
- device_entry_computation_layout_->result_shape().SerializeAsString());
+ entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
static std::atomic<int> counter{0};
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index cdb0b29a23..074e9c9070 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -37,48 +37,34 @@ class HloModuleConfig {
// ComputationLayout. The default ctor creates it without -- in this case
// accessing entry_computation_layout will CHECK-fail. The ctor accepting a
// ProgramShape creates a computation layout using this shape.
- HloModuleConfig();
- explicit HloModuleConfig(const ProgramShape& program_shape);
+ // The layouts in the ProgramShape will be reset to default unless
+ // ignore_layouts is set to false.
+ HloModuleConfig() = default;
- // Checks if this config has an entry computation layout already.
- bool has_host_entry_computation_layout() const {
- return host_entry_computation_layout_.has_value();
- }
+ explicit HloModuleConfig(const ProgramShape& program_shape,
+ bool ignore_layouts = true);
- bool has_device_entry_computation_layout() const {
- return device_entry_computation_layout_.has_value();
+ // Checks if this config has an entry computation layout already.
+ bool has_entry_computation_layout() const {
+ return entry_computation_layout_.has_value();
}
// Sets the entry computation layout for this config. If the entry computation
// layout already exists, it is silently replaced.
void SetDefaultComputationLayout(const ProgramShape& program_shape);
- // Returns a constant reference to the on-host layout of the entry
- // computation. Assumes the layout was set.
- const ComputationLayout& host_entry_computation_layout() const {
- CHECK(host_entry_computation_layout_.has_value());
- return *host_entry_computation_layout_;
- }
-
- // Returns a mutable pointer to the layout of the on-host entry computation.
+ // Returns a constant reference to the layout of the entry computation.
// Assumes the layout was set.
- ComputationLayout* mutable_host_entry_computation_layout() {
- CHECK(host_entry_computation_layout_.has_value());
- return &(*host_entry_computation_layout_);
- }
-
- // Returns a constant reference to the on-device layout of the entry
- // computation. Assumes the layout was set.
- const ComputationLayout& device_entry_computation_layout() const {
- CHECK(device_entry_computation_layout_.has_value());
- return *device_entry_computation_layout_;
+ const ComputationLayout& entry_computation_layout() const {
+ CHECK(entry_computation_layout_.has_value());
+ return *entry_computation_layout_;
}
- // Returns a mutable pointer to the layout of the on-device entry computation.
+ // Returns a mutable pointer to the layout of the entry computation.
// Assumes the layout was set.
- ComputationLayout* mutable_device_entry_computation_layout() {
- CHECK(device_entry_computation_layout_.has_value());
- return &(*device_entry_computation_layout_);
+ ComputationLayout* mutable_entry_computation_layout() {
+ CHECK(entry_computation_layout_.has_value());
+ return &(*entry_computation_layout_);
}
// Returns whether to enable HLO-level profiling.
@@ -127,8 +113,7 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> host_entry_computation_layout_;
- tensorflow::gtl::optional<ComputationLayout> device_entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index daa3bc4232..2cee74c314 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -327,22 +327,15 @@ bool HloParser::ParseComputations() {
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
- ->mutable_parameter_layout(p)
- ->CopyLayoutFromShape(param_shape));
- TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->CopyLayoutFromShape(result_shape));
- TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
-
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index d551400d1e..d481e07f60 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1302,7 +1302,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
- auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
+ auto program_layout = module.ValueOrDie()->entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
auto param_layout = program_layout.parameter_layout(0).layout();
auto result_layout = program_layout.result_layout().layout();
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index c166653068..9f8f4bda87 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout());
+ hlo_module->mutable_entry_computation_layout());
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index a6aa8bf82c..53efc30c36 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -190,10 +190,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, argument_layouts, &execution_options));
- VLOG(3) << "Host Computation Layout: "
- << module_config->host_entry_computation_layout().ToString();
- VLOG(3) << "Device Computation Layout: "
- << module_config->device_entry_computation_layout().ToString();
+ VLOG(3) << "Computation Layout: "
+ << module_config->entry_computation_layout().ToString();
TF_ASSIGN_OR_RETURN(
se::StreamExecutor * executor,
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 7ab39e01f2..da3b622bfa 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -244,10 +244,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
- ComputationLayout* host_computation_layout =
- config->mutable_host_entry_computation_layout();
- ComputationLayout* device_computation_layout =
- config->mutable_device_entry_computation_layout();
+ ComputationLayout* computation_layout =
+ config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %zu given",
program_shape.parameters_size(),
@@ -264,10 +262,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
- TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i)
- ->CopyLayoutFromShape(*argument_shapes[i]));
- TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i)
- ->CopyLayoutFromShape(*argument_shapes[i]));
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ *argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
@@ -276,20 +273,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
TF_RETURN_IF_ERROR(
ValidateResultShape(shape_with_output_layout, program_shape.result()));
TF_RETURN_IF_ERROR(
- host_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
- shape_with_output_layout));
- TF_RETURN_IF_ERROR(
- device_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
// If the result layout is not set, then choose the default.
- // TODO(b/29118294): Allow the compiler to choose a better layout in this
- // case.
- // TODO(b/78356948): We are forcing the default layout here. We should fix
- // clients which expect a default layout, to be explicit about it, by
- // passing the proper ExecutionOptions with shape_with_output_layout set.
- host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
- device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(options_.number_of_replicas());
@@ -377,24 +365,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
return std::move(executables);
}
-Status Service::ValidateEntryComputationLayout(HloModule* module) {
- const ComputationLayout& on_host = module->host_entry_computation_layout();
- const ComputationLayout& on_device =
- module->device_entry_computation_layout();
- for (int64 i = 0; i < on_device.parameter_count(); ++i) {
- TF_RET_CHECK(ShapeUtil::Compatible(on_device.parameter_shape(i),
- on_host.parameter_shape(i)))
- << ShapeUtil::HumanStringWithLayout(on_device.parameter_shape(i))
- << " vs "
- << ShapeUtil::HumanStringWithLayout(on_host.parameter_shape(i));
- }
- TF_RET_CHECK(
- ShapeUtil::Compatible(on_device.result_shape(), on_host.result_shape()))
- << ShapeUtil::HumanStringWithLayout(on_device.result_shape()) << " vs "
- << ShapeUtil::HumanStringWithLayout(on_host.result_shape());
- return Status::OK();
-}
-
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice<Executable*> executables,
@@ -690,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
request.execution_options()));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
- << module_config->host_entry_computation_layout().ToString();
+ << module_config->entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -851,8 +821,6 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
- // Check that on-host and on-device shapes are consistent.
- TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 7960429084..47d196fb2a 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -193,9 +193,6 @@ class Service : public ServiceInterface {
const ExecutionOptions& execution_options,
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
- // Assert that host- and device-shapes are in a consistent state.
- Status ValidateEntryComputationLayout(HloModule* module);
-
protected:
friend class LocalExecutable;
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 249da87f48..9009d67cea 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -185,13 +185,9 @@ class HloTestBase : public ::testing::Test {
// 'layout'.
void ForceParameterLayout(HloModule* module, int64 param_no,
const Layout& layout) {
- ASSERT_LT(
- param_no,
- module->mutable_host_entry_computation_layout()->parameter_count());
- module->mutable_host_entry_computation_layout()
- ->mutable_parameter_layout(param_no)
- ->ResetLayout(layout);
- module->mutable_device_entry_computation_layout()
+ ASSERT_LT(param_no,
+ module->mutable_entry_computation_layout()->parameter_count());
+ module->mutable_entry_computation_layout()
->mutable_parameter_layout(param_no)
->ResetLayout(layout);
}
@@ -199,10 +195,7 @@ class HloTestBase : public ::testing::Test {
// Convenience method to force the layout of the computation result in a
// module. The result layout of 'module' is set to 'layout'.
void ForceResultLayout(HloModule* module, const Layout& layout) {
- module->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->ResetLayout(layout);
- module->mutable_device_entry_computation_layout()
+ module->mutable_entry_computation_layout()
->mutable_result_layout()
->ResetLayout(layout);
}
@@ -210,10 +203,7 @@ class HloTestBase : public ::testing::Test {
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {
- module->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->Clear();
- module->mutable_device_entry_computation_layout()
+ module->mutable_entry_computation_layout()
->mutable_result_layout()
->Clear();
}