diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/service.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index fc848bdb03..849df1d8e6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -34,8 +34,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -55,6 +57,7 @@ namespace se = ::perftools::gputools; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; +using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -260,7 +263,8 @@ StatusOr<std::vector<const ShapedBuffer*>> Service::ResolveAndValidateArguments( StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, - const ExecutionOptions* execution_options) { + const ExecutionOptions* execution_options, + const UserComputation& user_computation) { auto config = MakeUnique<HloModuleConfig>(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -274,8 +278,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - return InvalidArgument( - "computation expects parameter %d to have shape %s, given shape %s", + return InvalidParameterArgument( + *user_computation.ParameterMetadata(i).value(), + "Argument does not match shape of computation parameter %d: want %s, " + "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -317,12 +323,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, + const UserComputation& user_computation) { std::vector<const Shape*> argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options, + user_computation); } StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( @@ -419,6 +427,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor)); @@ -566,7 +576,7 @@ Service::ExecuteParallelAndRegisterResult( se::Stream* stream = index_to_profiled_stream.second; Executable* executable = executables[device]; const HloModule& module = executable->module(); - HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); @@ -739,9 +749,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config, - CreateModuleConfig(*program_shape, arguments, - request.execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloModuleConfig> module_config, + CreateModuleConfig(*program_shape, arguments, + request.execution_options(), *user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -849,7 +860,8 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options())); + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -913,7 +925,8 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options())); + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1233,7 +1246,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + *user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1597,4 +1611,15 @@ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas( return replicas; } +Status Service::MaybeDumpHloModule(const HloModule& module) const { + const string xla_dump_prepass_hlo_proto_to = + module.config().debug_options().xla_dump_prepass_hlo_proto_to(); + if (xla_dump_prepass_hlo_proto_to.empty()) { + return Status::OK(); + } + HloProto proto = MakeHloProto(module); + return protobuf_util::DumpProtoToDirectory( + proto, xla_dump_prepass_hlo_proto_to, module.name()); +} + } // namespace xla |