aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-06-27 13:33:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-27 13:41:30 -0700
commit00feb62ef5f44c8d3dd32199e552eb5de8049e59 (patch)
treed968e3d68570db313293aae40d2ff035b2269643 /tensorflow/compiler/xla/service/hlo_cost_analysis.cc
parentfebcf3c7b9fddffe9aceaae626f184d5c39b657f (diff)
Add a time estimation to HloCostAnalysis and represent properties as a map so that adding more properties will be easier, e.g. in a sub-class.
PiperOrigin-RevId: 160318494
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc313
1 files changed, 195 insertions, 118 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index f4c0bb8a4a..522dddea4e 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -25,34 +25,56 @@ limitations under the License.
namespace xla {
+constexpr char HloCostAnalysis::kFlopsKey[];
+constexpr char HloCostAnalysis::kTranscendentalsKey[];
+constexpr char HloCostAnalysis::kBytesAccessedKey[];
+constexpr char HloCostAnalysis::kSecondsKey[];
+
+HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
+ : HloCostAnalysis(shape_size, {}) {}
+
+HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
+ const Properties& per_second_rates)
+ : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
+
Status HloCostAnalysis::Preprocess(HloInstruction* hlo) {
// Set current instruction cost values to reasonable default values. Each
- // handler can overwrite these values. In Postprocess, these value are
+ // handler can overwrite these values. In Postprocess, these values are
// accumulated and written to the per-instruction maps.
- current_flop_count_ = 0;
- current_transcendental_count_ = 0;
+ current_properties_.clear();
+ current_should_compute_bottleneck_time_ = true;
- // The default element count for an instruction is the sum of elements in the
- // operands and output. The default ShapeUtil::ByteSizeOf does not handle
- // opaque types.
- current_bytes_accessed_ = shape_size_(hlo->shape());
+ // The default number of bytes accessed for an instruction is the sum of the
+ // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
+ // handle opaque types.
+ float bytes_accessed = shape_size_(hlo->shape());
for (const HloInstruction* operand : hlo->operands()) {
- current_bytes_accessed_ += shape_size_(operand->shape());
+ bytes_accessed += shape_size_(operand->shape());
}
+ current_properties_[kBytesAccessedKey] = bytes_accessed;
return Status::OK();
}
Status HloCostAnalysis::Postprocess(HloInstruction* hlo) {
- // Accumulate cost values and write into per-instruction maps.
- flop_count_ += current_flop_count_;
- hlo_to_flop_count_[hlo] = current_flop_count_;
-
- transcendental_count_ += current_transcendental_count_;
- hlo_to_transcendental_count_[hlo] = current_transcendental_count_;
+ if (current_should_compute_bottleneck_time_) {
+ // Compute the time as the time of the bottleneck, i.e. the slowest property
+ // given the per-second rate of each property.
+ float max_seconds = 0.0f;
+ for (const auto& property : current_properties_) {
+ if (property.first != kSecondsKey) {
+ max_seconds = std::max(
+ max_seconds,
+ property.second / GetProperty(property.first, per_second_rates_));
+ }
+ }
+ current_properties_[kSecondsKey] = max_seconds;
+ }
- bytes_accessed_ += current_bytes_accessed_;
- hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_;
+ TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
+ for (const auto& property : current_properties_) {
+ properties_sum_[property.first] += property.second;
+ }
return Status::OK();
}
@@ -65,15 +87,32 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
auto opcode = hlo_instruction->opcode();
// We treat the two opcodes (kExp, kPower) as transcendental operations.
if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) {
- current_transcendental_count_ = computation_count;
+ current_properties_[kTranscendentalsKey] = computation_count;
} else {
// Note: transcendental operations are considered a separate category from
// FLOPs.
- current_flop_count_ = computation_count;
+ current_properties_[kFlopsKey] = computation_count;
}
return Status::OK();
}
+/*static*/ float HloCostAnalysis::GetProperty(const string& key,
+ const Properties& properties) {
+ auto key_value = properties.find(key);
+ return key_value == properties.end() ? 0.0f : key_value->second;
+}
+
+/*static*/ float HloCostAnalysis::GetPropertyForHlo(
+ const HloInstruction& hlo, const string& key,
+ const HloToProperties& hlo_to_properties) {
+ auto it = hlo_to_properties.find(&hlo);
+ if (it == hlo_to_properties.end()) {
+ return 0.0f;
+ } else {
+ return GetProperty(key, it->second);
+ }
+}
+
Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
HloOpcode opcode) {
return HandleElementwiseOp(hlo);
@@ -102,13 +141,13 @@ Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) {
}
Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) {
- current_bytes_accessed_ = 0;
+ current_properties_[kBytesAccessedKey] = 0;
return Status::OK();
}
Status HloCostAnalysis::HandleConstant(HloInstruction* constant,
const Literal& literal) {
- current_bytes_accessed_ = 0;
+ current_properties_[kBytesAccessedKey] = 0;
return Status::OK();
}
@@ -116,7 +155,7 @@ Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) {
// GetTupleElement forwards a pointer and does not touch each element in the
// output.
- current_bytes_accessed_ = 0;
+ current_properties_[kBytesAccessedKey] = 0;
return Status::OK();
}
@@ -154,8 +193,9 @@ Status HloCostAnalysis::HandleTuple(
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
// The tuple instruction only gathers pointers from inputs (it doesn't iterate
// through them). The memory touched is then only the size of the output
- // buffer.
- current_bytes_accessed_ = shape_size_(tuple->shape());
+ // index table of the tuple.
+
+ current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
return Status::OK();
}
@@ -193,7 +233,7 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot,
}
// We count an FMA operation as 2 floating point operations.
- current_flop_count_ = kFmaFlops * fma_count;
+ current_properties_[kFlopsKey] = kFmaFlops * fma_count;
return Status::OK();
}
@@ -209,16 +249,17 @@ Status HloCostAnalysis::HandleMap(
HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* function,
tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
- // Compute the cost of the user function.
- HloInstruction* function_instruction = function->root_instruction();
- HloCostAnalysis visitor(shape_size_);
- TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+ // Compute properties of the mapped function.
+ TF_ASSIGN_OR_RETURN(const Properties sub_properties,
+ ProcessSubcomputation(function));
// Compute the cost of all elements for this Map operation.
- int64 element_count = ShapeUtil::ElementsIn(map->shape());
- current_transcendental_count_ =
- element_count * visitor.transcendental_count();
- current_flop_count_ = element_count * visitor.flop_count();
+ const int64 element_count = ShapeUtil::ElementsIn(map->shape());
+ for (const auto& property : sub_properties) {
+ if (property.first != kBytesAccessedKey) {
+ current_properties_[property.first] = property.second * element_count;
+ }
+ }
return Status::OK();
}
@@ -226,16 +267,17 @@ Status HloCostAnalysis::HandleReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
// Compute the cost of the user function.
- HloInstruction* function_instruction = function->root_instruction();
- HloCostAnalysis visitor(shape_size_);
- TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+ TF_ASSIGN_OR_RETURN(const Properties sub_properties,
+ ProcessSubcomputation(function));
// Compute the cost of all elements for this Reduce operation.
int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
ShapeUtil::ElementsIn(reduce->shape());
- current_flop_count_ = reduction_count * visitor.flop_count();
- current_transcendental_count_ =
- reduction_count * visitor.transcendental_count();
+ for (const auto& property : sub_properties) {
+ if (property.first != kBytesAccessedKey) {
+ current_properties_[property.first] = property.second * reduction_count;
+ }
+ }
return Status::OK();
}
@@ -243,55 +285,63 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window,
HloInstruction* operand,
const Window& window,
HloComputation* function) {
- // Compute the cost of the user function.
- HloInstruction* function_instruction = function->root_instruction();
- HloCostAnalysis visitor(shape_size_);
- TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+ // Compute the properties of the reduction function.
+ TF_ASSIGN_OR_RETURN(const Properties sub_properties,
+ ProcessSubcomputation(function));
// Compute the cost of all elements for this ReduceWindow operation. For each
- // output element, (window_size - 1) number of user computations are applied.
- auto output_size = ShapeUtil::ElementsIn(reduce_window->shape());
- int64 window_size = 1;
+ // output element there are window_size - 1 reductions to perform.
+ int64 window_element_count = 1;
for (const auto& dimension : window.dimensions()) {
- window_size *= dimension.size();
+ window_element_count *= dimension.size();
+ }
+ const int64 output_element_count =
+ ShapeUtil::ElementsIn(reduce_window->shape());
+ const int64 reduction_count =
+ (window_element_count - 1) * output_element_count;
+ for (const auto& property : sub_properties) {
+ if (property.first != kBytesAccessedKey) {
+ current_properties_[property.first] = property.second * reduction_count;
+ }
}
- current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count();
- current_transcendental_count_ =
- output_size * (window_size - 1) * visitor.transcendental_count();
return Status::OK();
}
Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) {
- // Compute the cost of the select and scatter function.
- HloInstruction* select = instruction->select()->root_instruction();
- HloCostAnalysis select_visitor(shape_size_);
- TF_RETURN_IF_ERROR(select->Accept(&select_visitor));
- HloInstruction* scatter = instruction->scatter()->root_instruction();
- HloCostAnalysis scatter_visitor(shape_size_);
- TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor));
+ // Compute the properties of the select and scatter function.
+ // Compute the properties of the reduction function.
+ TF_ASSIGN_OR_RETURN(const Properties select_properties,
+ ProcessSubcomputation(instruction->select()));
+ TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
+ ProcessSubcomputation(instruction->scatter()));
// Compute the cost of all elements for this operation. For each scatter
- // source element, (window_size - 1) number of select computations and 1
- // scatter computation are applied.
+ // source element there are window_size - 1 select computations to perform and
+ // 1 scatter computation to perform.
const auto source = instruction->operand(1);
const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
- int64 window_size = 1;
+ int64 window_element_count = 1;
for (const auto& dimension : instruction->window().dimensions()) {
- window_size *= dimension.size();
+ window_element_count *= dimension.size();
+ }
+ const int64 select_count = source_element_count * (window_element_count - 1);
+ for (const auto& property : select_properties) {
+ if (property.first != kBytesAccessedKey) {
+ current_properties_[property.first] += property.second * select_count;
+ }
+ }
+ for (const auto& property : scatter_properties) {
+ if (property.first != kBytesAccessedKey) {
+ current_properties_[property.first] +=
+ property.second * source_element_count;
+ }
}
- current_flop_count_ =
- source_element_count * ((window_size - 1) * select_visitor.flop_count() +
- scatter_visitor.flop_count());
- current_transcendental_count_ =
- source_element_count *
- ((window_size - 1) * select_visitor.transcendental_count() +
- scatter_visitor.transcendental_count());
return Status::OK();
}
Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) {
// A bitcast does no computation and touches no memory.
- current_bytes_accessed_ = 0;
+ current_properties_[kBytesAccessedKey] = 0;
return Status::OK();
}
@@ -331,12 +381,13 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution,
const int64 output_features =
convolution->shape().dimensions(dnums.feature_dimension());
- // For each output element, we do one fma per element in the
- // kernel at some given output feature index.
+ // For each output element, we do one fma per element in the kernel at some
+ // given output feature index.
const int64 fmas_per_output_element =
ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features;
const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
- current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops;
+ current_properties_[kFlopsKey] =
+ output_elements * fmas_per_output_element * kFmaFlops;
return Status::OK();
}
@@ -346,7 +397,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) {
//
// TODO(b/33004697): Compute correct cost here, taking the actual number of
// replicas into account.
- current_flop_count_ = ShapeUtil::ElementsIn(crs->shape());
+ current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape());
return Status::OK();
}
@@ -355,44 +406,43 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random,
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
// cost changes with the implementation and the distribution. For now, assume
// the cost of each RNG is same as a transcendental operation.
- current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape());
+ current_properties_[kTranscendentalsKey] =
+ ShapeUtil::ElementsIn(random->shape());
return Status::OK();
}
Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
- // Compute the cost of the fused expression.
- HloInstruction* fused_expression_root = fusion->fused_expression_root();
- // Don't compute sizes inside of fused ops. We don't use the size here and the
- // operations inside might not have a layout.
- HloCostAnalysis visitor([](const Shape&) { return 0; });
- TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor));
-
- // If a fusion node produces a tuple, it also produces the operands of that
- // tuple.
- current_bytes_accessed_ = 0;
+ // Compute the properties of the fused expression and attribute them to the
+ // fusion node. Use a dummy shape_size to avoid any errors from trying to
+ // calculate the size of a shape that does not have a layout, since nodes
+ // inside fusion nodes do not necessarily have a layout assigned.
+ ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
+ TF_ASSIGN_OR_RETURN(
+ current_properties_,
+ ProcessSubcomputation(fusion->fused_instructions_computation(),
+ &shape_size));
+
+ // Fusion nodes that produce a tuple also produce the entries in the tuple.
+ // Ignore the memory accessed inside fused ops, since fusion is supposed to
+ // prevent intermediate data from touching slow memory.
+ current_properties_[kBytesAccessedKey] = 0;
ShapeUtil::ForEachSubshape(
fusion->shape(),
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
- current_bytes_accessed_ += shape_size_(subshape);
+ current_properties_[kBytesAccessedKey] += shape_size_(subshape);
});
for (const HloInstruction* operand : fusion->operands()) {
- current_bytes_accessed_ += shape_size_(operand->shape());
+ current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
}
- // Attribute the cost of the fused expression to the fusion node.
- current_transcendental_count_ = visitor.transcendental_count();
- current_flop_count_ = visitor.flop_count();
return Status::OK();
}
Status HloCostAnalysis::HandleCall(HloInstruction* call) {
- HloCostAnalysis computation_visitor(shape_size_);
- TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor));
-
- current_flop_count_ = computation_visitor.flop_count();
- current_transcendental_count_ = computation_visitor.transcendental_count();
- current_bytes_accessed_ = computation_visitor.bytes_accessed();
+ TF_ASSIGN_OR_RETURN(current_properties_,
+ ProcessSubcomputation(call->to_apply()));
+ current_should_compute_bottleneck_time_ = false;
return Status::OK();
}
@@ -400,34 +450,38 @@ Status HloCostAnalysis::HandleCustomCall(
HloInstruction* custom_call,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) {
- return Unimplemented("custom-call");
+ return Unimplemented("Custom-call is not implemented for HLO cost analysis.");
}
Status HloCostAnalysis::HandleSort(HloInstruction* sort,
HloInstruction* operand_instruction) {
- // The cost of sort is implementation dependent, so cannot determine at HLO
- // level. Assume comparison based N*log(N) sorting.
+ // This assumes a comparison based N*log(N) algorithm. As for all ops, the
+ // actual properties of the op depend on the backend implementation.
int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape());
- current_flop_count_ = elements * tensorflow::Log2Ceiling(elements);
+ current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
return Status::OK();
}
Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) {
- // Since the number of iterations of the while node is not statically
- // determined, we cannot precisely compute the cost of a while node. For now
- // compute the cost of a single iteration.
- // TODO(b/26346211): Improve the cost analysis for while node.
- HloCostAnalysis body_visitor(shape_size_);
- TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor));
- HloCostAnalysis condition_visitor(shape_size_);
- TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor));
+ // Since the number of iterations of the while node will not always be
+ // something that we can statically analyze, we cannot precisely compute the
+ // cost of a while node. For now compute the cost of a single iteration.
+ //
+ // TODO(b/26346211): Improve the cost analysis for while nodes.
+ TF_ASSIGN_OR_RETURN(const Properties body_properties,
+ ProcessSubcomputation(xla_while->while_body()));
- current_flop_count_ =
- body_visitor.flop_count() + condition_visitor.flop_count();
- current_transcendental_count_ = body_visitor.transcendental_count() +
- condition_visitor.transcendental_count();
- current_bytes_accessed_ =
- body_visitor.bytes_accessed() + condition_visitor.bytes_accessed();
+ TF_ASSIGN_OR_RETURN(const Properties condition_properties,
+ ProcessSubcomputation(xla_while->while_condition()));
+
+ current_properties_.clear();
+ for (const auto& property : body_properties) {
+ current_properties_[property.first] += property.second;
+ }
+ for (const auto& property : condition_properties) {
+ current_properties_[property.first] += property.second;
+ }
+ current_should_compute_bottleneck_time_ = false;
return Status::OK();
}
@@ -436,19 +490,42 @@ Status HloCostAnalysis::FinishVisit(HloInstruction* root) {
return Status::OK();
}
+float HloCostAnalysis::flop_count() const {
+ return GetProperty(kFlopsKey, properties_sum_);
+}
+
+float HloCostAnalysis::transcendental_count() const {
+ return GetProperty(kTranscendentalsKey, properties_sum_);
+}
+
+float HloCostAnalysis::bytes_accessed() const {
+ return GetProperty(kBytesAccessedKey, properties_sum_);
+}
+
+float HloCostAnalysis::seconds() const {
+ return GetProperty(kSecondsKey, properties_sum_);
+}
+
int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
- auto it = hlo_to_flop_count_.find(&hlo);
- return it == hlo_to_flop_count_.end() ? 0 : it->second;
+ return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
}
int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
- auto it = hlo_to_transcendental_count_.find(&hlo);
- return it == hlo_to_transcendental_count_.end() ? 0 : it->second;
+ return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
}
int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
- auto it = hlo_to_bytes_accessed_.find(&hlo);
- return it == hlo_to_bytes_accessed_.end() ? 0 : it->second;
+ return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
+}
+
+StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
+ HloComputation* computation, const ShapeSizeFunction* shape_size) {
+ if (shape_size == nullptr) {
+ shape_size = &shape_size_;
+ }
+ HloCostAnalysis visitor(*shape_size, per_second_rates_);
+ TF_RETURN_IF_ERROR(computation->Accept(&visitor));
+ return visitor.properties();
}
} // namespace xla