aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-10-06 11:04:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 11:41:35 -0700
commitaf6e00f7c661c7d93bacfc3adc40d17f0faeb9b4 (patch)
tree95c1bd53bebf4b8e7700dd76508c0f4ad831f19c
parent84b579e1d14760fc2a313c8e1d7ca100f74945a1 (diff)
Fix a minor issue w/ allreduce
PiperOrigin-RevId: 171314944
-rw-r--r--tensorflow/compiler/xla/client/client.cc1
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc13
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc242
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h9
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc9
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc73
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h2
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc18
-rw-r--r--tensorflow/core/common_runtime/executor.cc2
10 files changed, 187 insertions, 192 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 7db2ea79fb..387253617e 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -206,7 +206,6 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
*request.mutable_execution_options() = *execution_options;
}
for (GlobalData* argument : arguments) {
- CHECK(argument != nullptr) << "Argument pointers must not be null.";
*request.add_arguments() = argument->handle();
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 925dcd36c0..15a713513f 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -489,16 +489,6 @@ ComputationDataHandle ComputationBuilder::Collapse(
}
std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
- VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
- VLOG(3) << "dims to collapse: "
- << tensorflow::str_util::Join(dims_to_collapse, ",");
-
- if (dims_to_collapse.size() <= 1) {
- // Not collapsing anything, trivially we can return the operand versus
- // enqueueing a trivial reshape.
- return operand;
- }
-
std::vector<int64> new_sizes;
for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
@@ -508,9 +498,6 @@ ComputationDataHandle ComputationBuilder::Collapse(
}
}
- VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
- << "]";
-
return Reshape(operand, new_sizes);
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 7014685ea5..73972c1290 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -201,16 +201,6 @@ class ComputationBuilder {
// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
// be a consecutive, in-order subsequence of the operand dimensions.
//
- // Note that collapsing a single dimension does nothing:
- //
- // {256} collapsing {0} => {256}
- // {1} collapsing {0} => {1}
- //
- // Collapsing multiple dimensions produces a single result dimension:
- //
- // {256, 2} collapsing {0,1} => {512}
- // {256, 2, 3} collapsing {0,1} => {512, 3}
- //
// This could potentially cause data to be moved -- it provides a more
// structured form of reshaping than an arbitrary Reshape operation.
ComputationDataHandle Collapse(const ComputationDataHandle& operand,
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index e4fb7c0496..4375f13a0e 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -291,7 +291,8 @@ Status IrEmitter::HandleConstant(HloInstruction* constant,
Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (ShapeUtil::IsTuple(copy->shape())) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
+ TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy));
+ emitted_value_[copy] = copy_value;
return EmitMemcpy(*(copy->operand(0)), *copy);
} else {
// Use the elemental emitter for non-tuple shapes.
@@ -394,7 +395,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
TF_RET_CHECK(pred->shape().element_type() == PRED);
if (ShapeUtil::IsTuple(select->shape())) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(select));
+ emitted_value_[select] = output_address;
llvm_ir::EmitTupleSelect(GetIrArrayForOp(select), GetIrArrayForOp(pred),
GetEmittedValueFor(on_true),
GetEmittedValueFor(on_false), &ir_builder_);
@@ -411,8 +414,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
// The infeed operation produces data (dequeued from the infeed queue) at this
// address, which has been provided by buffer assignment.
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
- llvm_ir::IrArray infeed_array = GetIrArrayForOp(infeed);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(infeed));
if (ShapeUtil::IsTuple(shape)) {
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape));
@@ -430,9 +433,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
ShapeUtil::GetTupleElementShape(shape, i);
// Only the outer tuple buffer's target address is obtained from
- // GetEmittedValueFor, to handle the case when Infeed is the root
- // instruction. Target addresses for internal elements can be obtained
- // from EmitTempBufferPointer.
+ // EmitTargetAddressForOp to handle the case when Infeed is the
+ // root instruction. Target addresses for internal elements can
+ // be obtained from EmitTempBufferPointer.
llvm::Value* tuple_element_address =
EmitTempBufferPointer(buffer, tuple_element_shape);
@@ -442,12 +445,15 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
tuple_element_addresses.push_back(tuple_element_address);
}
- llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_);
+ llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape),
+ tuple_element_addresses, &ir_builder_);
} else {
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
- GetEmittedValueFor(infeed)));
+ TF_RETURN_IF_ERROR(
+ EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address));
}
+ emitted_value_[infeed] = target_address;
+
return Status::OK();
}
@@ -561,12 +567,15 @@ Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) {
Status IrEmitter::HandleTuple(
HloInstruction* tuple,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(tuple));
std::vector<llvm::Value*> base_ptrs;
for (auto operand : operands) {
base_ptrs.push_back(GetEmittedValueFor(operand));
}
- llvm_ir::EmitTuple(GetIrArrayForOp(tuple), base_ptrs, &ir_builder_);
+ llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()),
+ base_ptrs, &ir_builder_);
+ emitted_value_[tuple] = target_address;
return Status::OK();
}
@@ -883,8 +892,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs));
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
- llvm_ir::IrArray target_array = GetIrArrayForOp(dot);
+ Shape target_shape = dot->shape();
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(dot));
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*dot, &target_array);
VLOG(2) << "HandleDot: ";
VLOG(2) << " lhs operand: "
@@ -895,10 +907,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
<< llvm_ir::DumpToString(*target_array.GetBasePointer());
// Dot operation is complicated so we delegate to a helper class.
- return DotOpEmitter::EmitDotOperation(
+ TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
*dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_,
- hlo_module_config_);
+ hlo_module_config_));
+
+ emitted_value_[dot] = target_address;
+ return Status::OK();
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution,
@@ -926,7 +941,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
llvm::Value* lhs_address = GetEmittedValueFor(lhs);
llvm::Value* rhs_address = GetEmittedValueFor(rhs);
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(convolution));
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
@@ -1008,33 +1024,35 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
ir_builder_.CreateCall(
- conv_func, {
- GetExecutableRunOptionsArgument(),
- ir_builder_.CreateBitCast(
- GetEmittedValueFor(convolution), float_ptr_type),
- ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
- ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
- ir_builder_.getInt64(input_batch),
- ir_builder_.getInt64(input_rows),
- ir_builder_.getInt64(input_cols),
- ir_builder_.getInt64(input_channels),
- ir_builder_.getInt64(kernel_rows),
- ir_builder_.getInt64(kernel_cols),
- ir_builder_.getInt64(kernel_channels),
- ir_builder_.getInt64(kernel_filters),
- ir_builder_.getInt64(output_rows),
- ir_builder_.getInt64(output_cols),
- ir_builder_.getInt64(row_stride),
- ir_builder_.getInt64(col_stride),
- ir_builder_.getInt64(padding_top),
- ir_builder_.getInt64(padding_bottom),
- ir_builder_.getInt64(padding_left),
- ir_builder_.getInt64(padding_right),
- ir_builder_.getInt64(lhs_row_dilation),
- ir_builder_.getInt64(lhs_col_dilation),
- ir_builder_.getInt64(rhs_row_dilation),
- ir_builder_.getInt64(rhs_col_dilation),
- });
+ conv_func,
+ {
+ GetExecutableRunOptionsArgument(),
+ ir_builder_.CreateBitCast(target_address, float_ptr_type),
+ ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
+ ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
+ ir_builder_.getInt64(input_batch),
+ ir_builder_.getInt64(input_rows),
+ ir_builder_.getInt64(input_cols),
+ ir_builder_.getInt64(input_channels),
+ ir_builder_.getInt64(kernel_rows),
+ ir_builder_.getInt64(kernel_cols),
+ ir_builder_.getInt64(kernel_channels),
+ ir_builder_.getInt64(kernel_filters),
+ ir_builder_.getInt64(output_rows),
+ ir_builder_.getInt64(output_cols),
+ ir_builder_.getInt64(row_stride),
+ ir_builder_.getInt64(col_stride),
+ ir_builder_.getInt64(padding_top),
+ ir_builder_.getInt64(padding_bottom),
+ ir_builder_.getInt64(padding_left),
+ ir_builder_.getInt64(padding_right),
+ ir_builder_.getInt64(lhs_row_dilation),
+ ir_builder_.getInt64(lhs_col_dilation),
+ ir_builder_.getInt64(rhs_row_dilation),
+ ir_builder_.getInt64(rhs_col_dilation),
+ });
+ target_address->setName(AsStringRef(IrName(convolution)));
+ emitted_value_[convolution] = target_address;
return Status::OK();
}
@@ -1349,7 +1367,9 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
mean_array, &ir_builder_)
.EmitLoop(IrName(batch_norm_training, "mean_var")));
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(batch_norm_training));
+
TF_ASSIGN_OR_RETURN(
const BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0}));
@@ -1405,8 +1425,11 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
target_array, &ir_builder_)
.EmitLoop(IrName(batch_norm_training, "normalize")));
- llvm_ir::EmitTuple(GetIrArrayForOp(batch_norm_training),
- {normalized, mean, var}, &ir_builder_);
+ llvm_ir::EmitTuple(
+ llvm_ir::IrArray(target_address, batch_norm_training->shape()),
+ {normalized, mean, var}, &ir_builder_);
+ emitted_value_[batch_norm_training] = target_address;
+
return Status::OK();
}
@@ -1766,7 +1789,6 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
}
CHECK(!ShapeUtil::IsTuple(reduce->shape()));
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
// We know we're not reducing over the most minor dimension, which means we
// can lower the reduction loop as:
@@ -1829,7 +1851,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
reduction_generator, array_index, vector_type,
init_value, arg, dimensions, element_alignment));
- llvm_ir::IrArray target_array = GetIrArrayForOp(reduce);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(reduce));
+ llvm_ir::IrArray target_array(target_address, reduce->shape());
+ AddAliasingInformationToIrArray(*reduce, &target_array);
llvm::Value* output_address =
target_array.EmitArrayElementAddress(array_index, &ir_builder_);
EmitShardedVectorStore(output_address, accumulator, element_alignment,
@@ -1861,7 +1886,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
reduction_generator, array_index, vector_type,
init_value, arg, dimensions, element_alignment));
- llvm_ir::IrArray target_array = GetIrArrayForOp(reduce);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(reduce));
+ llvm_ir::IrArray target_array(target_address, reduce->shape());
+ AddAliasingInformationToIrArray(*reduce, &target_array);
llvm::Value* output_address =
target_array.EmitArrayElementAddress(array_index, &ir_builder_);
EmitShardedVectorStore(output_address, accumulator, element_alignment,
@@ -1872,6 +1900,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
ir_builder_.SetInsertPoint(outermost_loop_exit_block);
}
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(reduce));
+
+ emitted_value_[reduce] = target_address;
return true;
}
@@ -1971,7 +2003,9 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) {
return DefaultAction(slice);
}
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(slice));
+ emitted_value_[slice] = target_address;
if (ShapeUtil::HasZeroElements(slice->shape())) {
return Status::OK();
@@ -2043,7 +2077,8 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) {
outer_dims.push_back(memcpy_dim);
}
- llvm_ir::IrArray target_array = GetIrArrayForOp(slice);
+ llvm_ir::IrArray target_array(target_address, slice->shape());
+ AddAliasingInformationToIrArray(*slice, &target_array);
const int64 num_outer_loops = outer_dims.size();
llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_);
@@ -2096,7 +2131,10 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
HloInstruction* /*start_indices*/) {
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(dynamic_slice));
+ target_address->setName(AsStringRef(IrName(dynamic_slice)));
+ emitted_value_[dynamic_slice] = target_address;
return EmitMemcpy(*operand, *dynamic_slice);
}
return DefaultAction(dynamic_slice);
@@ -2152,7 +2190,10 @@ Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* update,
HloInstruction* start_indices) {
if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(dynamic_update_slice));
+ target_address->setName(AsStringRef(IrName(dynamic_update_slice)));
+ emitted_value_[dynamic_update_slice] = target_address;
return EmitMemcpy(*update, *dynamic_update_slice);
} else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) {
VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place.";
@@ -2206,7 +2247,9 @@ Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_)
.EmitLoop(IrName(dynamic_update_slice, "in_place")));
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
+ TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address,
+ EmitTargetAddressForOp(dynamic_update_slice));
+ emitted_value_[dynamic_update_slice] = dynamic_update_slice_address;
return Status::OK();
}
return DefaultAction(dynamic_update_slice);
@@ -2305,8 +2348,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs));
Shape target_shape = fusion->shape();
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
- llvm_ir::IrArray target_array = GetIrArrayForOp(fusion);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(fusion));
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*fusion, &target_array);
+
VLOG(2) << "HandleFusion kTransposeDot: ";
VLOG(2) << " lhs operand: "
<< llvm_ir::DumpToString(*lhs_array.GetBasePointer());
@@ -2320,6 +2366,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
*dot, dot->operand(0)->IsRank2Transpose(),
dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array,
GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_));
+
+ emitted_value_[fusion] = target_address;
return Status::OK();
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
std::vector<llvm_ir::IrArray> parameter_arrays;
@@ -2345,9 +2393,14 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
parameter_addresses.push_back(GetEmittedValueFor(operand));
}
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(call));
+ output_address->setName(AsStringRef(IrName(call)));
+
EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
- emitted_value_[call], computation->name());
+ output_address, computation->name());
+
+ emitted_value_[call] = output_address;
return Status::OK();
}
@@ -2376,13 +2429,17 @@ Status IrEmitter::HandleCustomCall(
/*Params=*/{i8_ptr_type, operands_alloca->getType()},
/*isVarArg=*/false)));
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
- auto* output_address_arg = ir_builder_.CreatePointerCast(
- GetEmittedValueFor(custom_call), i8_ptr_type);
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(custom_call));
+ output_address->setName(AsStringRef(IrName(custom_call)));
+
+ auto* output_address_arg =
+ ir_builder_.CreatePointerCast(output_address, i8_ptr_type);
ir_builder_.CreateCall(custom_call_ir_function,
{output_address_arg, operands_alloca});
+ emitted_value_[custom_call] = output_address;
return Status::OK();
}
@@ -2526,8 +2583,10 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
llvm::Type* i8_type = ir_builder_.getInt8Ty();
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
- llvm_ir::IrArray target_array = GetIrArrayForOp(concatenate);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(concatenate));
+
+ llvm_ir::IrArray target_array(target_address, output_shape);
llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_);
llvm_ir::IrArray::Index outer_dims_index =
@@ -2544,6 +2603,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
unsigned primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
+ AddAliasingInformationToIrArray(*concatenate, &target_array);
+
// Contiguous subregions from each operand to the concatenate contribute to a
// contiguous subregion in the target buffer starting at target_region_begin.
llvm::Value* target_region_begin = ir_builder_.CreateBitCast(
@@ -2586,6 +2647,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
}
+ emitted_value_[concatenate] = target_address;
+
return true;
}
@@ -2779,6 +2842,15 @@ Status IrEmitter::Preprocess(HloInstruction* hlo) {
}
Status IrEmitter::Postprocess(HloInstruction* hlo) {
+ // Set the name of the emitted llvm::Value to IrName(hlo). Outfeed and send
+ // the only ops that don't emit a value.
+ if (hlo->opcode() != HloOpcode::kOutfeed &&
+ hlo->opcode() != HloOpcode::kSend) {
+ auto it = emitted_value_.find(hlo);
+ CHECK(it != emitted_value_.end());
+ it->second->setName(AsStringRef(IrName(hlo)));
+ }
+
if (auto* prof_counter = GetProfileCounterFor(hlo)) {
profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter);
}
@@ -2955,10 +3027,10 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
return return_value_buffer;
}
-Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
- llvm::Value* addr;
- const Shape& target_shape = op->shape();
- if (op == op->parent()->root_instruction()) {
+StatusOr<llvm::Value*> IrEmitter::EmitTargetAddressForOp(
+ const HloInstruction* op, const ShapeIndex& shape_index) {
+ const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index);
+ if (op == op->parent()->root_instruction() && shape_index.empty()) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = GetResultArgument();
@@ -2968,18 +3040,15 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
retval->addAttrs(attr_builder);
}
- addr = ir_builder_.CreateBitCast(retval,
+ return ir_builder_.CreateBitCast(retval,
IrShapeType(target_shape)->getPointerTo());
- } else {
- // For other nodes, we need the temporary buffer allocated for this node to
- // write the result into.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- addr = EmitTempBufferPointer(slice, target_shape);
- }
- addr->setName(AsStringRef(IrName(op)));
- emitted_value_[op] = addr;
- return Status::OK();
+ }
+
+ // For other nodes, we need the temporary buffer allocated for this node to
+ // write the result into.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ return EmitTempBufferPointer(slice, target_shape);
}
Status IrEmitter::EmitTargetElementLoop(
@@ -2993,9 +3062,12 @@ Status IrEmitter::EmitTargetElementLoop(
const llvm_ir::ElementGenerator& element_generator) {
VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
+ // target_address will hold the address of the target buffer we will write the
+ // result of the computation into.
const Shape& target_shape = target_op->shape();
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
- llvm_ir::IrArray target_array = GetIrArrayForOp(target_op);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(target_op));
+ VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address);
if (target_op->IsMultiOutputFusion()) {
// For multiple outputs fusion, we need to emit each operand and the root.
@@ -3018,9 +3090,13 @@ Status IrEmitter::EmitTargetElementLoop(
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
- llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_);
+ llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape),
+ tuple_operand_ptrs, &ir_builder_);
} else {
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*target_op, &target_array);
+
if (ShouldEmitParallelLoopFor(*target_op)) {
TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop(
target_shape, element_generator, IrName(target_op), &target_array));
@@ -3030,6 +3106,8 @@ Status IrEmitter::EmitTargetElementLoop(
.EmitLoop(IrName(target_op)));
}
}
+
+ emitted_value_[target_op] = target_address;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index fd9ee71799..05663b6038 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -353,10 +353,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status EmitMemcpy(const HloInstruction& source,
const HloInstruction& destination);
- // Emits IR to compute the target address of the buffer for the given op.
- // After calling this function, you can get a pointer to this buffer by
- // calling GetIrArrayForOp or GetEmittedValueFor.
- Status EmitTargetAddressForOp(const HloInstruction* op);
+ // Emit IR to compute the target address of the buffer for the given op.
+ // The returned Value is a pointer to a IR type that represents the op's
+ // element type.
+ StatusOr<llvm::Value*> EmitTargetAddressForOp(
+ const HloInstruction* op, const ShapeIndex& shape_index = {});
// Structurizes "array_elements" into an MD array that represents "shape".
// This is a recursive function, and "dimension_index" indicates the index of
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 29221d2d29..ffd8018827 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1894,16 +1894,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
Shape inferred_shape =
ShapeUtil::MakeShape(operand.element_type(), new_sizes);
- VLOG(3) << "Reshape inferred shape: "
- << ShapeUtil::HumanString(inferred_shape);
if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
return InvalidArgument(
- "reshape operation has mismatched element counts: from=%lld (%s) "
- "to=%lld (%s)",
- ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(),
- ShapeUtil::ElementsIn(inferred_shape),
- ShapeUtil::HumanString(inferred_shape).c_str());
+ "reshape operation has mismatched element counts: from=%lld to=%lld",
+ ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape));
}
std::vector<int64> indices(ShapeUtil::Rank(operand));
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 2876a79dd8..061a4e190f 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -39,60 +39,30 @@ limitations under the License.
namespace xla {
-/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
- const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return ::testing::AssertionFailure()
- << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
- << " got: " << ShapeUtil::HumanString(actual);
- }
+/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
+ const Shape& actual) {
+ ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual));
if (ShapeUtil::IsTuple(expected)) {
- if (ShapeUtil::TupleElementCount(expected) !=
- ShapeUtil::TupleElementCount(actual)) {
- return ::testing::AssertionFailure()
- << "want tuple element count: "
- << ShapeUtil::TupleElementCount(expected)
- << " got tuple element count: "
- << ShapeUtil::TupleElementCount(actual);
- }
+ ASSERT_EQ(ShapeUtil::TupleElementCount(expected),
+ ShapeUtil::TupleElementCount(actual));
for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
- ::testing::AssertionResult result =
- EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
- if (!result) {
- return result;
- }
+ AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
}
} else {
- if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
- return ::testing::AssertionFailure()
- << "want rank of: " << ShapeUtil::HumanString(expected)
- << " got rank of: " << ShapeUtil::HumanString(actual);
- }
- if (expected.element_type() != actual.element_type()) {
- return ::testing::AssertionFailure()
- << PrimitiveType_Name(expected.element_type()) << " vs "
- << PrimitiveType_Name(actual.element_type());
- }
- if (expected.dimensions_size() != actual.dimensions_size()) {
- return ::testing::AssertionFailure()
- << "want dimensions_size " << expected.dimensions_size()
- << " got dimensions_size " << actual.dimensions_size();
- }
+ ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual))
+ << "want rank of: " << ShapeUtil::HumanString(expected)
+ << " got rank of: " << ShapeUtil::HumanString(actual);
+ ASSERT_EQ(expected.element_type(), actual.element_type())
+ << PrimitiveType_Name(expected.element_type()) << " vs "
+ << PrimitiveType_Name(actual.element_type());
+ ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size());
for (int i = 0; i < expected.dimensions_size(); ++i) {
- if (expected.dimensions(i) != actual.dimensions(i)) {
- return ::testing::AssertionFailure()
- << "mismatch in dimension #" << i
- << " expected: " << ShapeUtil::HumanString(expected)
- << " actual: " << ShapeUtil::HumanString(actual);
- }
+ ASSERT_EQ(expected.dimensions(i), actual.dimensions(i))
+ << "mismatch in dimension #" << i
+ << " expected: " << ShapeUtil::HumanString(expected)
+ << " actual: " << ShapeUtil::HumanString(actual);
}
}
- return ::testing::AssertionSuccess();
-}
-
-/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
- const Shape& actual) {
- ASSERT_TRUE(EqualShapes(expected, actual));
}
/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
@@ -295,14 +265,7 @@ class NearComparator {
VLOG(1) << "actual:";
XLA_VLOG_LINES(1, actual.ToString());
- // If the shapes mismatch, we simply fail the expectation instead of
- // printing out data, as it's a type error rather than a value error.
- ::testing::AssertionResult equal_shapes =
- LiteralTestUtil::EqualShapes(expected.shape(), actual.shape());
- if (!equal_shapes) {
- EXPECT_TRUE(equal_shapes);
- return false;
- }
+ LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape());
// Set up members used during the comparison.
num_miscompares_ = 0;
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 467d44b857..f645c4e8dc 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -50,8 +50,6 @@ class LiteralTestUtil {
public:
// Asserts that the given shapes have the same rank, dimension sizes, and
// primitive types.
- static ::testing::AssertionResult EqualShapes(const Shape& expected,
- const Shape& actual);
static void AssertEqualShapes(const Shape& expected, const Shape& actual);
// Asserts that the provided shapes are equal as defined in AssertEqualShapes
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 72c68f24a0..bb7160e3a0 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -47,7 +47,7 @@ class ReshapeTest : public ClientLibraryTestBase {
};
// Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension.
-XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) {
+XLA_TEST_F(ReshapeTest, Trivial1x1) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR2<float>({{1.0}});
builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
@@ -55,22 +55,6 @@ XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) {
ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_);
}
-XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) {
- ComputationBuilder builder(client_, TestName());
- auto a = builder.ConstantR1<float>({1.0});
- builder.Collapse(/*operand=*/a, /*dimensions=*/{});
-
- ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_);
-}
-
-XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) {
- ComputationBuilder builder(client_, TestName());
- auto a = builder.ConstantR1<float>({1.0});
- builder.Collapse(/*operand=*/a, /*dimensions=*/{0});
-
- ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_);
-}
-
// Collapses 2-dimensional pseudo-scalar (single-element array) to scalar.
XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 11e063d8d2..f57834cfbe 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2008,7 +2008,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
NodeExecStatsWrapper* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
+ if (!SetTimelineLabel(node, stats)) {
// Only record non-transfer nodes.
// Transfers 'stats' ownership to 'stats_collector_'.
stats_collector_->Save(impl_->params_.device->name(), stats);