diff options
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/while_transformer_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 690 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.h | 92 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_test_base.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_verified_test_base.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/hlo_verified_test_base.h | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 8 |
10 files changed, 455 insertions, 411 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 9636f6b5b3..f0507982b3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -234,7 +234,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction()); + pipeline.AddInvariantChecker<HloVerifier>(); pipeline.AddPass<CpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( @@ -253,7 +253,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction()); + pass.AddInvariantChecker<HloVerifier>(); pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 7c2e693560..af0010e207 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -133,12 +133,10 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker<HloVerifier>(shape_size_function); + pipeline.AddInvariantChecker<HloVerifier>(); pipeline.AddPass<GpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -150,7 +148,7 @@ tensorflow::Status OptimizeHloModule( { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(shape_size_function); + pass.AddInvariantChecker<HloVerifier>(); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -191,14 +189,14 @@ tensorflow::Status OptimizeHloModule( } { HloPassFix<HloPassPipeline> fusion("fusion"); - fusion.AddInvariantChecker<HloVerifier>(shape_size_function); + fusion.AddInvariantChecker<HloVerifier>(); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); fusion.AddPass<FusionMerger>(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker<HloVerifier>(shape_size_function); + reduce_pipeline.AddInvariantChecker<HloVerifier>(); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -216,16 +214,14 @@ tensorflow::Status OptimizeHloModule( // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker<HloVerifier>(shape_size_function); + pipeline.AddInvariantChecker<HloVerifier>(); pipeline.AddPass<PadInsertion>(); pipeline.AddPass<GpuLayoutAssignment>( hlo_module->mutable_entry_computation_layout()); @@ -409,7 +405,7 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses( XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get())); return std::move(module); } @@ -419,8 +415,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend( TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR( - PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index f16daa0b54..2f290f61bd 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -117,9 +117,7 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier([](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); - }); + HloVerifier verifier; TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9d5ca6673a..9d9cf0c0f6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -14,425 +14,400 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { -namespace { +Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { + return CheckUnaryShape(hlo); +} -// Visitor which verifies that the output shape is correctly set. Verifies -// against the inferred shape for the instruction. -// TODO(b/26024837): Check output shape for all instruction types. -class ShapeVerifier : public DfsHloVisitor { - public: - explicit ShapeVerifier( - const std::function<int64(const Shape&)>& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} +Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) { + return CheckBinaryShape(hlo); +} - Status HandleElementwiseUnary(HloInstruction* hlo) override { - return CheckUnaryShape(hlo); - } +Status ShapeVerifier::HandleClamp(HloInstruction* clamp) { + return CheckTernaryShape(clamp); +} - Status HandleElementwiseBinary(HloInstruction* hlo) override { - return CheckBinaryShape(hlo); - } +Status ShapeVerifier::HandleSelect(HloInstruction* select) { + return CheckTernaryShape(select); +} - Status HandleClamp(HloInstruction* clamp) override { - return CheckTernaryShape(clamp); +Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { + std::vector<const Shape*> operand_shapes; + for (const HloInstruction* operand : concatenate->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(concatenate, + ShapeInference::InferConcatOpShape( + operand_shapes, concatenate->concatenate_dimension())); +} - Status HandleSelect(HloInstruction* select) override { - return CheckTernaryShape(select); - } +Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleConcatenate(HloInstruction* concatenate) override { - std::vector<const Shape*> operand_shapes; - for (const HloInstruction* operand : concatenate->operands()) { - operand_shapes.push_back(&operand->shape()); - } - return CheckShape( - concatenate, ShapeInference::InferConcatOpShape( - operand_shapes, concatenate->concatenate_dimension())); - } +Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleCopy(HloInstruction* copy) { + return CheckUnaryShape(copy); +} - Status HandleBitcastConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferBitcastConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); +} - Status HandleCopy(HloInstruction* copy) override { - return CheckUnaryShape(copy); - } +Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferConvolveShape( + convolution->operand(0)->shape(), convolution->operand(1)->shape(), + convolution->window(), convolution->convolution_dimension_numbers())); + return CheckShape(convolution, expected); +} - Status HandleDot(HloInstruction* dot) override { - TF_ASSIGN_OR_RETURN(const Shape expected, - ShapeInference::InferDotOpShape( - dot->operand(0)->shape(), dot->operand(1)->shape(), - dot->dot_dimension_numbers())); - return CheckShape(dot, expected); - } +Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), + fft->fft_length())); + return CheckShape(fft, expected); +} - Status HandleConvolution(HloInstruction* convolution) override { - TF_ASSIGN_OR_RETURN( - const Shape expected, - ShapeInference::InferConvolveShape( - convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), - convolution->convolution_dimension_numbers())); - return CheckShape(convolution, expected); +Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { + std::vector<const Shape*> operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(crs, + ShapeInference::InferCrossReplicaSumShape(operand_shapes)); +} - Status HandleFft(HloInstruction* fft) override { - TF_ASSIGN_OR_RETURN( - const Shape expected, - ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), - fft->fft_length())); - return CheckShape(fft, expected); - } +Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( + reduce_precision->operand(0)->shape(), + reduce_precision->exponent_bits(), + reduce_precision->mantissa_bits())); +} - Status HandleCrossReplicaSum(HloInstruction* crs) override { - std::vector<const Shape*> operand_shapes; - for (const HloInstruction* operand : crs->operands()) { - operand_shapes.push_back(&operand->shape()); - } - return CheckShape( - crs, ShapeInference::InferCrossReplicaSumShape(operand_shapes)); - } +Status ShapeVerifier::HandleInfeed(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return CheckShape(reduce_precision, - ShapeInference::InferReducePrecisionShape( - reduce_precision->operand(0)->shape(), - reduce_precision->exponent_bits(), - reduce_precision->mantissa_bits())); - } +Status ShapeVerifier::HandleOutfeed(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleInfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleRng(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleOutfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + return CheckShape( + reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), + reverse->dimensions())); +} - Status HandleRng(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleSort(HloInstruction* sort) { + return CheckUnaryShape(sort); +} - Status HandleReverse(HloInstruction* reverse) override { - return CheckShape( - reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), - reverse->dimensions())); - } +Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + return CheckShape(constant, constant->literal().shape()); +} - Status HandleSort(HloInstruction* sort) override { - return CheckUnaryShape(sort); - } +Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + return CheckShape(get_tuple_element, + ShapeInference::InferGetTupleElementShape( + get_tuple_element->operand(0)->shape(), + get_tuple_element->tuple_index())); +} - Status HandleConstant(HloInstruction* constant) override { - return CheckShape(constant, constant->literal().shape()); - } +Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + return CheckShape( + reduce, + ShapeInference::InferReduceShape( + reduce->operand(0)->shape(), reduce->operand(1)->shape(), + reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); +} - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override { - return CheckShape(get_tuple_element, - ShapeInference::InferGetTupleElementShape( - get_tuple_element->operand(0)->shape(), - get_tuple_element->tuple_index())); - } +Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + return tensorflow::Status::OK(); +} - Status HandleReduce(HloInstruction* reduce) override { - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); +Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + // HLO broadcast has no exact analog at the proto level so there is no + // ShapeInference method. Check the output shape explicitly. + const Shape& operand_shape = broadcast->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == + broadcast->dimensions().size()); + for (int64 operand_dimension = 0; + operand_dimension < ShapeUtil::Rank(operand_shape); + ++operand_dimension) { + int64 output_dimension = broadcast->dimensions()[operand_dimension]; + TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension)); } + return tensorflow::Status::OK(); +} - Status HandleBitcast(HloInstruction* bitcast) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == + ShapeUtil::ElementsIn(reshape->operand(0)->shape())); + return tensorflow::Status::OK(); +} - Status HandleBroadcast(HloInstruction* broadcast) override { - // HLO broadcast has no exact analog at the proto level so there is no - // ShapeInference method. Check the output shape explicitly. - const Shape& operand_shape = broadcast->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); - ++operand_dimension) { - int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)); - } - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + return CheckShape( + transpose, ShapeInference::InferTransposeShape( + transpose->operand(0)->shape(), transpose->dimensions())); +} - Status HandleReshape(HloInstruction* reshape) override { - TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == - ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleParameter(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleTranspose(HloInstruction* transpose) override { - return CheckShape(transpose, ShapeInference::InferTransposeShape( - transpose->operand(0)->shape(), - transpose->dimensions())); - } +Status ShapeVerifier::HandleFusion(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleParameter(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleCall(HloInstruction* call) { + // The shape of kCall should match the shape of the computation it calls. + return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); +} - Status HandleFusion(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleCall(HloInstruction* call) override { - // The shape of kCall should match the shape of the computation it calls. - return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + return CheckShape(slice, + ShapeInference::InferSliceShape( + slice->operand(0)->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); +} - Status HandleCustomCall(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + dynamic_slice->operand(1)->shape(), + dynamic_slice->dynamic_slice_sizes())); +} - Status HandleSlice(HloInstruction* slice) override { - return CheckShape(slice, - ShapeInference::InferSliceShape( - slice->operand(0)->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - } +Status ShapeVerifier::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) { + return CheckShape(dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + dynamic_update_slice->operand(2)->shape())); +} - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); - } +Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { + return CheckVariadicShape(tuple); +} - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); - } +Status ShapeVerifier::HandleMap(HloInstruction* map) { + std::vector<const Shape*> operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : map->operands()) { + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + std::vector<int64> map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + return CheckShape(map, ShapeInference::InferMapShape( + operand_shapes, + map->to_apply()->ComputeProgramShape(), map_dims)); +} - Status HandleTuple(HloInstruction* tuple) override { - return CheckVariadicShape(tuple); - } +Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + return CheckShape( + reduce_window, + ShapeInference::InferReduceWindowShape( + reduce_window->operand(0)->shape(), + reduce_window->operand(1)->shape(), reduce_window->window(), + reduce_window->to_apply()->ComputeProgramShape())); +} - Status HandleMap(HloInstruction* map) override { - std::vector<const Shape*> operand_shapes; - int64 max_operand_rank = 0; - for (const HloInstruction* operand : map->operands()) { - operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); - } - // TODO(b/65689298) Remove code below once Map is generalized to accept - // arbitrary map dimensions. - std::vector<int64> map_dims(max_operand_rank); - std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape( - map, - ShapeInference::InferMapShape( - operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)); - } +Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + return CheckShape( + instruction, + ShapeInference::InferSelectAndScatterShape( + instruction->operand(0)->shape(), + instruction->select()->ComputeProgramShape(), instruction->window(), + instruction->operand(1)->shape(), instruction->operand(2)->shape(), + instruction->scatter()->ComputeProgramShape())); +} - Status HandleReduceWindow(HloInstruction* reduce_window) override { - return CheckShape( - reduce_window, - ShapeInference::InferReduceWindowShape( - reduce_window->operand(0)->shape(), - reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + // The shape of kWhile should match the shape of the body computation it + // calls. + return CheckShape(xla_while, + xla_while->while_body()->ComputeProgramShape().result()); +} - Status HandleSelectAndScatter(HloInstruction* instruction) override { - return CheckShape( - instruction, - ShapeInference::InferSelectAndScatterShape( - instruction->operand(0)->shape(), - instruction->select()->ComputeProgramShape(), instruction->window(), - instruction->operand(1)->shape(), instruction->operand(2)->shape(), - instruction->scatter()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); +} - Status HandleWhile(HloInstruction* xla_while) override { - // The shape of kWhile should match the shape of the body computation it - // calls. - return CheckShape(xla_while, - xla_while->while_body()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandlePad(HloInstruction* pad) { + return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), + pad->operand(1)->shape(), + pad->padding_config())); +} - Status HandleConditional(HloInstruction* conditional) override { - TF_RETURN_IF_ERROR(CheckShape( - conditional, - conditional->true_computation()->ComputeProgramShape().result())); - return CheckShape( - conditional, - conditional->false_computation()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandlePad(HloInstruction* pad) override { - return CheckShape(pad, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - pad->padding_config())); - } +Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); +} - Status HandleSend(HloInstruction* send) override { - TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandleSendDone(HloInstruction* send_done) override { - TF_RET_CHECK(send_done->operands().size() == 1); - const HloInstruction* send = send_done->operand(0); - TF_RET_CHECK(send->opcode() == HloOpcode::kSend); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); - } +Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); +} - Status HandleRecv(HloInstruction* recv) override { - TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleBatchNormTraining( + HloInstruction* batch_norm_training) { + return CheckShape(batch_norm_training, + ShapeInference::InferBatchNormTrainingShape( + batch_norm_training->operand(0)->shape(), + batch_norm_training->operand(1)->shape(), + batch_norm_training->operand(2)->shape(), + batch_norm_training->feature_index())); +} - Status HandleRecvDone(HloInstruction* recv_done) override { - TF_RET_CHECK(recv_done->operands().size() == 1); - const HloInstruction* recv = recv_done->operand(0); - TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); - } +Status ShapeVerifier::HandleBatchNormInference( + HloInstruction* batch_norm_inference) { + return CheckShape(batch_norm_inference, + ShapeInference::InferBatchNormInferenceShape( + batch_norm_inference->operand(0)->shape(), + batch_norm_inference->operand(1)->shape(), + batch_norm_inference->operand(2)->shape(), + batch_norm_inference->operand(3)->shape(), + batch_norm_inference->operand(4)->shape(), + batch_norm_inference->feature_index())); +} - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { - return CheckShape(batch_norm_training, - ShapeInference::InferBatchNormTrainingShape( - batch_norm_training->operand(0)->shape(), - batch_norm_training->operand(1)->shape(), - batch_norm_training->operand(2)->shape(), - batch_norm_training->feature_index())); - } +Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( + batch_norm_grad->operand(0)->shape(), + batch_norm_grad->operand(1)->shape(), + batch_norm_grad->operand(2)->shape(), + batch_norm_grad->operand(3)->shape(), + batch_norm_grad->operand(4)->shape(), + batch_norm_grad->feature_index())); +} - Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override { - return CheckShape(batch_norm_inference, - ShapeInference::InferBatchNormInferenceShape( - batch_norm_inference->operand(0)->shape(), - batch_norm_inference->operand(1)->shape(), - batch_norm_inference->operand(2)->shape(), - batch_norm_inference->operand(3)->shape(), - batch_norm_inference->operand(4)->shape(), - batch_norm_inference->feature_index())); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const Shape& expected_shape) { + if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + return InvalidArgument( + "Expected instruction to have shape compatible with %s, actual " + "shape is %s:\n%s", + ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + instruction->ToString().c_str()); } + return tensorflow::Status::OK(); +} - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { - return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( - batch_norm_grad->operand(0)->shape(), - batch_norm_grad->operand(1)->shape(), - batch_norm_grad->operand(2)->shape(), - batch_norm_grad->operand(3)->shape(), - batch_norm_grad->operand(4)->shape(), - batch_norm_grad->feature_index())); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const StatusOr<Shape>& expected_shape_status) { + if (!expected_shape_status.ok()) { + Status s = expected_shape_status.status(); + tensorflow::errors::AppendToMessage(&s, ", for instruction ", + instruction->ToString()); + return s; } + return CheckShape(instruction, expected_shape_status.ValueOrDie()); +} - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferUnaryOpShape(instruction->opcode(), + instruction->operand(0))); +} - private: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. - Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { - return InvalidArgument( - "Expected instruction to have shape compatible with %s, actual " - "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); - } - return tensorflow::Status::OK(); - } +Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + return CheckShape( + instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), + instruction->operand(0), + instruction->operand(1))); +} - // Overload which takes a StatusOr to reduce boilerplate in the caller. - Status CheckShape(const HloInstruction* instruction, - const StatusOr<Shape>& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); - tensorflow::errors::AppendToMessage(&s, ", for instruction ", - instruction->ToString()); - return s; - } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); - } +Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferTernaryOpShape( + instruction->opcode(), instruction->operand(0), + instruction->operand(1), instruction->operand(2))); +} - // Check a unary (binary, etc) instruction's shape against the inferred shape. - Status CheckUnaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferUnaryOpShape( - instruction->opcode(), instruction->operand(0))); - } - Status CheckBinaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferBinaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1))); - } - Status CheckTernaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferTernaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1), instruction->operand(2))); - } - Status CheckVariadicShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferVariadicOpShape( - instruction->opcode(), instruction->operands())); - } +Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferVariadicOpShape( + instruction->opcode(), instruction->operands())); +} - // Checks if the given two instructions shares the same channel id. - Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); - } - return tensorflow::Status::OK(); +// Checks if the given two instructions shares the same channel id. +Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); } - - // Returns the size of a Shape in bytes. - const std::function<int64(const Shape&)> shape_size_fn_; -}; + return tensorflow::Status::OK(); +} string ComputationsToString( tensorflow::gtl::ArraySlice<HloComputation*> computations) { @@ -499,8 +474,6 @@ Status VerifyHloStructure(HloModule* module) { return tensorflow::Status::OK(); } -} // namespace - Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); @@ -622,7 +595,6 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions; - ShapeVerifier shape_verifier(shape_size_fn_); for (auto* computation : module->computations()) { for (const auto& instruction : computation->instructions()) { @@ -702,7 +674,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { instructions[instruction->name()] = instruction; } - TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); + TF_RETURN_IF_ERROR(computation->Accept(shape_verifier_.get())); } return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e35a7f3642..6368611f32 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,14 +18,98 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + namespace xla { +// Visitor which verifies that the output shape is correctly set. Verifies +// against the inferred shape for the instruction. +// TODO(b/26024837): Check output shape for all instruction types. +class ShapeVerifier : public DfsHloVisitor { + public: + Status HandleElementwiseUnary(HloInstruction* hlo) override; + Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleClamp(HloInstruction* clamp) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleBitcastConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleReducePrecision(HloInstruction* reduce_precision) override; + Status HandleInfeed(HloInstruction*) override; + Status HandleOutfeed(HloInstruction*) override; + Status HandleRng(HloInstruction*) override; + Status HandleReverse(HloInstruction* reverse) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleConstant(HloInstruction* constant) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleReshape(HloInstruction* reshape) override; + Status HandleTranspose(HloInstruction* transpose) override; + Status HandleParameter(HloInstruction*) override; + Status HandleFusion(HloInstruction*) override; + Status HandleCall(HloInstruction* call) override; + Status HandleCustomCall(HloInstruction*) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + + Status FinishVisit(HloInstruction*) override { + return tensorflow::Status::OK(); + } + + protected: + // Check the instruction's shape against the given expected shape and return + // an appropriate error if there is a mismatch. + Status CheckShape(const HloInstruction* instruction, + const Shape& expected_shape); + + // Overload which takes a StatusOr to reduce boilerplate in the caller. + Status CheckShape(const HloInstruction* instruction, + const StatusOr<Shape>& expected_shape_status); + + // Check a unary (binary, etc) instruction's shape against the inferred shape. + Status CheckUnaryShape(const HloInstruction* instruction); + Status CheckBinaryShape(const HloInstruction* instruction); + Status CheckTernaryShape(const HloInstruction* instruction); + Status CheckVariadicShape(const HloInstruction* instruction); + + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2); +}; + // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: - explicit HloVerifier(const std::function<int64(const Shape&)>& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} + // Uses standard shape inference. + explicit HloVerifier() : shape_verifier_(MakeUnique<ShapeVerifier>()) {} + // Uses custom shape verification. + explicit HloVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) + : shape_verifier_(std::move(shape_verifier)) {} ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } @@ -37,8 +121,8 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; - // Returns the size of a Shape in bytes. - const std::function<int64(const Shape&)> shape_size_fn_; + // Verifies shapes match inferred expectations. + std::unique_ptr<ShapeVerifier> shape_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 6dc49ffe4c..a6d6c8b27f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1723,6 +1723,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument("FFT only supports ranks 1-3, but got %lld", fft_rank); } +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %lld requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ + } switch (fft_type) { case FFT: case IFFT: @@ -1731,12 +1738,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( FftType_Name(fft_type).c_str(), PrimitiveType_Name(in.element_type()).c_str()); } + RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s", PrimitiveType_Name(in.element_type()).c_str()); } + RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { if (in.dimensions(in.dimensions_size() - fft_rank + i) != fft_length[i]) { @@ -1758,7 +1767,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument("IRFFT requires C64 input type, found %s", PrimitiveType_Name(in.element_type()).c_str()); } - Shape result = ShapeUtil::ChangeElementType(in, F32); + RET_CHECK_RANK(in); + Shape result = ShapeUtil::ComplexComponentShape(in); for (int i = 0; i < fft_rank - 1; i++) { if (in.dimensions(in.dimensions_size() - fft_rank + i) != fft_length[i]) { @@ -1785,6 +1795,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( default: LOG(FATAL) << "Unexpected fft_type: " << fft_type; } +#undef RET_CHECK_RANK } /* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index a27e0f2c10..7c1a993b47 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -91,9 +91,7 @@ HloTestBase::HloTestBase() HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique<HloVerifier>([this](const Shape& shape) { - return backend().transfer_manager()->GetByteSizeRequirement(shape); - }); + hlo_verifier_ = MakeUnique<HloVerifier>(); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 31060b9e80..506091ddd8 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -23,15 +23,8 @@ limitations under the License. namespace xla { -/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { - constexpr int64 kPointerSize = sizeof(void*); - if (ShapeUtil::IsOpaque(shape)) { - return kPointerSize; - } - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - -HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} +HloVerifiedTestBase::HloVerifiedTestBase() + : shape_verifier_(MakeUnique<ShapeVerifier>()) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -47,7 +40,7 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier(shape_size_fn_); + HloVerifier verifier; xla::StatusOr<bool> mutated = verifier.Run(module_.get()); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index b3d6b5af3b..492688bf7d 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -28,14 +28,13 @@ namespace xla { // A base class for HLO tests that stores a default HloModule, and automatically // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { - public: - // Returns the size in bytes of the given shape, using a default pointer size. - static int64 DefaultShapeSize(const Shape& shape); - protected: HloVerifiedTestBase(); ~HloVerifiedTestBase() override; + // Constructs a default shape verifier. + std::unique_ptr<ShapeVerifier> MakeShapeVerifier(); + // Performs verification on the default HloModule returned by module(). // Automatically called by the testing framework for each test. // @@ -47,14 +46,14 @@ class HloVerifiedTestBase : public HloTestBase { HloModule& module(); // Sets the shape-size function used during hlo verification. If this isn't - // called, DefaultShapeSize is used instead. - void SetShapeSizeFn(std::function<int64(const Shape&)> shape_size_fn) { - shape_size_fn_ = std::move(shape_size_fn); + // called, a default ShapeVerifier is used instead. + void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) { + shape_verifier_ = std::move(shape_verifier); } private: std::unique_ptr<HloModule> module_; // Lazily populated. Access via module(). - std::function<int64(const Shape&)> shape_size_fn_; + std::unique_ptr<ShapeVerifier> shape_verifier_; bool tear_down_called_ = false; }; diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index bb215be8af..d7346d65c8 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -271,13 +271,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( Status VerifyHloModule(const perftools::gputools::Platform& platform, HloModule* const module) { - return HloVerifier( - std::bind( - &TransferManager::GetByteSizeRequirement, - TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), - std::placeholders::_1)) - .Run(module) - .status(); + return HloVerifier().Run(module).status(); } } // namespace xla |