diff options
author | James Qin <jamesqin@google.com> | 2017-10-06 11:04:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-06 11:41:35 -0700 |
commit | af6e00f7c661c7d93bacfc3adc40d17f0faeb9b4 (patch) | |
tree | 95c1bd53bebf4b8e7700dd76508c0f4ad831f19c | |
parent | 84b579e1d14760fc2a313c8e1d7ca100f74945a1 (diff) |
Fix a minor issue w/ allreduce
PiperOrigin-RevId: 171314944
-rw-r--r-- | tensorflow/compiler/xla/client/client.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.h | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 242 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.cc | 73 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/reshape_test.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 2 |
10 files changed, 187 insertions, 192 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 7db2ea79fb..387253617e 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -206,7 +206,6 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute( *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); } diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 925dcd36c0..15a713513f 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -489,16 +489,6 @@ ComputationDataHandle ComputationBuilder::Collapse( } std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie(); - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dims_to_collapse, ","); - - if (dims_to_collapse.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - std::vector<int64> new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { @@ -508,9 +498,6 @@ ComputationDataHandle ComputationBuilder::Collapse( } } - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - return Reshape(operand, new_sizes); } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 7014685ea5..73972c1290 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -201,16 +201,6 @@ class ComputationBuilder { // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must // be a consecutive, in-order subsequence of the operand dimensions. // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. ComputationDataHandle Collapse(const ComputationDataHandle& operand, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e4fb7c0496..4375f13a0e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -291,7 +291,8 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); + TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); + emitted_value_[copy] = copy_value; return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. @@ -394,7 +395,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(select)); + emitted_value_[select] = output_address; llvm_ir::EmitTupleSelect(GetIrArrayForOp(select), GetIrArrayForOp(pred), GetEmittedValueFor(on_true), GetEmittedValueFor(on_false), &ir_builder_); @@ -411,8 +414,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { // The infeed operation produces data (dequeued from the infeed queue) at this // address, which has been provided by buffer assignment. - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); - llvm_ir::IrArray infeed_array = GetIrArrayForOp(infeed); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(infeed)); if (ShapeUtil::IsTuple(shape)) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); @@ -430,9 +433,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { ShapeUtil::GetTupleElementShape(shape, i); // Only the outer tuple buffer's target address is obtained from - // GetEmittedValueFor, to handle the case when Infeed is the root - // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // EmitTargetAddressForOp to handle the case when Infeed is the + // root instruction. Target addresses for internal elements can + // be obtained from EmitTempBufferPointer. llvm::Value* tuple_element_address = EmitTempBufferPointer(buffer, tuple_element_shape); @@ -442,12 +445,15 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_); + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), + tuple_element_addresses, &ir_builder_); } else { - TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, - GetEmittedValueFor(infeed))); + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address)); } + emitted_value_[infeed] = target_address; + return Status::OK(); } @@ -561,12 +567,15 @@ Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { Status IrEmitter::HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(tuple)); std::vector<llvm::Value*> base_ptrs; for (auto operand : operands) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(GetIrArrayForOp(tuple), base_ptrs, &ir_builder_); + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()), + base_ptrs, &ir_builder_); + emitted_value_[tuple] = target_address; return Status::OK(); } @@ -883,8 +892,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); - llvm_ir::IrArray target_array = GetIrArrayForOp(dot); + Shape target_shape = dot->shape(); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dot)); + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*dot, &target_array); VLOG(2) << "HandleDot: "; VLOG(2) << " lhs operand: " @@ -895,10 +907,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - return DotOpEmitter::EmitDotOperation( + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_); + hlo_module_config_)); + + emitted_value_[dot] = target_address; + return Status::OK(); } Status IrEmitter::HandleConvolution(HloInstruction* convolution, @@ -926,7 +941,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, bool one_dim_convolution = lhs_shape.dimensions_size() == 3; llvm::Value* lhs_address = GetEmittedValueFor(lhs); llvm::Value* rhs_address = GetEmittedValueFor(rhs); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(convolution)); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -1008,33 +1024,35 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); ir_builder_.CreateCall( - conv_func, { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast( - GetEmittedValueFor(convolution), float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); + conv_func, + { + GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast(target_address, float_ptr_type), + ir_builder_.CreateBitCast(lhs_address, float_ptr_type), + ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(input_rows), + ir_builder_.getInt64(input_cols), + ir_builder_.getInt64(input_channels), + ir_builder_.getInt64(kernel_rows), + ir_builder_.getInt64(kernel_cols), + ir_builder_.getInt64(kernel_channels), + ir_builder_.getInt64(kernel_filters), + ir_builder_.getInt64(output_rows), + ir_builder_.getInt64(output_cols), + ir_builder_.getInt64(row_stride), + ir_builder_.getInt64(col_stride), + ir_builder_.getInt64(padding_top), + ir_builder_.getInt64(padding_bottom), + ir_builder_.getInt64(padding_left), + ir_builder_.getInt64(padding_right), + ir_builder_.getInt64(lhs_row_dilation), + ir_builder_.getInt64(lhs_col_dilation), + ir_builder_.getInt64(rhs_row_dilation), + ir_builder_.getInt64(rhs_col_dilation), + }); + target_address->setName(AsStringRef(IrName(convolution))); + emitted_value_[convolution] = target_address; return Status::OK(); } @@ -1349,7 +1367,9 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { mean_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "mean_var"))); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(batch_norm_training)); + TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); @@ -1405,8 +1425,11 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { target_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "normalize"))); - llvm_ir::EmitTuple(GetIrArrayForOp(batch_norm_training), - {normalized, mean, var}, &ir_builder_); + llvm_ir::EmitTuple( + llvm_ir::IrArray(target_address, batch_norm_training->shape()), + {normalized, mean, var}, &ir_builder_); + emitted_value_[batch_norm_training] = target_address; + return Status::OK(); } @@ -1766,7 +1789,6 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( } CHECK(!ShapeUtil::IsTuple(reduce->shape())); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we // can lower the reduction loop as: @@ -1829,7 +1851,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - llvm_ir::IrArray target_array = GetIrArrayForOp(reduce); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + llvm_ir::IrArray target_array(target_address, reduce->shape()); + AddAliasingInformationToIrArray(*reduce, &target_array); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1861,7 +1886,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - llvm_ir::IrArray target_array = GetIrArrayForOp(reduce); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + llvm_ir::IrArray target_array(target_address, reduce->shape()); + AddAliasingInformationToIrArray(*reduce, &target_array); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1872,6 +1900,10 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( ir_builder_.SetInsertPoint(outermost_loop_exit_block); } + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(reduce)); + + emitted_value_[reduce] = target_address; return true; } @@ -1971,7 +2003,9 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { return DefaultAction(slice); } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(slice)); + emitted_value_[slice] = target_address; if (ShapeUtil::HasZeroElements(slice->shape())) { return Status::OK(); @@ -2043,7 +2077,8 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { outer_dims.push_back(memcpy_dim); } - llvm_ir::IrArray target_array = GetIrArrayForOp(slice); + llvm_ir::IrArray target_array(target_address, slice->shape()); + AddAliasingInformationToIrArray(*slice, &target_array); const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); @@ -2096,7 +2131,10 @@ Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, HloInstruction* operand, HloInstruction* /*start_indices*/) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_slice)); + target_address->setName(AsStringRef(IrName(dynamic_slice))); + emitted_value_[dynamic_slice] = target_address; return EmitMemcpy(*operand, *dynamic_slice); } return DefaultAction(dynamic_slice); @@ -2152,7 +2190,10 @@ Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* update, HloInstruction* start_indices) { if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_update_slice)); + target_address->setName(AsStringRef(IrName(dynamic_update_slice))); + emitted_value_[dynamic_update_slice] = target_address; return EmitMemcpy(*update, *dynamic_update_slice); } else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) { VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place."; @@ -2206,7 +2247,9 @@ Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_) .EmitLoop(IrName(dynamic_update_slice, "in_place"))); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); + TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address, + EmitTargetAddressForOp(dynamic_update_slice)); + emitted_value_[dynamic_update_slice] = dynamic_update_slice_address; return Status::OK(); } return DefaultAction(dynamic_update_slice); @@ -2305,8 +2348,11 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayForOp(fusion); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*fusion, &target_array); + VLOG(2) << "HandleFusion kTransposeDot: "; VLOG(2) << " lhs operand: " << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); @@ -2320,6 +2366,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { *dot, dot->operand(0)->IsRank2Transpose(), dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); + + emitted_value_[fusion] = target_address; return Status::OK(); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { std::vector<llvm_ir::IrArray> parameter_arrays; @@ -2345,9 +2393,14 @@ Status IrEmitter::HandleCall(HloInstruction* call) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(call)); + output_address->setName(AsStringRef(IrName(call))); + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + output_address, computation->name()); + + emitted_value_[call] = output_address; return Status::OK(); } @@ -2376,13 +2429,17 @@ Status IrEmitter::HandleCustomCall( /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); - auto* output_address_arg = ir_builder_.CreatePointerCast( - GetEmittedValueFor(custom_call), i8_ptr_type); + TF_ASSIGN_OR_RETURN(llvm::Value * output_address, + EmitTargetAddressForOp(custom_call)); + output_address->setName(AsStringRef(IrName(custom_call))); + + auto* output_address_arg = + ir_builder_.CreatePointerCast(output_address, i8_ptr_type); ir_builder_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + emitted_value_[custom_call] = output_address; return Status::OK(); } @@ -2526,8 +2583,10 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); - llvm_ir::IrArray target_array = GetIrArrayForOp(concatenate); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(concatenate)); + + llvm_ir::IrArray target_array(target_address, output_shape); llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); llvm_ir::IrArray::Index outer_dims_index = @@ -2544,6 +2603,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + AddAliasingInformationToIrArray(*concatenate, &target_array); + // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = ir_builder_.CreateBitCast( @@ -2586,6 +2647,8 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); } + emitted_value_[concatenate] = target_address; + return true; } @@ -2779,6 +2842,15 @@ Status IrEmitter::Preprocess(HloInstruction* hlo) { } Status IrEmitter::Postprocess(HloInstruction* hlo) { + // Set the name of the emitted llvm::Value to IrName(hlo). Outfeed and send + // the only ops that don't emit a value. + if (hlo->opcode() != HloOpcode::kOutfeed && + hlo->opcode() != HloOpcode::kSend) { + auto it = emitted_value_.find(hlo); + CHECK(it != emitted_value_.end()); + it->second->setName(AsStringRef(IrName(hlo))); + } + if (auto* prof_counter = GetProfileCounterFor(hlo)) { profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); } @@ -2955,10 +3027,10 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; - const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { +StatusOr<llvm::Value*> IrEmitter::EmitTargetAddressForOp( + const HloInstruction* op, const ShapeIndex& shape_index) { + const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); + if (op == op->parent()->root_instruction() && shape_index.empty()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); @@ -2968,18 +3040,15 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); retval->addAttrs(attr_builder); } - addr = ir_builder_.CreateBitCast(retval, + return ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } - addr->setName(AsStringRef(IrName(op))); - emitted_value_[op] = addr; - return Status::OK(); + } + + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + return EmitTempBufferPointer(slice, target_shape); } Status IrEmitter::EmitTargetElementLoop( @@ -2993,9 +3062,12 @@ Status IrEmitter::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); + // target_address will hold the address of the target buffer we will write the + // result of the computation into. const Shape& target_shape = target_op->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); - llvm_ir::IrArray target_array = GetIrArrayForOp(target_op); + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(target_op)); + VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); if (target_op->IsMultiOutputFusion()) { // For multiple outputs fusion, we need to emit each operand and the root. @@ -3018,9 +3090,13 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), + tuple_operand_ptrs, &ir_builder_); } else { + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*target_op, &target_array); + if (ShouldEmitParallelLoopFor(*target_op)) { TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( target_shape, element_generator, IrName(target_op), &target_array)); @@ -3030,6 +3106,8 @@ Status IrEmitter::EmitTargetElementLoop( .EmitLoop(IrName(target_op))); } } + + emitted_value_[target_op] = target_address; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index fd9ee71799..05663b6038 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -353,10 +353,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitMemcpy(const HloInstruction& source, const HloInstruction& destination); - // Emits IR to compute the target address of the buffer for the given op. - // After calling this function, you can get a pointer to this buffer by - // calling GetIrArrayForOp or GetEmittedValueFor. - Status EmitTargetAddressForOp(const HloInstruction* op); + // Emit IR to compute the target address of the buffer for the given op. + // The returned Value is a pointer to a IR type that represents the op's + // element type. + StatusOr<llvm::Value*> EmitTargetAddressForOp( + const HloInstruction* op, const ShapeIndex& shape_index = {}); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 29221d2d29..ffd8018827 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1894,16 +1894,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); - VLOG(3) << "Reshape inferred shape: " - << ShapeUtil::HumanString(inferred_shape); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s)", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + "reshape operation has mismatched element counts: from=%lld to=%lld", + ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape)); } std::vector<int64> indices(ShapeUtil::Rank(operand)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 2876a79dd8..061a4e190f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -39,60 +39,30 @@ limitations under the License. namespace xla { -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } +/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, + const Shape& actual) { + ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } + ASSERT_EQ(ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); - if (!result) { - return result; - } + AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); } } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } + ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)) + << "want rank of: " << ShapeUtil::HumanString(expected) + << " got rank of: " << ShapeUtil::HumanString(actual); + ASSERT_EQ(expected.element_type(), actual.element_type()) + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } + ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); } } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); } /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( @@ -295,14 +265,7 @@ class NearComparator { VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!equal_shapes) { - EXPECT_TRUE(equal_shapes); - return false; - } + LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); // Set up members used during the comparison. num_miscompares_ = 0; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 467d44b857..f645c4e8dc 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -50,8 +50,6 @@ class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); static void AssertEqualShapes(const Shape& expected, const Shape& actual); // Asserts that the provided shapes are equal as defined in AssertEqualShapes diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 72c68f24a0..bb7160e3a0 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -47,7 +47,7 @@ class ReshapeTest : public ClientLibraryTestBase { }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_F(ReshapeTest, Trivial1x1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2<float>({{1.0}}); builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); @@ -55,22 +55,6 @@ XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1<float>({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_); -} - -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1<float>({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_); -} - // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 11e063d8d2..f57834cfbe 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -2008,7 +2008,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node, NodeExecStatsWrapper* stats, TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); - if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) { + if (!SetTimelineLabel(node, stats)) { // Only record non-transfer nodes. // Transfers 'stats' ownership to 'stats_collector_'. stats_collector_->Save(impl_->params_.device->name(), stats); |