aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc22
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h14
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc46
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc17
-rw-r--r--tensorflow/compiler/xla/tests/BUILD8
-rw-r--r--tensorflow/compiler/xla/tests/iota_test.cc117
18 files changed, 276 insertions, 77 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5e92df2d63..819d324927 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -466,6 +466,19 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
});
}
+XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape;
+ instr.add_dimensions(iota_dimension);
+ return AddInstruction(std::move(instr), HloOpcode::kIota);
+ });
+}
+
+XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) {
+ return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
+}
+
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -3023,10 +3036,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
}
XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) {
- HloInstructionProto instr;
- *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size});
- return builder->ReportErrorOrReturn(
- builder->AddInstruction(std::move(instr), HloOpcode::kIota));
+ return builder->IotaGen(type, size);
+}
+
+XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
+ return builder->IotaGen(shape, iota_dimension);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index e9d5d3943c..193d8ed071 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -800,6 +800,12 @@ class XlaBuilder {
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
+ // Enqueues an iota operation onto the computation.
+ XlaOp IotaGen(const Shape& shape, int64 iota_dimension);
+
+ // Enqueues a rank-1 iota operation onto the computation.
+ XlaOp IotaGen(PrimitiveType type, int64 size);
+
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand,
@@ -1304,6 +1310,8 @@ class XlaBuilder {
friend XlaOp IsFinite(const XlaOp& operand);
// TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota
// in xla/client/lib/numeric.h with this (renamed to xla::Iota).
+ friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape,
+ int64 iota_dimension);
friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
friend XlaOp ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type);
@@ -1960,6 +1968,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
// entry was NaN.
XlaOp IsFinite(const XlaOp& operand);
+// Enqueues an iota operation onto the computation.
+XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
+
+// Enqueues a rank-1 iota operation onto the computation.
+XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size);
+
// Enqueues a convert instruction onto the computation that changes the
// element type of the operand array to primitive_type.
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 716c75da39..b68785949c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -230,6 +230,7 @@ cc_library(
hdrs = ["hlo_evaluator.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_query",
":shape_inference",
"//tensorflow/compiler/xla:literal",
@@ -2290,6 +2291,7 @@ cc_library(
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2682,6 +2684,7 @@ cc_library(
hdrs = ["elemental_ir_emitter.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_config",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 903e73f606..460363e18f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2437,11 +2437,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
return Status::OK();
}
-Status IrEmitter::HandleIota(HloInstruction* iota) {
- // TODO(b/64798317): implement iota on CPU.
- return Unimplemented("Iota is not implemented on CPU.");
-}
-
Status IrEmitter::HandleRng(HloInstruction* rng) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : rng->operands()) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index ec68710d3f..f98891246b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -157,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConditional(HloInstruction* conditional) override;
Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
- Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
Status FinishVisit(HloInstruction* root) override;
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 61f6055fc9..813e93fafa 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -28,6 +28,8 @@ limitations under the License.
#include "llvm/IR/Intrinsics.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -2095,6 +2097,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
hlo->dimensions(), b_));
};
+ case HloOpcode::kIota:
+ return [this, hlo](
+ const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ auto* iota = Cast<HloIotaInstruction>(hlo);
+ PrimitiveType element_type = iota->shape().element_type();
+ IrArray::Index elem_index =
+ ShapeUtil::Rank(iota->shape()) > 1
+ ? target_index.SourceIndexOfBroadcast(
+ iota->shape(),
+ ShapeUtil::MakeShapeWithDescendingLayout(
+ element_type,
+ {iota->shape().dimensions(iota->iota_dimension())}),
+ {iota->iota_dimension()}, b_)
+ : target_index;
+ llvm::Value* elem_index_linear = elem_index.linear();
+ if (elem_index_linear == nullptr) {
+ std::vector<int64> iota_bound = {
+ iota->shape().dimensions(iota->iota_dimension())};
+ elem_index_linear = elem_index.Linearize(iota_bound, b_);
+ }
+ if (ShapeUtil::ElementIsIntegral(iota->shape())) {
+ return b_->CreateIntCast(
+ elem_index_linear,
+ llvm_ir::PrimitiveTypeToIrType(element_type, module_),
+ /*isSigned=*/false);
+ } else {
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape()))
+ << element_type;
+ llvm::Type* float_ir_type;
+ if (element_type == BF16) {
+ float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
+ } else {
+ float_ir_type =
+ llvm_ir::PrimitiveTypeToIrType(element_type, module_);
+ }
+ llvm::Value* float_val =
+ b_->CreateUIToFP(elem_index_linear, float_ir_type);
+ if (element_type == BF16) {
+ return EmitF32ToBF16(float_val, b_);
+ } else {
+ return float_val;
+ }
+ }
+ };
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index a620cebe04..bdf6aadde6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -746,11 +746,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleIota(HloInstruction*) {
- // TODO(b/64798317): implement iota on GPU.
- return Unimplemented("Iota is not implemented on GPU.");
-}
-
StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index e096a07704..3673b9f58d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -97,7 +97,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
- Status HandleIota(HloInstruction* iota) override;
Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b6566ebefe..f682e69ee9 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -21,7 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
@@ -2493,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::is_same<NativeT, float>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint32>::value>::type* = nullptr>
- Status HandleIota(HloInstruction* iota) {
- auto result = absl::make_unique<Literal>(iota->shape());
- auto data = result->data<ReturnT>();
+ Status HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension()));
std::iota(data.begin(), data.end(), 0);
- parent_->evaluated_[iota] = std::move(result);
+ auto result = LiteralUtil::CreateR1<NativeT>(data);
+
+ if (ShapeUtil::Rank(iota->shape()) > 1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[iota],
+ result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+ } else {
+ TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
+ parent_->evaluated_[iota] = std::move(result);
+ }
+
return Status::OK();
}
template <typename NativeT,
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index c77699a06f..ed4e159910 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -428,6 +428,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computations(0), *scatter_dimension_numbers);
break;
}
+ case HloOpcode::kIota:
+ TF_RET_CHECK(proto.dimensions_size() <= 1)
+ << "Iota instruction should have at most 1 dimension but sees "
+ << proto.dimensions_size();
+ instruction = CreateIota(proto.shape(), proto.dimensions(0));
+ break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -490,8 +496,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
- const Shape& shape) {
- return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape));
+ const Shape& shape, int64 iota_dimension) {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -2053,13 +2059,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
- extra.push_back(
- StrCat("calls=",
- StrJoin(called_computations(), ", ",
- [&](string* out, const HloComputation* computation) {
- StrAppend(out,
- PrintName(computation->name(), options));
- })));
+ extra.push_back(StrCat(
+ "calls=",
+ StrJoin(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, PrintName(computation->name(), options));
+ })));
}
} else if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
@@ -3218,12 +3223,12 @@ HloInstruction::source_target_pairs() const {
}
string HloInstruction::cross_replica_sum_barrier() const {
- return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
}
void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
- return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
- barrier);
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
+ barrier);
}
absl::optional<int64> HloInstruction::all_reduce_id() const {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index b393635e9d..4a424cebc0 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -350,7 +350,8 @@ class HloInstruction {
std::unique_ptr<Literal> literal);
// Creates an Iota instruction.
- static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape);
+ static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
+ int64 iota_dimension);
// Creates a get tuple element instruction.
static std::unique_ptr<HloInstruction> CreateGetTupleElement(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index b93c758937..ffc74cfedd 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2155,4 +2155,34 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
scatter_dimension_numbers());
}
+HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
+ : HloInstruction(HloOpcode::kIota, shape),
+ iota_dimension_(iota_dimension) {}
+
+HloInstructionProto HloIotaInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.add_dimensions(iota_dimension());
+ return proto;
+}
+
+std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("iota_dimension=", iota_dimension())};
+}
+
+bool HloIotaInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
+ return iota_dimension() == casted_other.iota_dimension();
+}
+
+std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 29b187300d..ee6e337b6a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1279,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction {
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
};
+class HloIotaInstruction : public HloInstruction {
+ public:
+ explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ int64 iota_dimension() const { return iota_dimension_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ const int64 iota_dimension_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index c7a766f4e0..eae4508b24 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -562,11 +562,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kIota: {
+ optional<tensorflow::int64> iota_dimension;
+ attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
+ &iota_dimension};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateIota(shape));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index b1ef288b8e..ba07ec432e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1116,7 +1116,7 @@ ENTRY CollectivePermute {
R"(HloModule iota
ENTRY Iota {
- ROOT iota = f32[100]{0} iota()
+ ROOT iota = f32[100]{0} iota(), iota_dimension=0
}
)"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 0ed2c3b449..f1b29c2559 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -265,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
return CheckShape(constant, constant->literal().shape());
}
-Status ShapeVerifier::HandleIota(HloInstruction* iota) {
- return ShapeUtil::Rank(iota->shape()) == 1
- ? Status::OK()
- : InternalError("Iota only supports arrays of rank 1.");
+Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
+ auto* iota = Cast<HloIotaInstruction>(instruction);
+ const int64 rank = ShapeUtil::Rank(iota->shape());
+ if (rank == 0) {
+ return InternalError("Iota does not support scalars.");
+ }
+ int64 iota_dimension = iota->iota_dimension();
+ if (iota_dimension >= rank) {
+ return InternalError(
+ "The iota dimension cannot go beyond the operation rank.");
+ }
+ return Status::OK();
}
Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index d5e3b747e7..a0829b0d02 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2131,19 +2131,13 @@ xla_test(
xla_test(
name = "iota_test",
srcs = ["iota_test.cc"],
- blacklisted_backends = [
- "cpu",
- "gpu",
- ],
+ shard_count = 30,
tags = [
"enable_for_xla_interpreter",
],
deps = [
":client_library_test_base",
- ":literal_test_util",
":xla_internal_test_main",
- "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc
index 17ac95ae01..07c3c6b878 100644
--- a/tensorflow/compiler/xla/tests/iota_test.cc
+++ b/tensorflow/compiler/xla/tests/iota_test.cc
@@ -23,40 +23,95 @@ limitations under the License.
namespace xla {
namespace {
-class IotaTest : public ClientLibraryTestBase {
- public:
- explicit IotaTest(se::Platform* platform = nullptr)
- : ClientLibraryTestBase(platform) {}
- template <typename T>
- std::vector<T> GetExpected(const int64 num_elements) {
- std::vector<T> result(num_elements);
- std::iota(result.begin(), result.end(), 0);
- return result;
+template <typename T>
+std::vector<T> GetR1Expected(const int64 num_elements) {
+ std::vector<T> result(num_elements);
+ std::iota(result.begin(), result.end(), 0);
+ return result;
+}
+
+class IotaR1Test
+ : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<std::tuple<PrimitiveType, int>> {};
+
+TEST_P(IotaR1Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ IotaGen(&builder, element_type, num_elements);
+ if (element_type == F32) {
+ ComputeAndCompareR1<float>(&builder, GetR1Expected<float>(num_elements), {},
+ ErrorSpec{0.0001});
+ } else if (element_type == U32) {
+ ComputeAndCompareR1<uint32>(&builder, GetR1Expected<uint32>(num_elements),
+ {});
+ } else {
+ CHECK_EQ(element_type, S32);
+ ComputeAndCompareR1<int32>(&builder, GetR1Expected<int32>(num_elements),
+ {});
}
-};
-
-XLA_TEST_F(IotaTest, SimpleR1) {
- for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) {
- {
- XlaBuilder builder(TestName() + "_f32");
- IotaGen(&builder, F32, num_elements);
- ComputeAndCompareR1<float>(&builder, GetExpected<float>(num_elements), {},
- ErrorSpec{0.0001});
- }
- {
- XlaBuilder builder(TestName() + "_u32");
- IotaGen(&builder, U32, num_elements);
- ComputeAndCompareR1<uint32>(&builder, GetExpected<uint32>(num_elements),
- {});
- }
- {
- XlaBuilder builder(TestName() + "_s32");
- IotaGen(&builder, S32, num_elements);
- ComputeAndCompareR1<int32>(&builder, GetExpected<int32>(num_elements),
- {});
- }
+}
+
+INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test,
+ ::testing::Combine(::testing::Values(F32, U32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/10001,
+ /*step=*/10)));
+
+class IotaR2Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<PrimitiveType, int, int>> {};
+
+TEST_P(IotaR2Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ const int64 iota_dim = std::get<2>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ std::vector<int64> dimensions = {42};
+ dimensions.insert(dimensions.begin() + iota_dim, num_elements);
+ IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ if (primitive_util::IsFloatingPointType(element_type)) {
+ ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
+ } else {
+ ComputeAndCompare(&builder, {});
}
}
+INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test,
+ ::testing::Combine(::testing::Values(F32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/1001,
+ /*step=*/10),
+ ::testing::Values(0, 1)));
+
+class IotaR3Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<PrimitiveType, int, int>> {};
+
+TEST_P(IotaR3Test, DoIt) {
+ const auto& spec = GetParam();
+ const auto element_type = std::get<0>(spec);
+ const int64 num_elements = std::get<1>(spec);
+ const int64 iota_dim = std::get<2>(spec);
+ XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
+ std::vector<int64> dimensions = {42, 19};
+ dimensions.insert(dimensions.begin() + iota_dim, num_elements);
+ IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
+ if (primitive_util::IsFloatingPointType(element_type)) {
+ ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
+ } else {
+ ComputeAndCompare(&builder, {});
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test,
+ ::testing::Combine(::testing::Values(F32, S32),
+ ::testing::Range(/*start=*/10,
+ /*end=*/1001,
+ /*step=*/10),
+ ::testing::Values(0, 1, 2)));
+
} // namespace
} // namespace xla