aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-25 05:19:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 05:23:19 -0700
commite07ddbe64c944f1ea69f23687803d70d450f2ac3 (patch)
tree28027989253484d100d841b02afcb6fe380bf50b /tensorflow/compiler
parent6610b3ec6bdb1a8843070a9fea6e4612681b9318 (diff)
Use correct order of arguments in call of valid_bitcast_callback_.
There are platforms where bitcasts are not symmetric. I.e. there are shapes A and B so that A->B is a bitcast, but B->A not. So we have to consider the correct order when calling valid_bitcast_callback_. PiperOrigin-RevId: 163058665
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h13
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index b351861425..4837402c15 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1488,9 +1488,9 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// We cannot insert bitcasts if the layouts will not be compatible.
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be
// invalid.
- if (!valid_bitcast_callback_(lhs->shape(), input_shape) ||
- !valid_bitcast_callback_(rhs->shape(), new_filter_shape) ||
- !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
+ if (!valid_bitcast_callback_(input_shape, lhs->shape()) ||
+ !valid_bitcast_callback_(new_filter_shape, rhs->shape()) ||
+ !valid_bitcast_callback_(convolution_shape, dot_output_shape)) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index f8919f0caa..4295a3227a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -26,12 +26,13 @@ namespace xla {
// A pass which performs AlgebraicSimplications.
class AlgebraicSimplifier : public HloPassInterface {
public:
- // Given two shapes, determines if it is valid to bitcast between them after
- // considering platform dependent effects on layout like alignment
- // restrictions.
- // Precondition: the two shapes have layouts, the same number of
- // elements and ShapeUtil::ReshapeIsBitcast returns true.
- using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>;
+ // Given shapes 'from_shape' and 'to_shape', determines if it is valid to
+ // bitcast from 'from_shape' to 'to_shape' after considering platform
+ // dependent effects on layout like alignment restrictions. Precondition: the
+ // two shapes have layouts, the same number of elements and
+ // ShapeUtil::ReshapeIsBitcast returns true.
+ using ValidBitcastCallback =
+ std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored. If valid_bitcast_callback