aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/service.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/service.cc')
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
1 files changed, 37 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index fc848bdb03..849df1d8e6 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -34,8 +34,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -55,6 +57,7 @@ namespace se = ::perftools::gputools;
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrCat;
+using ::xla::source_map_util::InvalidParameterArgument;
namespace xla {
@@ -260,7 +263,8 @@ StatusOr<std::vector<const ShapedBuffer*>> Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options) {
+ const ExecutionOptions* execution_options,
+ const UserComputation& user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = config->mutable_entry_computation_layout();
@@ -274,8 +278,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
- return InvalidArgument(
- "computation expects parameter %d to have shape %s, given shape %s",
+ return InvalidParameterArgument(
+ *user_computation.ParameterMetadata(i).value(),
+ "Argument does not match shape of computation parameter %d: want %s, "
+ "got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
@@ -317,12 +323,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options) {
+ const ExecutionOptions& execution_options,
+ const UserComputation& user_computation) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape());
}
- return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
+ return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
+ user_computation);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
@@ -419,6 +427,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
/*include_unreachable_instructions=*/
true));
+ TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor));
@@ -566,7 +576,7 @@ Service::ExecuteParallelAndRegisterResult(
se::Stream* stream = index_to_profiled_stream.second;
Executable* executable = executables[device];
const HloModule& module = executable->module();
- HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
+ HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
@@ -739,9 +749,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
// Create an HloModuleConfig object for the computation, given the shape of
// the program and the argument allocations.
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments,
- request.execution_options()));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arguments,
+ request.execution_options(), *user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -849,7 +860,8 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -913,7 +925,8 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -1233,7 +1246,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(program_shape, {}, execution_options));
+ CreateModuleConfig(program_shape, {}, execution_options,
+ *user_computation));
// Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
@@ -1597,4 +1611,15 @@ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
return replicas;
}
+Status Service::MaybeDumpHloModule(const HloModule& module) const {
+ const string xla_dump_prepass_hlo_proto_to =
+ module.config().debug_options().xla_dump_prepass_hlo_proto_to();
+ if (xla_dump_prepass_hlo_proto_to.empty()) {
+ return Status::OK();
+ }
+ HloProto proto = MakeHloProto(module);
+ return protobuf_util::DumpProtoToDirectory(
+ proto, xla_dump_prepass_hlo_proto_to, module.name());
+}
+
} // namespace xla