diff options
author | 2017-07-25 05:19:15 -0700 | |
---|---|---|
committer | 2017-07-25 05:23:19 -0700 | |
commit | e07ddbe64c944f1ea69f23687803d70d450f2ac3 (patch) | |
tree | 28027989253484d100d841b02afcb6fe380bf50b /tensorflow/compiler | |
parent | 6610b3ec6bdb1a8843070a9fea6e4612681b9318 (diff) |
Use correct order of arguments in call of valid_bitcast_callback_.
There are platforms where bitcasts are not symmetric. I.e. there are shapes A and B so that A->B is a bitcast, but B->A not. So we have to consider the correct order when calling valid_bitcast_callback_.
PiperOrigin-RevId: 163058665
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.h | 13 |
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b351861425..4837402c15 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1488,9 +1488,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(lhs->shape(), input_shape) || - !valid_bitcast_callback_(rhs->shape(), new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!valid_bitcast_callback_(input_shape, lhs->shape()) || + !valid_bitcast_callback_(new_filter_shape, rhs->shape()) || + !valid_bitcast_callback_(convolution_shape, dot_output_shape)) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index f8919f0caa..4295a3227a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -26,12 +26,13 @@ namespace xla { // A pass which performs AlgebraicSimplications. class AlgebraicSimplifier : public HloPassInterface { public: - // Given two shapes, determines if it is valid to bitcast between them after - // considering platform dependent effects on layout like alignment - // restrictions. - // Precondition: the two shapes have layouts, the same number of - // elements and ShapeUtil::ReshapeIsBitcast returns true. - using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>; + // Given shapes 'from_shape' and 'to_shape', determines if it is valid to + // bitcast from 'from_shape' to 'to_shape' after considering platform + // dependent effects on layout like alignment restrictions. Precondition: the + // two shapes have layouts, the same number of elements and + // ShapeUtil::ReshapeIsBitcast returns true. + using ValidBitcastCallback = + std::function<bool(const Shape& from_shape, const Shape& to_shape)>; // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. If valid_bitcast_callback |