aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc9
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc27
5 files changed, 15 insertions, 38 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 4842877d9a..4f3f17df9c 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -257,10 +257,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
-
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_,
- variables, ctx, &kernel, &executable,
- /*compile_options=*/nullptr));
+ variables, ctx, &kernel, &executable));
VLOG(1) << "Executing XLA Computation...";
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index bfff52c55a..3717c2cc24 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -238,8 +238,7 @@ Status XlaCompilationCache::Compile(
int num_constant_args, const std::vector<OptionalTensor>& variable_args,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options) {
+ xla::LocalExecutable** executable) {
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
@@ -298,9 +297,9 @@ Status XlaCompilationCache::Compile(
XlaCompiler compiler(options);
entry->compiled = true;
- entry->compilation_status = compiler.CompileFunction(
- compile_options ? *compile_options : XlaCompiler::CompileOptions(),
- function, args, &entry->compilation_result);
+ entry->compilation_status =
+ compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args,
+ &entry->compilation_result);
}
*compilation_result = &entry->compilation_result;
if (entry->compilation_status.ok() && executable) {
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 0858020716..c3a8f68a15 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -66,8 +66,7 @@ class XlaCompilationCache : public ResourceBase {
const std::vector<OptionalTensor>& variable_args,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
- xla::LocalExecutable** executable,
- const XlaCompiler::CompileOptions* compile_options);
+ xla::LocalExecutable** executable);
xla::LocalClient* client() const { return client_; }
const DeviceType& device_type() const { return device_type_; }
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 5aea25dc7d..ee466520dd 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -201,12 +201,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
- OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1,
- errors::FailedPrecondition("Expected one input shape"));
- xla::Shape body_input_shape = body.xla_input_shapes[0];
- OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1,
- errors::FailedPrecondition("Expected one input shape"));
- xla::Shape cond_input_shape = cond.xla_input_shapes[0];
+ xla::Shape body_input_shape =
+ xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
+ xla::Shape cond_input_shape =
+ xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 310cf20ec1..c55719be55 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -316,22 +316,15 @@ Status BuildArguments(const Graph& graph,
return Status::OK();
}
- std::vector<xla::Shape> arg_shapes;
- arg_shapes.reserve(parameters.size());
+ input_shapes->resize(parameters.size());
input_mapping->resize(parameters.size());
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
// Computes the shapes of non-constant arguments.
- arg_shapes.push_back(arg.shape);
+ (*input_shapes)[i] = arg.shape;
(*input_mapping)[i] = parameters[i];
}
- if (use_tuple_arg) {
- input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes));
- } else {
- *input_shapes = arg_shapes;
- }
-
// Use the _Arg nodes in the graph to resolve core assignments.
for (const Node* n : graph.nodes()) {
if (StringPiece(n->type_string()) != "_Arg") continue;
@@ -355,19 +348,9 @@ Status BuildArguments(const Graph& graph,
// Build parameter handles for non-constant arguments.
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
if (use_tuple_arg) {
- xla::OpSharding tuple_sharding;
- tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
- for (int64 parameter : parameters) {
- const int core = (*arg_cores)[parameter];
- const int root_device = 0;
- *tuple_sharding.add_tuple_shardings() =
- core == -1 ? xla::sharding_builder::AssignDevice(root_device)
- : xla::sharding_builder::AssignDevice(core);
- }
- xla::ScopedShardingAssignment assign_tuple_sharding(builder,
- tuple_sharding);
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes);
xla::ComputationDataHandle tuple =
- builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
+ builder->Parameter(0, tuple_shape, "arg_tuple");
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const int core = (*arg_cores)[parameters[i]];
xla::ScopedShardingAssignment assign_sharding(
@@ -391,7 +374,7 @@ Status BuildArguments(const Graph& graph,
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
VLOG(2) << " XLA arg " << i
- << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
+ << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i])
<< " name: " << arg.name << " TF arg " << parameters[i];
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
switch (arg.kind) {