diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 102 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier_test.cc | 103 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/reshape_mover_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 24 |
5 files changed, 213 insertions, 33 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 3fae61f704..e7674f3ddd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -194,7 +194,67 @@ Status ShapeVerifier::HandleHostCompute(HloInstruction*) { return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } +bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, + const Shape& shape_1, + const Shape& result_shape) { + return ShapeUtil::SameElementType(shape_0, shape_1) && + (ShapeUtil::SameElementType(shape_0, result_shape) || + (allow_mixed_precision_ && + ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0, + result_shape))); +} + +Status ShapeVerifier::HandleRng(HloInstruction* instruction) { + if (instruction->operand_count() != 2) { + return InternalError("Expected two operands for Rng instruction: %s", + instruction->ToString().c_str()); + } + + const Shape& shape_0 = instruction->operand(0)->shape(); + const Shape& shape_1 = instruction->operand(1)->shape(); + if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { + return InternalError( + "Expected scalar types for the two operands of Rng instruction: %s", + instruction->ToString().c_str()); + } + + if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { + return InternalError( + "Expected compatible element types for the result and the two operands" + " of Rng instruction: %s", + instruction->ToString().c_str()); + } + + PrimitiveType element_type = shape_0.element_type(); + switch (instruction->random_distribution()) { + case RNG_UNIFORM: + if (!primitive_util::IsFloatingPointType(element_type) && + !primitive_util::IsIntegralType(element_type) && + element_type != PRED) { + return InternalError( + "Element type not supported." + " Expected element to be of floating point type, integral type or" + " predicate type for RngUniform: %s", + instruction->ToString().c_str()); + } + break; + + case RNG_NORMAL: + if (!primitive_util::IsFloatingPointType(element_type)) { + return InternalError( + "Element type not supported." + " Expected element to be FloatingPointType for RngNormal: %s", + instruction->ToString().c_str()); + } + break; + default: + return InternalError( + "Invalid Rng distribution %s", + RandomDistribution_Name(instruction->random_distribution()).c_str()); + } + + return Status::OK(); +} Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -463,9 +523,9 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because they - // involve data pass through or grouping via tuples, where the precisions - // of buffers can be different. + // White list the following opcodes for mixed-precision check, because + // they involve data pass through or grouping via tuples, where the + // precisions of buffers can be different. case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: @@ -647,7 +707,8 @@ string ComputationsToString( // Verifies various invariants about the structure of the HLO: // -// (1) each instruction has a non-null parent() set to the HloComputation which +// (1) each instruction has a non-null parent() set to the HloComputation +// which // contains it. // // (2) each computation has a non-null parent() set to the HloModule which @@ -681,9 +742,9 @@ Status VerifyHloStructure(HloModule* module) { } // Check that operands are in the same computation separately from verifying - // parent() correctness so conditions like a null HloInstruction::parent() are - // identified and reported explicitly above rather than reporting a mismatched - // operand. + // parent() correctness so conditions like a null HloInstruction::parent() + // are identified and reported explicitly above rather than reporting a + // mismatched operand. for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { for (int i = 0; i < instruction->operand_count(); ++i) { @@ -707,13 +768,14 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { return InternalError( - "Instruction of fused computation does not match expected instruction " + "Instruction of fused computation does not match expected " + "instruction " "%s.", fusion->ToString().c_str()); } - // Fused root instruction and fused parameters must all be owned by the fusion - // computation. + // Fused root instruction and fused parameters must all be owned by the + // fusion computation. bool root_owned = false; const std::vector<HloInstruction*>& fused_parameters = fusion->fused_parameters(); @@ -755,8 +817,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->ToString().c_str()); } - // All uses of fused instructions must be in the fusion computation, and every - // non-root instruction must have at least one use. + // All uses of fused instructions must be in the fusion computation, and + // every non-root instruction must have at least one use. for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { @@ -800,7 +862,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( - "Shape mismatch between parameter number %lld and its operand in %s.", + "Shape mismatch between parameter number %lld and its operand in " + "%s.", param_no, fusion->ToString().c_str()); } } @@ -918,8 +981,9 @@ Status CheckSameChannel(const HloInstruction* instr1, return Status::OK(); } -// Checks if the given two instructions have the same is_host_transfer attribute -// value. Intsructions must be send/recv instructions or their 'done' variant. +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. Status CheckSameIsHostTransfer(const HloInstruction* instr1, const HloInstruction* instr2) { const HloSendRecvInstruction* send_recv1 = @@ -930,7 +994,8 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, TF_RET_CHECK(send_recv2 != nullptr); if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { return InternalError( - "Expected instructions to have the same is-host-transfer property: %s, " + "Expected instructions to have the same is-host-transfer property: " + "%s, " "%s ", instr1->ToString().c_str(), instr2->ToString().c_str()); } @@ -949,7 +1014,8 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: %s " + "Channel %lld is used for multiple host send/recv instructions: " + "%s " "and " "%s", sendrecv->channel_id(), sendrecv->ToString().c_str(), diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 5a56a44f35..c942fab08e 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -106,6 +106,13 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: + // Return true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand + // shapes or mixed precision is allowed and the result shape and the operand + // shapes have floating point element types. + bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, + const Shape& result_shape); + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 04c6ba3eeb..d764964f3c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -34,7 +34,17 @@ namespace { using ::testing::HasSubstr; -using HloVerifierTest = HloTestBase; +class HloVerifierTest : public HloTestBase { + public: + HloVerifierTest() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} +}; + +class HloVerifierTestAllowMixedPrecision : public HloTestBase { + public: + HloVerifierTestAllowMixedPrecision() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} +}; TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); @@ -174,5 +184,96 @@ ENTRY entry { HasSubstr("shape does not match parameter")); } +TEST_F(HloVerifierTest, RngOpnd0NotScalar) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOpnd0NotScalar { + constant.0 = f32[] constant(0) + constant.1 = f16[2] constant({1, 3}) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1), + distribution=rng_uniform + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type")); +} + +TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f16[] constant(1) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, RngElementTypeNotSupported) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngElementTypeNotSupported { + constant.0 = s32[] constant(0) + constant.1 = s32[] constant(1) + ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ad3b662c20..ccb9fb3e3a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -76,9 +76,13 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); - auto rng0 = builder.AddInstruction( - HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), - RandomDistribution::RNG_UNIFORM, {})); + auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, + {builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0<float>(1.0f)))})); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 02af71f8a3..fad9fd57f1 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1877,19 +1877,19 @@ See also [`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Constructs an output of a given shape with random numbers generated following -the $$N(\mu, \sigma)$$ normal distribution. The parameters `mu` and `sigma`, and -output shape have to have elemental type F32. The parameters furthermore have to -be scalar valued. +the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and +$$\sigma$$, and output shape have to have a floating point elemental type. The +parameters furthermore have to be scalar valued. -<b>`RngNormal(mean, sigma, shape)`</b> +<b>`RngNormal(mu, sigma, shape)`</b> | Arguments | Type | Semantics | | --------- | ------- | --------------------------------------------------- | -| `mu` | `XlaOp` | Scalar of type F32 specifying mean of generated | -: : : numbers : -| `sigma` | `XlaOp` | Scalar of type F32 specifying standard deviation of | +| `mu` | `XlaOp` | Scalar of type T specifying mean of generated | +: : : numbers : +| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of | : : : generated numbers : -| `shape` | `Shape` | Output shape of type F32 | +| `shape` | `Shape` | Output shape of type T | ## RngUniform @@ -1898,9 +1898,11 @@ See also Constructs an output of a given shape with random numbers generated following the uniform distribution over the interval $$[a,b)$$. The parameters and output -shape may be either F32, S32 or U32, but the types have to be consistent. -Furthermore, the parameters need to be scalar valued. If $$b <= a$$ the result -is implementation-defined. +element type have to be a boolean type, an integral type or a floating point +types, and the types have to be consistent. The CPU and GPU backends currently +only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the +parameters need to be scalar valued. If $$b <= a$$ the result is +implementation-defined. <b>`RngUniform(a, b, shape)`</b> |