diff options
author | 2017-06-27 13:33:56 -0700 | |
---|---|---|
committer | 2017-06-27 13:41:30 -0700 | |
commit | 00feb62ef5f44c8d3dd32199e552eb5de8049e59 (patch) | |
tree | d968e3d68570db313293aae40d2ff035b2269643 /tensorflow/compiler/xla/service/hlo_cost_analysis.cc | |
parent | febcf3c7b9fddffe9aceaae626f184d5c39b657f (diff) |
Add a time estimation to HloCostAnalysis and represent properties as a map so that adding more properties will be easier, e.g. in a sub-class.
PiperOrigin-RevId: 160318494
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.cc | 313 |
1 files changed, 195 insertions, 118 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index f4c0bb8a4a..522dddea4e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -25,34 +25,56 @@ limitations under the License. namespace xla { +constexpr char HloCostAnalysis::kFlopsKey[]; +constexpr char HloCostAnalysis::kTranscendentalsKey[]; +constexpr char HloCostAnalysis::kBytesAccessedKey[]; +constexpr char HloCostAnalysis::kSecondsKey[]; + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) + : HloCostAnalysis(shape_size, {}) {} + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, + const Properties& per_second_rates) + : shape_size_(shape_size), per_second_rates_(per_second_rates) {} + Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { // Set current instruction cost values to reasonable default values. Each - // handler can overwrite these values. In Postprocess, these value are + // handler can overwrite these values. In Postprocess, these values are // accumulated and written to the per-instruction maps. - current_flop_count_ = 0; - current_transcendental_count_ = 0; + current_properties_.clear(); + current_should_compute_bottleneck_time_ = true; - // The default element count for an instruction is the sum of elements in the - // operands and output. The default ShapeUtil::ByteSizeOf does not handle - // opaque types. - current_bytes_accessed_ = shape_size_(hlo->shape()); + // The default number of bytes accessed for an instruction is the sum of the + // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not + // handle opaque types. + float bytes_accessed = shape_size_(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - current_bytes_accessed_ += shape_size_(operand->shape()); + bytes_accessed += shape_size_(operand->shape()); } + current_properties_[kBytesAccessedKey] = bytes_accessed; return Status::OK(); } Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { - // Accumulate cost values and write into per-instruction maps. - flop_count_ += current_flop_count_; - hlo_to_flop_count_[hlo] = current_flop_count_; - - transcendental_count_ += current_transcendental_count_; - hlo_to_transcendental_count_[hlo] = current_transcendental_count_; + if (current_should_compute_bottleneck_time_) { + // Compute the time as the time of the bottleneck, i.e. the slowest property + // given the per-second rate of each property. + float max_seconds = 0.0f; + for (const auto& property : current_properties_) { + if (property.first != kSecondsKey) { + max_seconds = std::max( + max_seconds, + property.second / GetProperty(property.first, per_second_rates_)); + } + } + current_properties_[kSecondsKey] = max_seconds; + } - bytes_accessed_ += current_bytes_accessed_; - hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_; + TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); + for (const auto& property : current_properties_) { + properties_sum_[property.first] += property.second; + } return Status::OK(); } @@ -65,15 +87,32 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { auto opcode = hlo_instruction->opcode(); // We treat the two opcodes (kExp, kPower) as transcendental operations. if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) { - current_transcendental_count_ = computation_count; + current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from // FLOPs. - current_flop_count_ = computation_count; + current_properties_[kFlopsKey] = computation_count; } return Status::OK(); } +/*static*/ float HloCostAnalysis::GetProperty(const string& key, + const Properties& properties) { + auto key_value = properties.find(key); + return key_value == properties.end() ? 0.0f : key_value->second; +} + +/*static*/ float HloCostAnalysis::GetPropertyForHlo( + const HloInstruction& hlo, const string& key, + const HloToProperties& hlo_to_properties) { + auto it = hlo_to_properties.find(&hlo); + if (it == hlo_to_properties.end()) { + return 0.0f; + } else { + return GetProperty(key, it->second); + } +} + Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) { return HandleElementwiseOp(hlo); @@ -102,13 +141,13 @@ Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(HloInstruction* constant, const Literal& literal) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -116,7 +155,7 @@ Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) { // GetTupleElement forwards a pointer and does not touch each element in the // output. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -154,8 +193,9 @@ Status HloCostAnalysis::HandleTuple( tensorflow::gtl::ArraySlice<HloInstruction*> operands) { // The tuple instruction only gathers pointers from inputs (it doesn't iterate // through them). The memory touched is then only the size of the output - // buffer. - current_bytes_accessed_ = shape_size_(tuple->shape()); + // index table of the tuple. + + current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); return Status::OK(); } @@ -193,7 +233,7 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot, } // We count an FMA operation as 2 floating point operations. - current_flop_count_ = kFmaFlops * fma_count; + current_properties_[kFlopsKey] = kFmaFlops * fma_count; return Status::OK(); } @@ -209,16 +249,17 @@ Status HloCostAnalysis::HandleMap( HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* function, tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute properties of the mapped function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Map operation. - int64 element_count = ShapeUtil::ElementsIn(map->shape()); - current_transcendental_count_ = - element_count * visitor.transcendental_count(); - current_flop_count_ = element_count * visitor.flop_count(); + const int64 element_count = ShapeUtil::ElementsIn(map->shape()); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * element_count; + } + } return Status::OK(); } @@ -226,16 +267,17 @@ Status HloCostAnalysis::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) { // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(reduce->shape()); - current_flop_count_ = reduction_count * visitor.flop_count(); - current_transcendental_count_ = - reduction_count * visitor.transcendental_count(); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } + } return Status::OK(); } @@ -243,55 +285,63 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, HloComputation* function) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each - // output element, (window_size - 1) number of user computations are applied. - auto output_size = ShapeUtil::ElementsIn(reduce_window->shape()); - int64 window_size = 1; + // output element there are window_size - 1 reductions to perform. + int64 window_element_count = 1; for (const auto& dimension : window.dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 output_element_count = + ShapeUtil::ElementsIn(reduce_window->shape()); + const int64 reduction_count = + (window_element_count - 1) * output_element_count; + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } } - current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count(); - current_transcendental_count_ = - output_size * (window_size - 1) * visitor.transcendental_count(); return Status::OK(); } Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { - // Compute the cost of the select and scatter function. - HloInstruction* select = instruction->select()->root_instruction(); - HloCostAnalysis select_visitor(shape_size_); - TF_RETURN_IF_ERROR(select->Accept(&select_visitor)); - HloInstruction* scatter = instruction->scatter()->root_instruction(); - HloCostAnalysis scatter_visitor(shape_size_); - TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor)); + // Compute the properties of the select and scatter function. + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties select_properties, + ProcessSubcomputation(instruction->select())); + TF_ASSIGN_OR_RETURN(const Properties scatter_properties, + ProcessSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter - // source element, (window_size - 1) number of select computations and 1 - // scatter computation are applied. + // source element there are window_size - 1 select computations to perform and + // 1 scatter computation to perform. const auto source = instruction->operand(1); const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); - int64 window_size = 1; + int64 window_element_count = 1; for (const auto& dimension : instruction->window().dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 select_count = source_element_count * (window_element_count - 1); + for (const auto& property : select_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += property.second * select_count; + } + } + for (const auto& property : scatter_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += + property.second * source_element_count; + } } - current_flop_count_ = - source_element_count * ((window_size - 1) * select_visitor.flop_count() + - scatter_visitor.flop_count()); - current_transcendental_count_ = - source_element_count * - ((window_size - 1) * select_visitor.transcendental_count() + - scatter_visitor.transcendental_count()); return Status::OK(); } Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { // A bitcast does no computation and touches no memory. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -331,12 +381,13 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const int64 output_features = convolution->shape().dimensions(dnums.feature_dimension()); - // For each output element, we do one fma per element in the - // kernel at some given output feature index. + // For each output element, we do one fma per element in the kernel at some + // given output feature index. const int64 fmas_per_output_element = ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops; + current_properties_[kFlopsKey] = + output_elements * fmas_per_output_element * kFmaFlops; return Status::OK(); } @@ -346,7 +397,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_flop_count_ = ShapeUtil::ElementsIn(crs->shape()); + current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); return Status::OK(); } @@ -355,44 +406,43 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. - current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape()); + current_properties_[kTranscendentalsKey] = + ShapeUtil::ElementsIn(random->shape()); return Status::OK(); } Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { - // Compute the cost of the fused expression. - HloInstruction* fused_expression_root = fusion->fused_expression_root(); - // Don't compute sizes inside of fused ops. We don't use the size here and the - // operations inside might not have a layout. - HloCostAnalysis visitor([](const Shape&) { return 0; }); - TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); - - // If a fusion node produces a tuple, it also produces the operands of that - // tuple. - current_bytes_accessed_ = 0; + // Compute the properties of the fused expression and attribute them to the + // fusion node. Use a dummy shape_size to avoid any errors from trying to + // calculate the size of a shape that does not have a layout, since nodes + // inside fusion nodes do not necessarily have a layout assigned. + ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; + TF_ASSIGN_OR_RETURN( + current_properties_, + ProcessSubcomputation(fusion->fused_instructions_computation(), + &shape_size)); + + // Fusion nodes that produce a tuple also produce the entries in the tuple. + // Ignore the memory accessed inside fused ops, since fusion is supposed to + // prevent intermediate data from touching slow memory. + current_properties_[kBytesAccessedKey] = 0; ShapeUtil::ForEachSubshape( fusion->shape(), [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { - current_bytes_accessed_ += shape_size_(subshape); + current_properties_[kBytesAccessedKey] += shape_size_(subshape); }); for (const HloInstruction* operand : fusion->operands()) { - current_bytes_accessed_ += shape_size_(operand->shape()); + current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); } - // Attribute the cost of the fused expression to the fusion node. - current_transcendental_count_ = visitor.transcendental_count(); - current_flop_count_ = visitor.flop_count(); return Status::OK(); } Status HloCostAnalysis::HandleCall(HloInstruction* call) { - HloCostAnalysis computation_visitor(shape_size_); - TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor)); - - current_flop_count_ = computation_visitor.flop_count(); - current_transcendental_count_ = computation_visitor.transcendental_count(); - current_bytes_accessed_ = computation_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(current_properties_, + ProcessSubcomputation(call->to_apply())); + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -400,34 +450,38 @@ Status HloCostAnalysis::HandleCustomCall( HloInstruction* custom_call, tensorflow::gtl::ArraySlice<HloInstruction*> operands, tensorflow::StringPiece custom_call_target) { - return Unimplemented("custom-call"); + return Unimplemented("Custom-call is not implemented for HLO cost analysis."); } Status HloCostAnalysis::HandleSort(HloInstruction* sort, HloInstruction* operand_instruction) { - // The cost of sort is implementation dependent, so cannot determine at HLO - // level. Assume comparison based N*log(N) sorting. + // This assumes a comparison based N*log(N) algorithm. As for all ops, the + // actual properties of the op depend on the backend implementation. int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape()); - current_flop_count_ = elements * tensorflow::Log2Ceiling(elements); + current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); return Status::OK(); } Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { - // Since the number of iterations of the while node is not statically - // determined, we cannot precisely compute the cost of a while node. For now - // compute the cost of a single iteration. - // TODO(b/26346211): Improve the cost analysis for while node. - HloCostAnalysis body_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor)); - HloCostAnalysis condition_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor)); + // Since the number of iterations of the while node will not always be + // something that we can statically analyze, we cannot precisely compute the + // cost of a while node. For now compute the cost of a single iteration. + // + // TODO(b/26346211): Improve the cost analysis for while nodes. + TF_ASSIGN_OR_RETURN(const Properties body_properties, + ProcessSubcomputation(xla_while->while_body())); - current_flop_count_ = - body_visitor.flop_count() + condition_visitor.flop_count(); - current_transcendental_count_ = body_visitor.transcendental_count() + - condition_visitor.transcendental_count(); - current_bytes_accessed_ = - body_visitor.bytes_accessed() + condition_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(const Properties condition_properties, + ProcessSubcomputation(xla_while->while_condition())); + + current_properties_.clear(); + for (const auto& property : body_properties) { + current_properties_[property.first] += property.second; + } + for (const auto& property : condition_properties) { + current_properties_[property.first] += property.second; + } + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -436,19 +490,42 @@ Status HloCostAnalysis::FinishVisit(HloInstruction* root) { return Status::OK(); } +float HloCostAnalysis::flop_count() const { + return GetProperty(kFlopsKey, properties_sum_); +} + +float HloCostAnalysis::transcendental_count() const { + return GetProperty(kTranscendentalsKey, properties_sum_); +} + +float HloCostAnalysis::bytes_accessed() const { + return GetProperty(kBytesAccessedKey, properties_sum_); +} + +float HloCostAnalysis::seconds() const { + return GetProperty(kSecondsKey, properties_sum_); +} + int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { - auto it = hlo_to_flop_count_.find(&hlo); - return it == hlo_to_flop_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_); } int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { - auto it = hlo_to_transcendental_count_.find(&hlo); - return it == hlo_to_transcendental_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_); } int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { - auto it = hlo_to_bytes_accessed_.find(&hlo); - return it == hlo_to_bytes_accessed_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); +} + +StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation( + HloComputation* computation, const ShapeSizeFunction* shape_size) { + if (shape_size == nullptr) { + shape_size = &shape_size_; + } + HloCostAnalysis visitor(*shape_size, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); } } // namespace xla |