From 7da384ab3d9577311c7e073ea29ad6eab6bccfc9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Jan 2018 19:55:40 -0800 Subject: Automated g4 rollback of changelist 181548597 PiperOrigin-RevId: 181553949 --- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 4 +--- tensorflow/compiler/jit/xla_compilation_cache.cc | 9 ++++---- tensorflow/compiler/jit/xla_compilation_cache.h | 3 +-- tensorflow/compiler/tf2xla/kernels/while_op.cc | 14 ++++-------- tensorflow/compiler/tf2xla/xla_compiler.cc | 27 +++++------------------- 5 files changed, 15 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 4842877d9a..4f3f17df9c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -257,10 +257,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + variables, ctx, &kernel, &executable)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bfff52c55a..3717c2cc24 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -238,8 +238,7 @@ Status XlaCompilationCache::Compile( int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + xla::LocalExecutable** executable) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -298,9 +297,9 @@ Status XlaCompilationCache::Compile( XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, + &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 0858020716..c3a8f68a15 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -66,8 +66,7 @@ class XlaCompilationCache : public ResourceBase { const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + xla::LocalExecutable** executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 66730b9276..ee466520dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -201,16 +201,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape body_input_shape = body.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), - errors::FailedPrecondition("Expected tuple shape")); - OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape cond_input_shape = cond.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), - errors::FailedPrecondition("Expected tuple shape")); + xla::Shape body_input_shape = + xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); + xla::Shape cond_input_shape = + xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 310cf20ec1..c55719be55 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -316,22 +316,15 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } - std::vector arg_shapes; - arg_shapes.reserve(parameters.size()); + input_shapes->resize(parameters.size()); input_mapping->resize(parameters.size()); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. - arg_shapes.push_back(arg.shape); + (*input_shapes)[i] = arg.shape; (*input_mapping)[i] = parameters[i]; } - if (use_tuple_arg) { - input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); - } else { - *input_shapes = arg_shapes; - } - // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { if (StringPiece(n->type_string()) != "_Arg") continue; @@ -355,19 +348,9 @@ Status BuildArguments(const Graph& graph, // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { - xla::OpSharding tuple_sharding; - tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : parameters) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; - *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); - } - xla::ScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); xla::ComputationDataHandle tuple = - builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + builder->Parameter(0, tuple_shape, "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const int core = (*arg_cores)[parameters[i]]; xla::ScopedShardingAssignment assign_sharding( @@ -391,7 +374,7 @@ Status BuildArguments(const Graph& graph, for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; VLOG(2) << " XLA arg " << i - << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) + << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) << " name: " << arg.name << " TF arg " << parameters[i]; XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; switch (arg.kind) { -- cgit v1.2.3