aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-29 17:48:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 17:51:38 -0700
commit7cda8c3a8ad528f2e11fc47b0abf08e01f97af45 (patch)
tree4abceb5e1c3ca6692f41d53d71f8e4de1c4108fb /tensorflow
parente528493c8cde468451ba1b1995e649ebe9c29b02 (diff)
[XLA] Switch to using kIota from TF
We were using a broadcast of a constant instead of the kIota HLO. To make switching to kIota practical, we need to do a few things first: - Don't constant fold kIota. - Don't hoist kIota from loops without good cause. PiperOrigin-RevId: 210825834
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc47
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc10
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc14
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h16
-rw-r--r--tensorflow/compiler/xla/literal.cc55
-rw-r--r--tensorflow/compiler/xla/literal.h3
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc8
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc6
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc1
-rw-r--r--tensorflow/compiler/xla/tests/iota_test.cc6
13 files changed, 128 insertions, 89 deletions
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
index 02bed80162..38e440c68d 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -23,53 +23,6 @@ limitations under the License.
namespace xla {
-namespace {
-
-template <typename T>
-XlaOp MakeIota(XlaBuilder* builder, int64 size) {
- std::vector<T> values(size);
- for (int64 i = 0; i < size; ++i) {
- values[i] = static_cast<T>(i);
- }
- return ConstantR1<T>(builder, values);
-}
-
-} // namespace
-
-XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
- switch (type) {
- case S8:
- return MakeIota<int8>(builder, size);
- case S16:
- return MakeIota<int16>(builder, size);
- case S32:
- return MakeIota<int32>(builder, size);
- case S64:
- return MakeIota<int64>(builder, size);
- case U8:
- return MakeIota<uint8>(builder, size);
- case U16:
- return MakeIota<uint16>(builder, size);
- case U32:
- return MakeIota<uint32>(builder, size);
- case U64:
- return MakeIota<uint64>(builder, size);
- case BF16:
- return MakeIota<bfloat16>(builder, size);
- case F16:
- return MakeIota<half>(builder, size);
- case F32:
- return MakeIota<float>(builder, size);
- case F64:
- return MakeIota<double>(builder, size);
- case C64:
- return MakeIota<complex64>(builder, size);
- default:
- return builder->ReportError(InvalidArgument(
- "Unimplemented type for Iota: %s.", PrimitiveType_Name(type)));
- }
-}
-
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
int64 n) {
auto a = Iota(builder, type, m);
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
index 8a96ec68d2..7d6aedd494 100644
--- a/tensorflow/compiler/xla/client/lib/numeric_test.cc
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase {
void TestMatrixDiagonal();
};
-// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to
-// xla::Iota. This test is already implemented for xla::IotaGen in
-// xla/tests/iota_test.cc.
-XLA_TEST_F(NumericTest, Iota) {
- XlaBuilder builder(TestName());
- Iota(&builder, S32, 10);
-
- ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
-}
-
XLA_TEST_F(NumericTest, Triangle) {
XlaBuilder builder(TestName());
Array3D<int32> input(2, 3, 4);
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index ea53287068..531b8dd66b 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -466,7 +466,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
});
}
-XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) {
+XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape;
@@ -475,8 +475,8 @@ XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) {
});
}
-XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) {
- return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
+XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
+ return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
@@ -3063,12 +3063,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
grad_output, epsilon, feature_index);
}
-XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
- return builder->IotaGen(type, size);
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ return builder->Iota(type, size);
}
-XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
- return builder->IotaGen(shape, iota_dimension);
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
+ return builder->Iota(shape, iota_dimension);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 9b82cc03b3..b9e651f2ae 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -808,10 +808,10 @@ class XlaBuilder {
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.
- XlaOp IotaGen(const Shape& shape, int64 iota_dimension);
+ XlaOp Iota(const Shape& shape, int64 iota_dimension);
// Enqueues a rank-1 iota operation onto the computation.
- XlaOp IotaGen(PrimitiveType type, int64 size);
+ XlaOp Iota(PrimitiveType type, int64 size);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
@@ -1320,11 +1320,9 @@ class XlaBuilder {
friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp IsFinite(const XlaOp& operand);
- // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
- // in xla/client/lib/numeric.h with this (renamed to xla::Iota).
- friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape,
- int64 iota_dimension);
- friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+ friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
+ int64 iota_dimension);
+ friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
friend XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
friend XlaOp BitcastConvertType(const XlaOp& operand,
@@ -1988,10 +1986,10 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
XlaOp IsFinite(const XlaOp& operand);
// Enqueues an iota operation onto the computation.
-XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
+XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
// Enqueues a rank-1 iota operation onto the computation.
-XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 93e808469a..3dd0abee79 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -1687,6 +1687,61 @@ bool LiteralBase::IsAllFirst() const {
});
}
+bool LiteralBase::IsR1Iota() const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return false;
+ }
+
+ if (ShapeUtil::Rank(shape()) != 1) {
+ return false;
+ }
+
+ auto is_iota_at_idx = [&](const int64 idx) {
+ switch (shape().element_type()) {
+ case U8:
+ return Get<uint8>({idx}) == idx;
+ case U16:
+ return Get<uint16>({idx}) == idx;
+ case U32:
+ return Get<uint32>({idx}) == idx;
+ case U64:
+ return Get<uint64>({idx}) == idx;
+ case S8:
+ return Get<int8>({idx}) == idx;
+ case S16:
+ return Get<int16>({idx}) == idx;
+ case S32:
+ return Get<int32>({idx}) == idx;
+ case S64:
+ return Get<int64>({idx}) == idx;
+ case F32:
+ return Get<float>({idx}) == idx;
+ case F64:
+ return Get<double>({idx}) == idx;
+ case F16:
+ return Get<half>({idx}) == static_cast<half>(idx);
+ case BF16:
+ return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
+ case C64:
+ return Get<complex64>({idx}) == complex64(idx, 0.0f);
+ case PRED:
+ return Get<bool>({idx}) == idx;
+ // token, opaque, tuple, etc. are all not iota.
+ default:
+ return false;
+ }
+ };
+
+ const int64 elements = ShapeUtil::ElementsIn(shape());
+ for (int64 idx = 0; idx < elements; ++idx) {
+ if (!is_iota_at_idx(idx)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index aad435ed5b..8370043da1 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -195,6 +195,9 @@ class LiteralBase {
// Literal consists entirely of the first element of the literal.
bool IsAllFirst() const;
+ // Literal consists entirely of an iota.
+ bool IsR1Iota() const;
+
// Returns whether this literal is zero at the specified index. This literal
// must be an array with a dense layout.
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 19bb4da9a6..196865f333 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -553,6 +553,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant,
HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
}
+
+ // If a literal is an increasing sequence from zero, replace it with an iota.
+ if (ShapeUtil::Rank(constant->shape()) == 1 &&
+ ShapeUtil::ElementsIn(constant->shape()) > 1 &&
+ constant->literal().IsR1Iota()) {
+ return ReplaceWithNewInstruction(
+ constant, HloInstruction::CreateIota(constant->shape(), 0));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 1900a05750..917ed86b69 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -298,6 +298,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
EXPECT_THAT(root, op::Constant());
}
+TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
+ HloComputation::Builder builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Iota());
+}
+
// Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, SubZero) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -521,7 +536,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 813e93fafa..def42f9c77 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2117,29 +2117,40 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
iota->shape().dimensions(iota->iota_dimension())};
elem_index_linear = elem_index.Linearize(iota_bound, b_);
}
- if (ShapeUtil::ElementIsIntegral(iota->shape())) {
- return b_->CreateIntCast(
+ Shape component_shape =
+ ShapeUtil::ElementIsComplex(iota->shape())
+ ? ShapeUtil::ComplexComponentShape(iota->shape())
+ : iota->shape();
+ PrimitiveType component_element_type = component_shape.element_type();
+ llvm::Value* iota_result;
+ if (ShapeUtil::ElementIsIntegral(component_shape)) {
+ iota_result = b_->CreateIntCast(
elem_index_linear,
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
/*isSigned=*/false);
} else {
- TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape()))
- << element_type;
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape))
+ << component_element_type;
llvm::Type* float_ir_type;
- if (element_type == BF16) {
+ if (component_element_type == BF16) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
} else {
float_ir_type =
- llvm_ir::PrimitiveTypeToIrType(element_type, module_);
+ llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
}
llvm::Value* float_val =
b_->CreateUIToFP(elem_index_linear, float_ir_type);
- if (element_type == BF16) {
- return EmitF32ToBF16(float_val, b_);
+ if (component_element_type == BF16) {
+ iota_result = EmitF32ToBF16(float_val, b_);
} else {
- return float_val;
+ iota_result = float_val;
}
}
+ if (ShapeUtil::ElementIsComplex(iota->shape())) {
+ return EmitComposeComplex(iota, iota_result, nullptr);
+ } else {
+ return iota_result;
+ }
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 2ed645c3ae..8a45939c61 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -71,7 +71,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
- if (instruction->opcode() == HloOpcode::kBroadcast) {
+ if (instruction->opcode() == HloOpcode::kBroadcast ||
+ instruction->opcode() == HloOpcode::kIota) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index b9244b8e9e..72006e17e7 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -151,7 +151,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
}
TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
- if (!HasOperandType(hlo, eliminate_type_)) {
+ bool nullary = hlo->operands().empty();
+ bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
+ bool should_eliminate_type = (nullary && wrong_element_type) ||
+ HasOperandType(hlo, eliminate_type_);
+ if (!should_eliminate_type) {
// If this CHECK fires, then this was an instruction that does not take
// the elimination type as an operand but it does return it. This pass
// does not have a feature to change the output type in that case, so
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index f4098f28b3..e8fe33e626 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -110,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
+ case HloOpcode::kIota:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc
index 07c3c6b878..310f349592 100644
--- a/tensorflow/compiler/xla/tests/iota_test.cc
+++ b/tensorflow/compiler/xla/tests/iota_test.cc
@@ -39,7 +39,7 @@ TEST_P(IotaR1Test, DoIt) {
const auto element_type = std::get<0>(spec);
const int64 num_elements = std::get<1>(spec);
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
- IotaGen(&builder, element_type, num_elements);
+ Iota(&builder, element_type, num_elements);
if (element_type == F32) {
ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {},
ErrorSpec{0.0001});
@@ -71,7 +71,7 @@ TEST_P(IotaR2Test, DoIt) {
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
std::vector<int64> dimensions = {42};
dimensions.insert(dimensions.begin() + iota_dim, num_elements);
- IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
if (primitive_util::IsFloatingPointType(element_type)) {
ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
} else {
@@ -98,7 +98,7 @@ TEST_P(IotaR3Test, DoIt) {
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
std::vector<int64> dimensions = {42, 19};
dimensions.insert(dimensions.begin() + iota_dim, num_elements);
- IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
if (primitive_util::IsFloatingPointType(element_type)) {
ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
} else {