diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 405 |
1 files changed, 405 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc new file mode 100644 index 0000000000..e46c2a3148 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -0,0 +1,405 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" + +#include <numeric> + +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +namespace { + +bool HasRetval(const Graph& graph) { + for (const Node* n : graph.nodes()) { + if (n->type_string() == "_Retval") return true; + } + return false; +} + +Status CheckSignature(const DataTypeVector& tf_types, + const xla::Shape& xla_shape) { + if (xla::ShapeUtil::IsTuple(xla_shape)) { + if (xla::ShapeUtil::TupleElementCount(xla_shape) != tf_types.size()) { + return errors::Internal("XLA shape has ", + xla::ShapeUtil::TupleElementCount(xla_shape), + " elements while function has ", tf_types.size()); + } + for (int i = 0; i < tf_types.size(); ++i) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[i], &type)); + if (type != + xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type()) { + return errors::Internal( + "element ", i, " has XLA type ", + xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type(), + " and TensorFlow type ", DataTypeString(tf_types[i])); + } + } + } else { + if (tf_types.size() != 1) { + return errors::Internal("Expected singleton type, got ", tf_types.size(), + " types"); + } + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[0], &type)); + if (type != xla_shape.element_type()) { + return errors::Internal("singleton element has XLA type ", + xla_shape.element_type(), " and TensorFlow type ", + DataTypeString(tf_types[0])); + } + } + return Status::OK(); +} + +} // namespace + +XlaCompiler::XlaCompiler(const XlaCompiler::Options& options) + : client_(options.client), + allow_cpu_custom_calls_(options.allow_cpu_custom_calls), + local_executable_has_hybrid_result_( + options.local_executable_has_hybrid_result), + next_step_id_(1), + device_(new XlaCompilationDevice(SessionOptions(), options.device_type)), + device_mgr_({device_}) {} + +XlaCompiler::~XlaCompiler() = default; + +int64 XlaCompiler::NextStepId() { + mutex_lock l(mu_); + return next_step_id_++; +} + +Status XlaCompiler::CompileFunction( + FunctionLibraryRuntime* flr, const NameAttrList& function, + const std::vector<XlaCompiler::Argument>& args, + XlaCompiler::CompilationResult* result) { + const string function_id = Canonicalize(function.name(), function.attr()); + VLOG(1) << "XlaCompiler::CompileFunction " << function_id; + + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(function.name(), function.attr(), &handle)); + + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); + + return CompileFunctionBody(flr, *fbody, function_id, args, + /*use_tuple_arg=*/false, result); +} + +Status XlaCompiler::CompileSubComputation(FunctionLibraryRuntime* flr, + const NameAttrList& function, + const xla::Shape& input_shape, + const xla::Shape& output_shape, + xla::Computation* computation) { + const string function_id = Canonicalize(function.name(), function.attr()); + VLOG(1) << "XlaCompiler::CompileSubComputation " << function_id; + + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(function.name(), function.attr(), &handle)); + + const FunctionBody* fbody = flr->GetFunctionBody(handle); + CHECK(fbody); + + TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, input_shape)); + TF_RETURN_IF_ERROR(CheckSignature(fbody->ret_types, output_shape)); + + const bool use_tuple_arg = xla::ShapeUtil::IsTuple(input_shape); + + std::vector<XlaCompiler::Argument> args(fbody->arg_types.size()); + if (use_tuple_arg) { + for (int i = 0; i < args.size(); ++i) { + xla::Shape xla_shape = + xla::ShapeUtil::GetTupleElementShape(input_shape, i); + args[i].type = fbody->arg_types[i]; + args[i].shape = XLAShapeToTensorShape(xla_shape); + args[i].parameter = i; + } + } else { + args[0].type = fbody->arg_types[0]; + args[0].shape = XLAShapeToTensorShape(input_shape); + args[0].parameter = 0; + } + + CompilationResult result; + TF_RETURN_IF_ERROR(CompileFunctionBody(flr, *fbody, function_id, args, + use_tuple_arg, &result)); + + if (!xla::ShapeUtil::Compatible(result.xla_output_shape, output_shape)) { + return errors::Internal("output shape mismatch from compilation"); + } + *computation = std::move(result.computation); + + return Status::OK(); +} + +Status XlaCompiler::CompileFunctionBody( + FunctionLibraryRuntime* flr, const FunctionBody& fbody, + const string& function_id, const std::vector<XlaCompiler::Argument>& args, + bool use_tuple_arg, XlaCompiler::CompilationResult* result) { + VLOG(1) << "XlaCompiler::CompileFunctionBody " << function_id; + + std::unique_ptr<Graph> graph(new Graph(flr->GetFunctionLibraryDefinition())); + CopyGraph(*fbody.graph, graph.get()); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile( + strings::StrCat("xla_jit_raw_input_", function_id), *graph); + } + + if (!HasRetval(*graph)) { + VLOG(1) << "Graph has no retvals. Skipping compilation."; + return Status::OK(); + } + + // Optimize the graph to before running throught the translator. + // TODO(pbar) The constant folder currently does not simplify int32 operations + // for devices other than CPU. + OptimizerOptions opts; + GraphOptimizer optimizer(opts); + Graph* g = graph.release(); + OptimizeGraph(flr, &g); + graph.reset(g); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile( + strings::StrCat("xla_jit_final_graph_", function_id), *graph); + } + + VLOG(1) << "===================================================="; + TF_RETURN_IF_ERROR(CompileGraph(function_id, std::move(graph), flr, args, + use_tuple_arg, result)); + VLOG(1) << "===================================================="; + + return Status::OK(); +} + +Status XlaCompiler::BuildExecutable( + const XlaCompiler::CompilationResult& result, + std::unique_ptr<xla::LocalExecutable>* executable) { + VLOG(2) << "Compiling to local executable"; + xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); + + std::vector<const xla::Shape*> argument_layouts( + result.xla_input_shapes.size()); + for (int i = 0; i < result.xla_input_shapes.size(); ++i) { + argument_layouts[i] = &result.xla_input_shapes[i].second; + } + if (result.requires_runtime_context) { + // The final arg is the XlaLocalRuntimeContext*. + argument_layouts.push_back(&opaque_shape); + } + xla::LocalClient* local_client = static_cast<xla::LocalClient*>(client()); + xla::ExecutableBuildOptions build_options; + build_options.set_device_ordinal(local_client->default_device_ordinal()); + build_options.set_platform(local_client->platform()); + build_options.set_result_layout(result.xla_output_shape); + build_options.set_has_hybrid_result(local_executable_has_hybrid_result_); + + auto compile_result = local_client->Compile(result.computation, + argument_layouts, build_options); + if (!compile_result.ok()) { + return compile_result.status(); + } + *executable = std::move(compile_result.ValueOrDie()); + return Status::OK(); +} + +namespace { + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status cleanup_status; + auto step_container = xla::MakeUnique<ScopedStepContainer>( + step_id, [&cleanup_status, device](const string& name) { + cleanup_status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + // Create a LocalExecutor that will own and run the graph. + LocalExecutorParams exec_params; + exec_params.device = device; + exec_params.function_library = flib; + exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) { + return flib->CreateKernel(ndef, kernel); + }; + exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; + Executor* exec_ptr = nullptr; + TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr)); + std::unique_ptr<Executor> exec(exec_ptr); + // At this point ownership of the graph has been transferred to exec. + + auto runner = [](Executor::Args::Closure c) { + // TODO(misard) Temporarily just schedule c eagerly while we + // decide what to do about the fact that the ComputationBuilder is + // thread-compatible, but we don't really want Op writers to have + // to remember to acquire a lock around every call to + // ComputationBuilder. One possibility is to add the (generally + // useful) ability to run a single-threaded Executor based on an + // option in LocalExecutorParams. Another is to automagically + // acquire a lock around ComputationBuilder calls using some + // wrapper or RAII funny business. + c(); + }; + + // Run the graph symbolically, turning the graph into an XLA computation. + Executor::Args exec_args; + exec_args.step_id = step_id; + exec_args.step_container = step_container.get(); + exec_args.runner = runner; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + exec->Run(exec_args), + "Conversion from TensorFlow graph to XLA computation failed."); + + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return cleanup_status; +} + +} // namespace + +Status XlaCompiler::CompileGraph(string const& name, + std::unique_ptr<Graph> graph, + FunctionLibraryRuntime* flib, + const std::vector<XlaCompiler::Argument>& args, + bool use_tuple_arg, + CompilationResult* result) { + VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + + // Converts the input shapes into xla::Shape instances. + result->xla_input_shapes.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + if (args[i].parameter < 0) { + continue; + } + result->xla_input_shapes.push_back(std::make_pair(i, xla::Shape())); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + args[i].type, args[i].shape, &result->xla_input_shapes.back().second)); + } + + XlaContext* xla_context = + new XlaContext(client(), name, allow_cpu_custom_calls_); + core::ScopedUnref xla_context_unref(xla_context); + + TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg)); + + TF_RETURN_IF_ERROR( + ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId())); + + std::vector<XlaContext::ConstRetVal> compile_time_constants; + int num_nonconst_outputs; + TF_RETURN_IF_ERROR(xla_context->CollectResults( + &result->computation, &result->requires_runtime_context, + &compile_time_constants, &num_nonconst_outputs)); + + result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs); + for (const auto& c : compile_time_constants) { + if (!c.status.ok()) { + Status constant_status = c.status; + errors::AppendToMessage(&constant_status, + "Failed evaluating constant XLA return " + "value ", + c.index); + return constant_status; + } + if (c.index >= result->outputs.size()) { + return errors::InvalidArgument("Invalid argument index ", c.index); + } + OutputDescription& output = result->outputs[c.index]; + output.shape = c.value.shape(); + output.is_constant = true; + output.constant_value = c.value; + } + + if (result->computation.IsNull()) { + return Status::OK(); + } + + // Compute the output shapes, if there is a computation with non-constant + // outputs. + auto computation_shape = client()->GetComputationShape(result->computation); + if (!computation_shape.ok()) { + return computation_shape.status(); + } + + result->xla_output_shape.Swap( + computation_shape.ValueOrDie()->mutable_result()); + + auto num_non_constant_outputs = + (xla::ShapeUtil::IsTuple(result->xla_output_shape)) + ? xla::ShapeUtil::TupleElementCount(result->xla_output_shape) + : 1; + // Tensorflow expects a major-to-minor order of results. + if (1 == num_non_constant_outputs) { + xla::Shape& s = result->xla_output_shape; + auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); + minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } else { + for (xla::Shape& s : *result->xla_output_shape.mutable_tuple_shapes()) { + auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major(); + minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + } + + // Converts the output shapes to TensorShapes. + int computation_output = 0; + for (int i = 0; i < result->outputs.size(); ++i) { + if (!result->outputs[i].is_constant) { + CHECK_LT(computation_output, num_non_constant_outputs); + if (num_non_constant_outputs > 1) { + result->outputs[i].shape = + XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( + result->xla_output_shape, computation_output)); + } else { + result->outputs[i].shape = + XLAShapeToTensorShape(result->xla_output_shape); + } + ++computation_output; + } + } + return Status::OK(); +} + +} // namespace tensorflow |