diff options
author | 2018-06-18 15:32:55 -0700 | |
---|---|---|
committer | 2018-06-18 15:38:37 -0700 | |
commit | ae377d44a9796a2b226306aeade57888d2f2df03 (patch) | |
tree | 31347bb415f53aadaf57b21e171c7689d4c3e36f /tensorflow/compiler/xla/service/local_service.cc | |
parent | 205fe2dbb8e00ebe25e5e9a480a24a49f0d87646 (diff) |
Enable the natural layouts of the entry computation to flow into the parameters and result layouts of the entry ComputationLayout.
If the arguments shapes passed in to the servie.cc API do not have a layout, it is assumed the caller is willing to accept the natural layout propagated by the XLA compiler.
Similarly, if the ExecutionOptions has a shape for the result, but no layout is set in such shape, it is assumed the caller is willing to accept the natural layout propagated by the XLA compiler.
Same thing for the ExecutableBuildOptions result_layout().
PiperOrigin-RevId: 201070858
Diffstat (limited to 'tensorflow/compiler/xla/service/local_service.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/local_service.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 296d04d436..a6aa8bf82c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -154,7 +154,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional<const OpMetadata*> metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -178,8 +179,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape.result())); + TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = @@ -189,6 +190,11 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); + VLOG(3) << "Host Computation Layout: " + << module_config->host_entry_computation_layout().ToString(); + VLOG(3) << "Device Computation Layout: " + << module_config->device_entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); |