aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 11:18:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 11:22:15 -0800
commit4146ff1259c0b4ada8afbbad11a7b37d8373d1b9 (patch)
tree1a4cb649245215420c7a34ce97506327caa0d1c4 /tensorflow/compiler/xla/service/hlo_cost_analysis.cc
parentea1c29552b01f3404e27999a27a1919b3accc594 (diff)
[XLA] Adds Dot with DotDimensionNumbers proto for specifying arbitrary contracting and batch dimensions.
PiperOrigin-RevId: 177481231
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cost_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 6fcc01dd64..0ed64e6779 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
// Count of elements along the reduction dimension (last dimension for the
// rhs).
- int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1);
-
+ int64 reduction_width =
+ lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0));
// First divide by reduction width before multiplying by rhs elements to avoid
// overflow.
int64 fma_count;