diff options
author | 2018-09-24 15:54:22 -0700 | |
---|---|---|
committer | 2018-09-24 15:58:37 -0700 | |
commit | 9c58005ec86297a1d0a17dc4f7ad7cbae9c47e4b (patch) | |
tree | 6d170b666d522db1c24f255900a5031ad6bca709 /tensorflow/compiler/jit | |
parent | 084f84f2ce44b8a1909b59bcc940652a95cd6fc9 (diff) |
Remove the "constants" input group from _XlaRun; NFC
It wasn't actually needed.
PiperOrigin-RevId: 214346217
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/build_xla_ops_pass.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_ops.cc | 35 | ||||
-rw-r--r-- | tensorflow/compiler/jit/ops/xla_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compile_on_demand_op.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_device_ops.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_launch_util.h | 14 |
7 files changed, 59 insertions, 36 deletions
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index a6086f30a1..13a518d0e8 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -55,7 +55,6 @@ static Status BuildXlaCompileNode( } static Status BuildXlaRunNode(const string& nodename, const string& device_name, - const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, Graph* graph, Node** node) { @@ -63,7 +62,6 @@ static Status BuildXlaRunNode(const string& nodename, const string& device_name, def.set_name(graph->NewName(nodename)); def.set_op("_XlaRun"); def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); AddNodeAttr("Tresults", result_dtypes, &def); @@ -98,12 +96,14 @@ static Status GetXlaAttrs(Node* node, int* num_constant_args, return Status::OK(); } -static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node) { +static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node, + int prefix_to_ignore) { for (const Edge* edge : old_node->in_edges()) { if (edge->IsControlEdge()) { g->AddControlEdge(edge->src(), new_node); - } else { - g->AddEdge(edge->src(), edge->src_output(), new_node, edge->dst_input()); + } else if (edge->dst_input() >= prefix_to_ignore) { + g->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input() - prefix_to_ignore); } } } @@ -145,17 +145,19 @@ static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) { } TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(), - const_dtypes, arg_dtypes_with_resources, + arg_dtypes_with_resources, n->output_types(), g, &run_node)); compile_node->set_assigned_device_name(n->assigned_device_name()); run_node->set_assigned_device_name(n->assigned_device_name()); - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node); - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node); + CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node, + /*prefix_to_ignore=*/0); + CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node, + /*prefix_to_ignore=*/num_constant_args); // The compilation_key output. - g->AddEdge(compile_node, 0, run_node, n->num_inputs()); + g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args); MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node); g->RemoveNode(n); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index c483841a7c..a85006eb03 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -98,11 +98,13 @@ class XlaExecutableClosure { explicit XlaExecutableClosure( xla::LocalClient* client, xla::LocalExecutable* executable, const XlaCompiler::CompilationResult* compilation_result, - std::map<int, OptionalTensor> resource_var_snapshots) + std::map<int, OptionalTensor> resource_var_snapshots, + int num_constant_args) : client_(client), executable_(executable), compilation_result_(compilation_result), - resource_var_snapshots_(std::move(resource_var_snapshots)) {} + resource_var_snapshots_(std::move(resource_var_snapshots)), + num_constant_args_(num_constant_args) {} XlaExecutableClosure(XlaExecutableClosure&&) = default; XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; @@ -115,12 +117,14 @@ class XlaExecutableClosure { const std::map<int, OptionalTensor>& resource_var_snapshots() const { return resource_var_snapshots_; } + int num_constant_args() const { return num_constant_args_; } private: xla::LocalClient* client_; xla::LocalExecutable* executable_; const XlaCompiler::CompilationResult* compilation_result_; std::map<int, OptionalTensor> resource_var_snapshots_; + int num_constant_args_; TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); }; @@ -298,7 +302,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { client, platform_info_.allocator(), /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, kernel, variables); + launch_context.PopulateInputs(ctx, kernel, variables, + /*missing_ctx_input_prefix=*/0); // Execute the computation. VLOG(2) << "Executing computation."; @@ -317,7 +322,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(2) << "Elapsed time: " << elapsed << "us"; OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie())); + ctx, kernel, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); VLOG(1) << "Done"; } @@ -406,7 +412,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { // variables. XlaExecutableClosureStore::KeyT key = XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( - client, executable, kernel, std::move(variables))); + client, executable, kernel, std::move(variables), constants_.size())); Allocator* cpu_allocator = [&] { AllocatorAttributes host_alloc_attrs; @@ -440,8 +446,13 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { closure.client(), platform_info_.allocator(), /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, closure.compilation_result(), - closure.resource_var_snapshots()); + + // We're missing the must-be-constant inputs, tell `PopulateInputs` + // about this. We don't actually need these inputs because they've + // already been baked into the compiled kernel. + launch_context.PopulateInputs( + ctx, closure.compilation_result(), closure.resource_var_snapshots(), + /*missing_ctx_input_prefix=*/closure.num_constant_args()); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -461,8 +472,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; OP_REQUIRES_OK( - ctx, launch_context.PopulateOutputs(ctx, closure.compilation_result(), - run_result.ConsumeValueOrDie())); + ctx, + launch_context.PopulateOutputs( + ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/closure.num_constant_args())); } REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); @@ -481,8 +494,6 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") XlaCompileOp); REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); - -REGISTER_KERNEL_BUILDER( - Name("_XlaRun").Device(DEVICE_GPU).HostMemory("constants"), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 6b4cdaa1c1..bcd1a29b1f 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -76,10 +76,6 @@ for now. )"); REGISTER_OP("_XlaRun") - // TODO(sanjoy): We don't need constants and Tconstants and they should be - // removed. - .Input("constants: Tconstants") - .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") .Attr("Targs: list(type) >= 0") .Output("results: Tresults") diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 3ba48e8c31..3c160aefe5 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -58,7 +58,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables); + launch_context.PopulateInputs(ctx, result, variables, + /*missing_ctx_input_prefix=*/0); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -79,7 +80,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, TF_RETURN_IF_ERROR(run_result.status()); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie())); + ctx, result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 639243973c..2ccee79761 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -73,8 +73,7 @@ class XlaAssignVariableOp : public AsyncOpKernel { KERNEL); #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER( \ - Name("_XlaRun").Device(DEVICE).HostMemory("constants"), KERNEL); + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 07a93e9c39..f5c8bdd6ee 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -133,7 +133,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map<int, OptionalTensor>& variables) { + const std::map<int, OptionalTensor>& variables, + int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. @@ -145,12 +146,13 @@ void XlaComputationLaunchContext::PopulateInputs( const Tensor* t; for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; + DCHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; if (variables.count(arg_num)) { t = &(variables.at(arg_num).value); CHECK(t); } else { - t = &(ctx->input(arg_num)); + t = &(ctx->input(arg_num - missing_ctx_input_prefix)); } if (use_multiple_streams_) { @@ -187,7 +189,7 @@ void XlaComputationLaunchContext::PopulateInputs( Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output) { + ScopedShapedBuffer output, int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -315,7 +317,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } @@ -325,7 +328,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>( - ctx, HandleFromInput(ctx, write.input_index), &variable, + ctx, HandleFromInput(ctx, actual_input_index), &variable, [&write](Var** ptr) { *ptr = new Var(write.type); return Status::OK(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index fa7a5e5f89..326d70a027 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -89,14 +89,24 @@ class XlaComputationLaunchContext { // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map<int, OptionalTensor>& variables); + const std::map<int, OptionalTensor>& variables, + int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. Status PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + xla::ScopedShapedBuffer output, + int missing_ctx_input_prefix); // Return the argument list. Only valid after PopulateInputs() has been // called. |