aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc102
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc103
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc10
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md24
5 files changed, 213 insertions, 33 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 3fae61f704..e7674f3ddd 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -194,7 +194,67 @@ Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
return Status::OK();
}
-Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); }
+bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
+ const Shape& shape_1,
+ const Shape& result_shape) {
+ return ShapeUtil::SameElementType(shape_0, shape_1) &&
+ (ShapeUtil::SameElementType(shape_0, result_shape) ||
+ (allow_mixed_precision_ &&
+ ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
+ result_shape)));
+}
+
+Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
+ if (instruction->operand_count() != 2) {
+ return InternalError("Expected two operands for Rng instruction: %s",
+ instruction->ToString().c_str());
+ }
+
+ const Shape& shape_0 = instruction->operand(0)->shape();
+ const Shape& shape_1 = instruction->operand(1)->shape();
+ if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
+ return InternalError(
+ "Expected scalar types for the two operands of Rng instruction: %s",
+ instruction->ToString().c_str());
+ }
+
+ if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
+ return InternalError(
+ "Expected compatible element types for the result and the two operands"
+ " of Rng instruction: %s",
+ instruction->ToString().c_str());
+ }
+
+ PrimitiveType element_type = shape_0.element_type();
+ switch (instruction->random_distribution()) {
+ case RNG_UNIFORM:
+ if (!primitive_util::IsFloatingPointType(element_type) &&
+ !primitive_util::IsIntegralType(element_type) &&
+ element_type != PRED) {
+ return InternalError(
+ "Element type not supported."
+ " Expected element to be of floating point type, integral type or"
+ " predicate type for RngUniform: %s",
+ instruction->ToString().c_str());
+ }
+ break;
+
+ case RNG_NORMAL:
+ if (!primitive_util::IsFloatingPointType(element_type)) {
+ return InternalError(
+ "Element type not supported."
+ " Expected element to be FloatingPointType for RngNormal: %s",
+ instruction->ToString().c_str());
+ }
+ break;
+ default:
+ return InternalError(
+ "Invalid Rng distribution %s",
+ RandomDistribution_Name(instruction->random_distribution()).c_str());
+ }
+
+ return Status::OK();
+}
Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
return CheckShape(
@@ -463,9 +523,9 @@ namespace {
// inputs.
Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
switch (instruction->opcode()) {
- // White list the following opcodes for mixed-precision check, because they
- // involve data pass through or grouping via tuples, where the precisions
- // of buffers can be different.
+ // White list the following opcodes for mixed-precision check, because
+ // they involve data pass through or grouping via tuples, where the
+ // precisions of buffers can be different.
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
@@ -647,7 +707,8 @@ string ComputationsToString(
// Verifies various invariants about the structure of the HLO:
//
-// (1) each instruction has a non-null parent() set to the HloComputation which
+// (1) each instruction has a non-null parent() set to the HloComputation
+// which
// contains it.
//
// (2) each computation has a non-null parent() set to the HloModule which
@@ -681,9 +742,9 @@ Status VerifyHloStructure(HloModule* module) {
}
// Check that operands are in the same computation separately from verifying
- // parent() correctness so conditions like a null HloInstruction::parent() are
- // identified and reported explicitly above rather than reporting a mismatched
- // operand.
+ // parent() correctness so conditions like a null HloInstruction::parent()
+ // are identified and reported explicitly above rather than reporting a
+ // mismatched operand.
for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
for (int i = 0; i < instruction->operand_count(); ++i) {
@@ -707,13 +768,14 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
return InternalError(
- "Instruction of fused computation does not match expected instruction "
+ "Instruction of fused computation does not match expected "
+ "instruction "
"%s.",
fusion->ToString().c_str());
}
- // Fused root instruction and fused parameters must all be owned by the fusion
- // computation.
+ // Fused root instruction and fused parameters must all be owned by the
+ // fusion computation.
bool root_owned = false;
const std::vector<HloInstruction*>& fused_parameters =
fusion->fused_parameters();
@@ -755,8 +817,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
fusion->ToString().c_str());
}
- // All uses of fused instructions must be in the fusion computation, and every
- // non-root instruction must have at least one use.
+ // All uses of fused instructions must be in the fusion computation, and
+ // every non-root instruction must have at least one use.
for (auto* instruction :
fusion->fused_instructions_computation()->instructions()) {
if (instruction != fused_root) {
@@ -800,7 +862,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
if (!ShapeUtil::Compatible(fused_param->shape(),
fusion->operand(param_no)->shape())) {
return InternalError(
- "Shape mismatch between parameter number %lld and its operand in %s.",
+ "Shape mismatch between parameter number %lld and its operand in "
+ "%s.",
param_no, fusion->ToString().c_str());
}
}
@@ -918,8 +981,9 @@ Status CheckSameChannel(const HloInstruction* instr1,
return Status::OK();
}
-// Checks if the given two instructions have the same is_host_transfer attribute
-// value. Intsructions must be send/recv instructions or their 'done' variant.
+// Checks if the given two instructions have the same is_host_transfer
+// attribute value. Intsructions must be send/recv instructions or their
+// 'done' variant.
Status CheckSameIsHostTransfer(const HloInstruction* instr1,
const HloInstruction* instr2) {
const HloSendRecvInstruction* send_recv1 =
@@ -930,7 +994,8 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
TF_RET_CHECK(send_recv2 != nullptr);
if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
return InternalError(
- "Expected instructions to have the same is-host-transfer property: %s, "
+ "Expected instructions to have the same is-host-transfer property: "
+ "%s, "
"%s ",
instr1->ToString().c_str(), instr2->ToString().c_str());
}
@@ -949,7 +1014,8 @@ Status VerifySendsAndRecvs(const HloModule& module) {
host_channels.insert({sendrecv->channel_id(), sendrecv});
if (!it_inserted.second) {
return FailedPrecondition(
- "Channel %lld is used for multiple host send/recv instructions: %s "
+ "Channel %lld is used for multiple host send/recv instructions: "
+ "%s "
"and "
"%s",
sendrecv->channel_id(), sendrecv->ToString().c_str(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 5a56a44f35..c942fab08e 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -106,6 +106,13 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
+ // Return true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand
+ // shapes or mixed precision is allowed and the result shape and the operand
+ // shapes have floating point element types.
+ bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
+ const Shape& result_shape);
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 04c6ba3eeb..d764964f3c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -34,7 +34,17 @@ namespace {
using ::testing::HasSubstr;
-using HloVerifierTest = HloTestBase;
+class HloVerifierTest : public HloTestBase {
+ public:
+ HloVerifierTest()
+ : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+};
+
+class HloVerifierTestAllowMixedPrecision : public HloTestBase {
+ public:
+ HloVerifierTestAllowMixedPrecision()
+ : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+};
TEST_F(HloVerifierTest, NullInstructionParent) {
HloComputation::Builder builder(TestName());
@@ -174,5 +184,96 @@ ENTRY entry {
HasSubstr("shape does not match parameter"));
}
+TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOpnd0NotScalar {
+ constant.0 = f32[] constant(0)
+ constant.1 = f16[2] constant({1, 3})
+ ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
+ distribution=rng_uniform
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
+}
+
+TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngOperandElementTypesNotMatch {
+ constant.0 = f32[] constant(0)
+ constant.1 = f16[] constant(1)
+ ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
+ distribution=rng_normal
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected compatible element types"));
+}
+
+TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngResultElementTypeNotMatch {
+ constant.0 = f32[] constant(0)
+ constant.1 = f32[] constant(1)
+ ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
+ distribution=rng_normal
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected compatible element types"));
+}
+
+TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngResultElementTypeNotMatch {
+ constant.0 = f32[] constant(0)
+ constant.1 = f32[] constant(1)
+ ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
+ distribution=rng_normal
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ ENTRY RngElementTypeNotSupported {
+ constant.0 = s32[] constant(0)
+ constant.1 = s32[] constant(1)
+ ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
+ distribution=rng_normal
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index ad3b662c20..ccb9fb3e3a 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -76,9 +76,13 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
- auto rng0 = builder.AddInstruction(
- HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
- RandomDistribution::RNG_UNIFORM, {}));
+ auto rng0 = builder.AddInstruction(HloInstruction::CreateRng(
+ ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
+ RandomDistribution::RNG_UNIFORM,
+ {builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(1.0f)))}));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0));
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 02af71f8a3..fad9fd57f1 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1877,19 +1877,19 @@ See also
[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Constructs an output of a given shape with random numbers generated following
-the $$N(\mu, \sigma)$$ normal distribution. The parameters `mu` and `sigma`, and
-output shape have to have elemental type F32. The parameters furthermore have to
-be scalar valued.
+the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
+$$\sigma$$, and output shape have to have a floating point elemental type. The
+parameters furthermore have to be scalar valued.
-<b>`RngNormal(mean, sigma, shape)`</b>
+<b>`RngNormal(mu, sigma, shape)`</b>
| Arguments | Type | Semantics |
| --------- | ------- | --------------------------------------------------- |
-| `mu` | `XlaOp` | Scalar of type F32 specifying mean of generated |
-: : : numbers :
-| `sigma` | `XlaOp` | Scalar of type F32 specifying standard deviation of |
+| `mu` | `XlaOp` | Scalar of type T specifying mean of generated |
+: : : numbers :
+| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of |
: : : generated numbers :
-| `shape` | `Shape` | Output shape of type F32 |
+| `shape` | `Shape` | Output shape of type T |
## RngUniform
@@ -1898,9 +1898,11 @@ See also
Constructs an output of a given shape with random numbers generated following
the uniform distribution over the interval $$[a,b)$$. The parameters and output
-shape may be either F32, S32 or U32, but the types have to be consistent.
-Furthermore, the parameters need to be scalar valued. If $$b <= a$$ the result
-is implementation-defined.
+element type have to be a boolean type, an integral type or a floating point
+types, and the types have to be consistent. The CPU and GPU backends currently
+only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
+parameters need to be scalar valued. If $$b <= a$$ the result is
+implementation-defined.
<b>`RngUniform(a, b, shape)`</b>