diff options
author | 2017-11-30 11:18:54 -0800 | |
---|---|---|
committer | 2017-11-30 11:22:15 -0800 | |
commit | 4146ff1259c0b4ada8afbbad11a7b37d8373d1b9 (patch) | |
tree | 1a4cb649245215420c7a34ce97506327caa0d1c4 /tensorflow/compiler/xla/service/hlo_cost_analysis.cc | |
parent | ea1c29552b01f3404e27999a27a1919b3accc594 (diff) |
[XLA] Adds Dot with DotDimensionNumbers proto for specifying arbitrary contracting and batch dimensions.
PiperOrigin-RevId: 177481231
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cost_analysis.cc | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6fcc01dd64..0ed64e6779 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; |