aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-24 15:54:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 15:58:37 -0700
commit9c58005ec86297a1d0a17dc4f7ad7cbae9c47e4b (patch)
tree6d170b666d522db1c24f255900a5031ad6bca709 /tensorflow/compiler/jit
parent084f84f2ce44b8a1909b59bcc940652a95cd6fc9 (diff)
Remove the "constants" input group from _XlaRun; NFC
It wasn't actually needed. PiperOrigin-RevId: 214346217
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc20
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc35
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc6
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc13
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h14
7 files changed, 59 insertions, 36 deletions
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index a6086f30a1..13a518d0e8 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -55,7 +55,6 @@ static Status BuildXlaCompileNode(
}
static Status BuildXlaRunNode(const string& nodename, const string& device_name,
- const DataTypeVector& constant_dtypes,
const DataTypeVector& arg_dtypes,
const DataTypeVector& result_dtypes, Graph* graph,
Node** node) {
@@ -63,7 +62,6 @@ static Status BuildXlaRunNode(const string& nodename, const string& device_name,
def.set_name(graph->NewName(nodename));
def.set_op("_XlaRun");
def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
AddNodeAttr("Targs", arg_dtypes, &def);
AddNodeAttr("Tresults", result_dtypes, &def);
@@ -98,12 +96,14 @@ static Status GetXlaAttrs(Node* node, int* num_constant_args,
return Status::OK();
}
-static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node) {
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
+ int prefix_to_ignore) {
for (const Edge* edge : old_node->in_edges()) {
if (edge->IsControlEdge()) {
g->AddControlEdge(edge->src(), new_node);
- } else {
- g->AddEdge(edge->src(), edge->src_output(), new_node, edge->dst_input());
+ } else if (edge->dst_input() >= prefix_to_ignore) {
+ g->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input() - prefix_to_ignore);
}
}
}
@@ -145,17 +145,19 @@ static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
}
TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
- const_dtypes, arg_dtypes_with_resources,
+ arg_dtypes_with_resources,
n->output_types(), g, &run_node));
compile_node->set_assigned_device_name(n->assigned_device_name());
run_node->set_assigned_device_name(n->assigned_device_name());
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node);
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
+ /*prefix_to_ignore=*/0);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
+ /*prefix_to_ignore=*/num_constant_args);
// The compilation_key output.
- g->AddEdge(compile_node, 0, run_node, n->num_inputs());
+ g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
g->RemoveNode(n);
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index c483841a7c..a85006eb03 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -98,11 +98,13 @@ class XlaExecutableClosure {
explicit XlaExecutableClosure(
xla::LocalClient* client, xla::LocalExecutable* executable,
const XlaCompiler::CompilationResult* compilation_result,
- std::map<int, OptionalTensor> resource_var_snapshots)
+ std::map<int, OptionalTensor> resource_var_snapshots,
+ int num_constant_args)
: client_(client),
executable_(executable),
compilation_result_(compilation_result),
- resource_var_snapshots_(std::move(resource_var_snapshots)) {}
+ resource_var_snapshots_(std::move(resource_var_snapshots)),
+ num_constant_args_(num_constant_args) {}
XlaExecutableClosure(XlaExecutableClosure&&) = default;
XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
@@ -115,12 +117,14 @@ class XlaExecutableClosure {
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
return resource_var_snapshots_;
}
+ int num_constant_args() const { return num_constant_args_; }
private:
xla::LocalClient* client_;
xla::LocalExecutable* executable_;
const XlaCompiler::CompilationResult* compilation_result_;
std::map<int, OptionalTensor> resource_var_snapshots_;
+ int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
};
@@ -298,7 +302,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
client, platform_info_.allocator(),
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, kernel, variables);
+ launch_context.PopulateInputs(ctx, kernel, variables,
+ /*missing_ctx_input_prefix=*/0);
// Execute the computation.
VLOG(2) << "Executing computation.";
@@ -317,7 +322,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(2) << "Elapsed time: " << elapsed << "us";
OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
+ ctx, kernel, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
VLOG(1) << "Done";
}
@@ -406,7 +412,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
// variables.
XlaExecutableClosureStore::KeyT key =
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
- client, executable, kernel, std::move(variables)));
+ client, executable, kernel, std::move(variables), constants_.size()));
Allocator* cpu_allocator = [&] {
AllocatorAttributes host_alloc_attrs;
@@ -440,8 +446,13 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
closure.client(), platform_info_.allocator(),
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, closure.compilation_result(),
- closure.resource_var_snapshots());
+
+ // We're missing the must-be-constant inputs, tell `PopulateInputs`
+ // about this. We don't actually need these inputs because they've
+ // already been baked into the compiled kernel.
+ launch_context.PopulateInputs(
+ ctx, closure.compilation_result(), closure.resource_var_snapshots(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args());
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -461,8 +472,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
OP_REQUIRES_OK(
- ctx, launch_context.PopulateOutputs(ctx, closure.compilation_result(),
- run_result.ConsumeValueOrDie()));
+ ctx,
+ launch_context.PopulateOutputs(
+ ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args()));
}
REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
@@ -481,8 +494,6 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
XlaCompileOp);
REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
-
-REGISTER_KERNEL_BUILDER(
- Name("_XlaRun").Device(DEVICE_GPU).HostMemory("constants"), XlaRunOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 6b4cdaa1c1..bcd1a29b1f 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -76,10 +76,6 @@ for now.
)");
REGISTER_OP("_XlaRun")
- // TODO(sanjoy): We don't need constants and Tconstants and they should be
- // removed.
- .Input("constants: Tconstants")
- .Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
.Attr("Targs: list(type) >= 0")
.Output("results: Tresults")
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 3ba48e8c31..3c160aefe5 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -58,7 +58,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, result, variables);
+ launch_context.PopulateInputs(ctx, result, variables,
+ /*missing_ctx_input_prefix=*/0);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -79,7 +80,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
TF_RETURN_IF_ERROR(run_result.status());
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
- ctx, result, run_result.ConsumeValueOrDie()));
+ ctx, result, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 639243973c..2ccee79761 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -73,8 +73,7 @@ class XlaAssignVariableOp : public AsyncOpKernel {
KERNEL);
#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
- REGISTER_KERNEL_BUILDER( \
- Name("_XlaRun").Device(DEVICE).HostMemory("constants"), KERNEL);
+ REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 07a93e9c39..f5c8bdd6ee 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -133,7 +133,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables) {
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
@@ -145,12 +146,13 @@ void XlaComputationLaunchContext::PopulateInputs(
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
+ DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
- t = &(ctx->input(arg_num));
+ t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
if (use_multiple_streams_) {
@@ -187,7 +189,7 @@ void XlaComputationLaunchContext::PopulateInputs(
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- ScopedShapedBuffer output) {
+ ScopedShapedBuffer output, int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -315,7 +317,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ int actual_input_index = write.input_index - missing_ctx_input_prefix;
+ if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
@@ -325,7 +328,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index), &variable,
+ ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index fa7a5e5f89..326d70a027 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -89,14 +89,24 @@ class XlaComputationLaunchContext {
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly. All elements in kernel's
+ // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
+ // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables);
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ xla::ScopedShapedBuffer output,
+ int missing_ctx_input_prefix);
// Return the argument list. Only valid after PopulateInputs() has been
// called.