aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/local_service.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 16:13:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:22:41 -0700
commitdd3adb6165605c28f1a993f9093e8f7c99b357c5 (patch)
treee24c590b6e346462484d4bbdcac1add5761672d7 /tensorflow/compiler/xla/service/local_service.cc
parent084c10784887d7c4d467416430626cf7eb333cb8 (diff)
[XLA] Redesign: implement local client and local service interface.
PiperOrigin-RevId: 190291400
Diffstat (limited to 'tensorflow/compiler/xla/service/local_service.cc')
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc153
1 files changed, 125 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 1e2d8eea58..499f280211 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {}
+namespace {
+
+// Retrieves the parameter metadata for the given computation and parameter
+// number.
+//
+// If the parameter number is invalid for this computation, nullopt is
+// returned. When the return value has_value(), nullptr will never be
+// the held value.
+tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+ const XlaComputation& computation, int parameter_number) {
+ for (const HloComputationProto& comp : computation.proto().computations()) {
+ if (comp.id() == computation.proto().entry_computation_id()) {
+ for (const HloInstructionProto& instr : comp.instructions()) {
+ if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
+ instr.parameter_number() == parameter_number) {
+ if (!instr.has_metadata()) {
+ return tensorflow::gtl::nullopt;
+ }
+ return &instr.metadata();
+ }
+ }
+ }
+ }
+ return tensorflow::gtl::nullopt;
+}
+
+ExecutionOptions CreateExecutionOptions(
+ const ExecutableBuildOptions& build_options,
+ const ProgramShape* program_shape) {
+ ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+ if (build_options.hlo_profile().has_value()) {
+ execution_options.mutable_debug_options()->set_xla_hlo_profile(
+ *build_options.hlo_profile());
+ }
+ if (build_options.generate_hlo_graph().has_value()) {
+ execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
+ build_options.generate_hlo_graph().value());
+ }
+ if (build_options.dump_optimized_hlo_proto_to().has_value()) {
+ execution_options.mutable_debug_options()
+ ->set_xla_dump_optimized_hlo_proto_to(
+ build_options.dump_optimized_hlo_proto_to().value());
+ }
+ if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
+ execution_options.mutable_debug_options()
+ ->set_xla_dump_per_pass_hlo_proto_to(
+ build_options.dump_per_pass_hlo_proto_to().value());
+ }
+ if (build_options.result_layout() != nullptr) {
+ *execution_options.mutable_shape_with_output_layout() =
+ *build_options.result_layout();
+ } else {
+ *execution_options.mutable_shape_with_output_layout() =
+ program_shape->result();
+ LayoutUtil::SetToDefaultLayout(
+ execution_options.mutable_shape_with_output_layout());
+ }
+ return execution_options;
+}
+
+} // namespace
+
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
@@ -118,34 +180,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
*build_options.result_layout(), program_shape->result()));
}
- ExecutionOptions execution_options = CreateDefaultExecutionOptions();
- if (build_options.hlo_profile().has_value()) {
- execution_options.mutable_debug_options()->set_xla_hlo_profile(
- *build_options.hlo_profile());
- }
- if (build_options.generate_hlo_graph().has_value()) {
- execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
- build_options.generate_hlo_graph().value());
- }
- if (build_options.dump_optimized_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_optimized_hlo_proto_to(
- build_options.dump_optimized_hlo_proto_to().value());
- }
- if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_per_pass_hlo_proto_to(
- build_options.dump_per_pass_hlo_proto_to().value());
- }
- if (build_options.result_layout() != nullptr) {
- *execution_options.mutable_shape_with_output_layout() =
- *build_options.result_layout();
- } else {
- *execution_options.mutable_shape_with_output_layout() =
- program_shape->result();
- LayoutUtil::SetToDefaultLayout(
- execution_options.mutable_shape_with_output_layout());
- }
+ ExecutionOptions execution_options =
+ CreateExecutionOptions(build_options, program_shape.get());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, argument_layouts,
&execution_options, user_computation));
@@ -159,6 +195,67 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
build_options.device_allocator());
}
+StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
+ const XlaComputation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& build_options) {
+ const HloModuleProto& proto = computation.proto();
+ TF_RET_CHECK(proto.has_program_shape());
+ const ProgramShape& program_shape = proto.program_shape();
+
+ // Validate incoming layouts.
+ if (argument_layouts.size() != program_shape.parameters_size()) {
+ return InvalidArgument(
+ "Invalid number of arguments for computation: expected %d, got %zu.",
+ program_shape.parameters_size(), argument_layouts.size());
+ }
+
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& argument_shape = *argument_layouts[i];
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+ if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
+ tensorflow::gtl::optional<const OpMetadata*> metadata =
+ ParameterMetadata(computation, /*parameter_number=*/i);
+ auto metadata_string = [&metadata]() -> string {
+ if (!metadata.has_value()) {
+ return "";
+ }
+ CHECK(metadata.value() != nullptr);
+ const OpMetadata& m = *metadata.value();
+ if (!m.source_file().empty()) {
+ return tensorflow::strings::Printf(
+ " (%s:%d)", m.source_file().c_str(), m.source_line());
+ }
+ return "";
+ };
+ return InvalidArgument(
+ "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
+ metadata_string().c_str(),
+ ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(argument_shape).c_str());
+ }
+ }
+ if (build_options.result_layout() != nullptr) {
+ TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(
+ *build_options.result_layout(), program_shape.result()));
+ }
+
+ ExecutionOptions execution_options =
+ CreateExecutionOptions(build_options, &program_shape);
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(program_shape, argument_layouts, &execution_options));
+
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ execute_backend_->stream_executor(build_options.device_ordinal()));
+
+ return BuildExecutable(proto, std::move(module_config),
+ execute_backend_.get(), executor,
+ build_options.device_allocator());
+}
+
StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
return backend().computation_placer()->DeviceId(
replica_number, /*computation=*/0, options_.number_of_replicas(),