aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc690
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h92
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc13
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc13
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h15
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc8
10 files changed, 455 insertions, 411 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 9636f6b5b3..f0507982b3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -234,7 +234,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// Optimization pipeline.
HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction());
+ pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -253,7 +253,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>(ShapeSizeBytesFunction());
+ pass.AddInvariantChecker<HloVerifier>();
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 7c2e693560..af0010e207 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -133,12 +133,10 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
}
// Runs optimization passes on the given HLO module.
-tensorflow::Status OptimizeHloModule(
- HloModule* hlo_module,
- const HloCostAnalysis::ShapeSizeFunction& shape_size_function) {
+tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>(shape_size_function);
+ pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -150,7 +148,7 @@ tensorflow::Status OptimizeHloModule(
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>(shape_size_function);
+ pass.AddInvariantChecker<HloVerifier>();
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -191,14 +189,14 @@ tensorflow::Status OptimizeHloModule(
}
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>(shape_size_function);
+ fusion.AddInvariantChecker<HloVerifier>();
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>(shape_size_function);
+ reduce_pipeline.AddInvariantChecker<HloVerifier>();
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -216,16 +214,14 @@ tensorflow::Status OptimizeHloModule(
// Modifies the given HLO module so that it will be accepted by IrEmitter.
// Unlike optimization passes, the passes are necessary for correctness.
-tensorflow::Status PrepareHloModuleForIrEmitting(
- HloModule* hlo_module,
- const HloCostAnalysis::ShapeSizeFunction& shape_size_function) {
+tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// In some cases, we have to place the result of an instruction in a temporary
// buffer. For instance, the buffer that holds an external parameter is
// assumed immutable at this point, and should not be reused for output
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>(shape_size_function);
+ pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<PadInsertion>();
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout());
@@ -409,7 +405,7 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses");
Tracing::TraceMe annotation("HLO Transforms", module->name(),
/*is_expensive=*/true);
- TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction()));
+ TF_RETURN_IF_ERROR(OptimizeHloModule(module.get()));
return std::move(module);
}
@@ -419,8 +415,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
TF_RET_CHECK(stream_exec != nullptr);
- TF_RETURN_IF_ERROR(
- PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
+ TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));
llvm::LLVMContext llvm_context;
std::string buffer;
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index f16daa0b54..2f290f61bd 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -117,9 +117,7 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier([](const Shape& shape) {
- return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
- });
+ HloVerifier verifier;
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9d5ca6673a..9d9cf0c0f6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -14,425 +14,400 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
-#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
-namespace {
+Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
+ return CheckUnaryShape(hlo);
+}
-// Visitor which verifies that the output shape is correctly set. Verifies
-// against the inferred shape for the instruction.
-// TODO(b/26024837): Check output shape for all instruction types.
-class ShapeVerifier : public DfsHloVisitor {
- public:
- explicit ShapeVerifier(
- const std::function<int64(const Shape&)>& shape_size_fn)
- : shape_size_fn_(shape_size_fn) {}
+Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
+ return CheckBinaryShape(hlo);
+}
- Status HandleElementwiseUnary(HloInstruction* hlo) override {
- return CheckUnaryShape(hlo);
- }
+Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
+ return CheckTernaryShape(clamp);
+}
- Status HandleElementwiseBinary(HloInstruction* hlo) override {
- return CheckBinaryShape(hlo);
- }
+Status ShapeVerifier::HandleSelect(HloInstruction* select) {
+ return CheckTernaryShape(select);
+}
- Status HandleClamp(HloInstruction* clamp) override {
- return CheckTernaryShape(clamp);
+Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : concatenate->operands()) {
+ operand_shapes.push_back(&operand->shape());
}
+ return CheckShape(concatenate,
+ ShapeInference::InferConcatOpShape(
+ operand_shapes, concatenate->concatenate_dimension()));
+}
- Status HandleSelect(HloInstruction* select) override {
- return CheckTernaryShape(select);
- }
+Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
+ return CheckShape(convert, ShapeInference::InferConvertShape(
+ convert->operand(0)->shape(),
+ convert->shape().element_type()));
+}
- Status HandleConcatenate(HloInstruction* concatenate) override {
- std::vector<const Shape*> operand_shapes;
- for (const HloInstruction* operand : concatenate->operands()) {
- operand_shapes.push_back(&operand->shape());
- }
- return CheckShape(
- concatenate, ShapeInference::InferConcatOpShape(
- operand_shapes, concatenate->concatenate_dimension()));
- }
+Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
+ return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
+ convert->operand(0)->shape(),
+ convert->shape().element_type()));
+}
- Status HandleConvert(HloInstruction* convert) override {
- return CheckShape(convert, ShapeInference::InferConvertShape(
- convert->operand(0)->shape(),
- convert->shape().element_type()));
- }
+Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
+ return CheckUnaryShape(copy);
+}
- Status HandleBitcastConvert(HloInstruction* convert) override {
- return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
- convert->operand(0)->shape(),
- convert->shape().element_type()));
- }
+Status ShapeVerifier::HandleDot(HloInstruction* dot) {
+ TF_ASSIGN_OR_RETURN(const Shape expected,
+ ShapeInference::InferDotOpShape(
+ dot->operand(0)->shape(), dot->operand(1)->shape(),
+ dot->dot_dimension_numbers()));
+ return CheckShape(dot, expected);
+}
- Status HandleCopy(HloInstruction* copy) override {
- return CheckUnaryShape(copy);
- }
+Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
+ TF_ASSIGN_OR_RETURN(
+ const Shape expected,
+ ShapeInference::InferConvolveShape(
+ convolution->operand(0)->shape(), convolution->operand(1)->shape(),
+ convolution->window(), convolution->convolution_dimension_numbers()));
+ return CheckShape(convolution, expected);
+}
- Status HandleDot(HloInstruction* dot) override {
- TF_ASSIGN_OR_RETURN(const Shape expected,
- ShapeInference::InferDotOpShape(
- dot->operand(0)->shape(), dot->operand(1)->shape(),
- dot->dot_dimension_numbers()));
- return CheckShape(dot, expected);
- }
+Status ShapeVerifier::HandleFft(HloInstruction* fft) {
+ TF_ASSIGN_OR_RETURN(
+ const Shape expected,
+ ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
+ fft->fft_length()));
+ return CheckShape(fft, expected);
+}
- Status HandleConvolution(HloInstruction* convolution) override {
- TF_ASSIGN_OR_RETURN(
- const Shape expected,
- ShapeInference::InferConvolveShape(
- convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(),
- convolution->convolution_dimension_numbers()));
- return CheckShape(convolution, expected);
+Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : crs->operands()) {
+ operand_shapes.push_back(&operand->shape());
}
+ return CheckShape(crs,
+ ShapeInference::InferCrossReplicaSumShape(operand_shapes));
+}
- Status HandleFft(HloInstruction* fft) override {
- TF_ASSIGN_OR_RETURN(
- const Shape expected,
- ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
- fft->fft_length()));
- return CheckShape(fft, expected);
- }
+Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
+ return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
+ reduce_precision->operand(0)->shape(),
+ reduce_precision->exponent_bits(),
+ reduce_precision->mantissa_bits()));
+}
- Status HandleCrossReplicaSum(HloInstruction* crs) override {
- std::vector<const Shape*> operand_shapes;
- for (const HloInstruction* operand : crs->operands()) {
- operand_shapes.push_back(&operand->shape());
- }
- return CheckShape(
- crs, ShapeInference::InferCrossReplicaSumShape(operand_shapes));
- }
+Status ShapeVerifier::HandleInfeed(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleReducePrecision(HloInstruction* reduce_precision) override {
- return CheckShape(reduce_precision,
- ShapeInference::InferReducePrecisionShape(
- reduce_precision->operand(0)->shape(),
- reduce_precision->exponent_bits(),
- reduce_precision->mantissa_bits()));
- }
+Status ShapeVerifier::HandleOutfeed(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleInfeed(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleRng(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleOutfeed(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
+ return CheckShape(
+ reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
+ reverse->dimensions()));
+}
- Status HandleRng(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleSort(HloInstruction* sort) {
+ return CheckUnaryShape(sort);
+}
- Status HandleReverse(HloInstruction* reverse) override {
- return CheckShape(
- reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
- reverse->dimensions()));
- }
+Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
+ return CheckShape(constant, constant->literal().shape());
+}
- Status HandleSort(HloInstruction* sort) override {
- return CheckUnaryShape(sort);
- }
+Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
+ return CheckShape(get_tuple_element,
+ ShapeInference::InferGetTupleElementShape(
+ get_tuple_element->operand(0)->shape(),
+ get_tuple_element->tuple_index()));
+}
- Status HandleConstant(HloInstruction* constant) override {
- return CheckShape(constant, constant->literal().shape());
- }
+Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
+ return CheckShape(
+ reduce,
+ ShapeInference::InferReduceShape(
+ reduce->operand(0)->shape(), reduce->operand(1)->shape(),
+ reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
+}
- Status HandleGetTupleElement(HloInstruction* get_tuple_element) override {
- return CheckShape(get_tuple_element,
- ShapeInference::InferGetTupleElementShape(
- get_tuple_element->operand(0)->shape(),
- get_tuple_element->tuple_index()));
- }
+Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
+ return tensorflow::Status::OK();
+}
- Status HandleReduce(HloInstruction* reduce) override {
- return CheckShape(
- reduce,
- ShapeInference::InferReduceShape(
- reduce->operand(0)->shape(), reduce->operand(1)->shape(),
- reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
+Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
+ // HLO broadcast has no exact analog at the proto level so there is no
+ // ShapeInference method. Check the output shape explicitly.
+ const Shape& operand_shape = broadcast->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
+ broadcast->dimensions().size());
+ for (int64 operand_dimension = 0;
+ operand_dimension < ShapeUtil::Rank(operand_shape);
+ ++operand_dimension) {
+ int64 output_dimension = broadcast->dimensions()[operand_dimension];
+ TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
+ operand_shape.dimensions(operand_dimension));
}
+ return tensorflow::Status::OK();
+}
- Status HandleBitcast(HloInstruction* bitcast) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
+ TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
+ ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
+ return tensorflow::Status::OK();
+}
- Status HandleBroadcast(HloInstruction* broadcast) override {
- // HLO broadcast has no exact analog at the proto level so there is no
- // ShapeInference method. Check the output shape explicitly.
- const Shape& operand_shape = broadcast->operand(0)->shape();
- TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
- broadcast->dimensions().size());
- for (int64 operand_dimension = 0;
- operand_dimension < ShapeUtil::Rank(operand_shape);
- ++operand_dimension) {
- int64 output_dimension = broadcast->dimensions()[operand_dimension];
- TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension));
- }
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
+ return CheckShape(
+ transpose, ShapeInference::InferTransposeShape(
+ transpose->operand(0)->shape(), transpose->dimensions()));
+}
- Status HandleReshape(HloInstruction* reshape) override {
- TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
- ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleParameter(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleTranspose(HloInstruction* transpose) override {
- return CheckShape(transpose, ShapeInference::InferTransposeShape(
- transpose->operand(0)->shape(),
- transpose->dimensions()));
- }
+Status ShapeVerifier::HandleFusion(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleParameter(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleCall(HloInstruction* call) {
+ // The shape of kCall should match the shape of the computation it calls.
+ return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
+}
- Status HandleFusion(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleCustomCall(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
- Status HandleCall(HloInstruction* call) override {
- // The shape of kCall should match the shape of the computation it calls.
- return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
- }
+Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
+ return CheckShape(slice,
+ ShapeInference::InferSliceShape(
+ slice->operand(0)->shape(), slice->slice_starts(),
+ slice->slice_limits(), slice->slice_strides()));
+}
- Status HandleCustomCall(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
+ return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape(
+ dynamic_slice->operand(0)->shape(),
+ dynamic_slice->operand(1)->shape(),
+ dynamic_slice->dynamic_slice_sizes()));
+}
- Status HandleSlice(HloInstruction* slice) override {
- return CheckShape(slice,
- ShapeInference::InferSliceShape(
- slice->operand(0)->shape(), slice->slice_starts(),
- slice->slice_limits(), slice->slice_strides()));
- }
+Status ShapeVerifier::HandleDynamicUpdateSlice(
+ HloInstruction* dynamic_update_slice) {
+ return CheckShape(dynamic_update_slice,
+ ShapeInference::InferDynamicUpdateSliceShape(
+ dynamic_update_slice->operand(0)->shape(),
+ dynamic_update_slice->operand(1)->shape(),
+ dynamic_update_slice->operand(2)->shape()));
+}
- Status HandleDynamicSlice(HloInstruction* dynamic_slice) override {
- return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape(
- dynamic_slice->operand(0)->shape(),
- dynamic_slice->operand(1)->shape(),
- dynamic_slice->dynamic_slice_sizes()));
- }
+Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
+ return CheckVariadicShape(tuple);
+}
- Status HandleDynamicUpdateSlice(
- HloInstruction* dynamic_update_slice) override {
- return CheckShape(dynamic_update_slice,
- ShapeInference::InferDynamicUpdateSliceShape(
- dynamic_update_slice->operand(0)->shape(),
- dynamic_update_slice->operand(1)->shape(),
- dynamic_update_slice->operand(2)->shape()));
- }
+Status ShapeVerifier::HandleMap(HloInstruction* map) {
+ std::vector<const Shape*> operand_shapes;
+ int64 max_operand_rank = 0;
+ for (const HloInstruction* operand : map->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ max_operand_rank =
+ std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
+ }
+ // TODO(b/65689298) Remove code below once Map is generalized to accept
+ // arbitrary map dimensions.
+ std::vector<int64> map_dims(max_operand_rank);
+ std::iota(map_dims.begin(), map_dims.end(), 0);
+ return CheckShape(map, ShapeInference::InferMapShape(
+ operand_shapes,
+ map->to_apply()->ComputeProgramShape(), map_dims));
+}
- Status HandleTuple(HloInstruction* tuple) override {
- return CheckVariadicShape(tuple);
- }
+Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
+ return CheckShape(
+ reduce_window,
+ ShapeInference::InferReduceWindowShape(
+ reduce_window->operand(0)->shape(),
+ reduce_window->operand(1)->shape(), reduce_window->window(),
+ reduce_window->to_apply()->ComputeProgramShape()));
+}
- Status HandleMap(HloInstruction* map) override {
- std::vector<const Shape*> operand_shapes;
- int64 max_operand_rank = 0;
- for (const HloInstruction* operand : map->operands()) {
- operand_shapes.push_back(&operand->shape());
- max_operand_rank =
- std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
- }
- // TODO(b/65689298) Remove code below once Map is generalized to accept
- // arbitrary map dimensions.
- std::vector<int64> map_dims(max_operand_rank);
- std::iota(map_dims.begin(), map_dims.end(), 0);
- return CheckShape(
- map,
- ShapeInference::InferMapShape(
- operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims));
- }
+Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
+ return CheckShape(
+ instruction,
+ ShapeInference::InferSelectAndScatterShape(
+ instruction->operand(0)->shape(),
+ instruction->select()->ComputeProgramShape(), instruction->window(),
+ instruction->operand(1)->shape(), instruction->operand(2)->shape(),
+ instruction->scatter()->ComputeProgramShape()));
+}
- Status HandleReduceWindow(HloInstruction* reduce_window) override {
- return CheckShape(
- reduce_window,
- ShapeInference::InferReduceWindowShape(
- reduce_window->operand(0)->shape(),
- reduce_window->operand(1)->shape(), reduce_window->window(),
- reduce_window->to_apply()->ComputeProgramShape()));
- }
+Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
+ // The shape of kWhile should match the shape of the body computation it
+ // calls.
+ return CheckShape(xla_while,
+ xla_while->while_body()->ComputeProgramShape().result());
+}
- Status HandleSelectAndScatter(HloInstruction* instruction) override {
- return CheckShape(
- instruction,
- ShapeInference::InferSelectAndScatterShape(
- instruction->operand(0)->shape(),
- instruction->select()->ComputeProgramShape(), instruction->window(),
- instruction->operand(1)->shape(), instruction->operand(2)->shape(),
- instruction->scatter()->ComputeProgramShape()));
- }
+Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
+ TF_RETURN_IF_ERROR(CheckShape(
+ conditional,
+ conditional->true_computation()->ComputeProgramShape().result()));
+ return CheckShape(
+ conditional,
+ conditional->false_computation()->ComputeProgramShape().result());
+}
- Status HandleWhile(HloInstruction* xla_while) override {
- // The shape of kWhile should match the shape of the body computation it
- // calls.
- return CheckShape(xla_while,
- xla_while->while_body()->ComputeProgramShape().result());
- }
+Status ShapeVerifier::HandlePad(HloInstruction* pad) {
+ return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
+ pad->operand(1)->shape(),
+ pad->padding_config()));
+}
- Status HandleConditional(HloInstruction* conditional) override {
- TF_RETURN_IF_ERROR(CheckShape(
- conditional,
- conditional->true_computation()->ComputeProgramShape().result()));
- return CheckShape(
- conditional,
- conditional->false_computation()->ComputeProgramShape().result());
- }
+Status ShapeVerifier::HandleSend(HloInstruction* send) {
+ TF_RET_CHECK(send->users().size() == 1);
+ const HloInstruction* send_done = send->users().front();
+ TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+ return CheckShape(
+ send, ShapeUtil::MakeTupleShape(
+ {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
+}
- Status HandlePad(HloInstruction* pad) override {
- return CheckShape(pad,
- ShapeInference::InferPadShape(pad->operand(0)->shape(),
- pad->operand(1)->shape(),
- pad->padding_config()));
- }
+Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
+ TF_RET_CHECK(send_done->operands().size() == 1);
+ const HloInstruction* send = send_done->operand(0);
+ TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
+ TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+ return CheckShape(send_done, ShapeUtil::MakeNil());
+}
- Status HandleSend(HloInstruction* send) override {
- TF_RET_CHECK(send->users().size() == 1);
- const HloInstruction* send_done = send->users().front();
- TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
- return CheckShape(
- send, ShapeUtil::MakeTupleShape(
- {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
- }
+Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
+ TF_RET_CHECK(recv->users().size() == 1);
+ const HloInstruction* recv_done = recv->users().front();
+ TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+ return CheckShape(recv,
+ ShapeUtil::MakeTupleShape(
+ {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
+}
- Status HandleSendDone(HloInstruction* send_done) override {
- TF_RET_CHECK(send_done->operands().size() == 1);
- const HloInstruction* send = send_done->operand(0);
- TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
- TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
- return CheckShape(send_done, ShapeUtil::MakeNil());
- }
+Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
+ TF_RET_CHECK(recv_done->operands().size() == 1);
+ const HloInstruction* recv = recv_done->operand(0);
+ TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
+ TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+ return CheckShape(recv_done, recv->shape().tuple_shapes(0));
+}
- Status HandleRecv(HloInstruction* recv) override {
- TF_RET_CHECK(recv->users().size() == 1);
- const HloInstruction* recv_done = recv->users().front();
- TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
- return CheckShape(recv,
- ShapeUtil::MakeTupleShape(
- {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
- }
+Status ShapeVerifier::HandleBatchNormTraining(
+ HloInstruction* batch_norm_training) {
+ return CheckShape(batch_norm_training,
+ ShapeInference::InferBatchNormTrainingShape(
+ batch_norm_training->operand(0)->shape(),
+ batch_norm_training->operand(1)->shape(),
+ batch_norm_training->operand(2)->shape(),
+ batch_norm_training->feature_index()));
+}
- Status HandleRecvDone(HloInstruction* recv_done) override {
- TF_RET_CHECK(recv_done->operands().size() == 1);
- const HloInstruction* recv = recv_done->operand(0);
- TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
- TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
- return CheckShape(recv_done, recv->shape().tuple_shapes(0));
- }
+Status ShapeVerifier::HandleBatchNormInference(
+ HloInstruction* batch_norm_inference) {
+ return CheckShape(batch_norm_inference,
+ ShapeInference::InferBatchNormInferenceShape(
+ batch_norm_inference->operand(0)->shape(),
+ batch_norm_inference->operand(1)->shape(),
+ batch_norm_inference->operand(2)->shape(),
+ batch_norm_inference->operand(3)->shape(),
+ batch_norm_inference->operand(4)->shape(),
+ batch_norm_inference->feature_index()));
+}
- Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
- return CheckShape(batch_norm_training,
- ShapeInference::InferBatchNormTrainingShape(
- batch_norm_training->operand(0)->shape(),
- batch_norm_training->operand(1)->shape(),
- batch_norm_training->operand(2)->shape(),
- batch_norm_training->feature_index()));
- }
+Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
+ return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
+ batch_norm_grad->operand(0)->shape(),
+ batch_norm_grad->operand(1)->shape(),
+ batch_norm_grad->operand(2)->shape(),
+ batch_norm_grad->operand(3)->shape(),
+ batch_norm_grad->operand(4)->shape(),
+ batch_norm_grad->feature_index()));
+}
- Status HandleBatchNormInference(
- HloInstruction* batch_norm_inference) override {
- return CheckShape(batch_norm_inference,
- ShapeInference::InferBatchNormInferenceShape(
- batch_norm_inference->operand(0)->shape(),
- batch_norm_inference->operand(1)->shape(),
- batch_norm_inference->operand(2)->shape(),
- batch_norm_inference->operand(3)->shape(),
- batch_norm_inference->operand(4)->shape(),
- batch_norm_inference->feature_index()));
+Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
+ const Shape& expected_shape) {
+ if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
+ return InvalidArgument(
+ "Expected instruction to have shape compatible with %s, actual "
+ "shape is %s:\n%s",
+ ShapeUtil::HumanString(expected_shape).c_str(),
+ ShapeUtil::HumanString(instruction->shape()).c_str(),
+ instruction->ToString().c_str());
}
+ return tensorflow::Status::OK();
+}
- Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
- return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
- batch_norm_grad->operand(0)->shape(),
- batch_norm_grad->operand(1)->shape(),
- batch_norm_grad->operand(2)->shape(),
- batch_norm_grad->operand(3)->shape(),
- batch_norm_grad->operand(4)->shape(),
- batch_norm_grad->feature_index()));
+Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
+ const StatusOr<Shape>& expected_shape_status) {
+ if (!expected_shape_status.ok()) {
+ Status s = expected_shape_status.status();
+ tensorflow::errors::AppendToMessage(&s, ", for instruction ",
+ instruction->ToString());
+ return s;
}
+ return CheckShape(instruction, expected_shape_status.ValueOrDie());
+}
- Status FinishVisit(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
+ return CheckShape(instruction,
+ ShapeInference::InferUnaryOpShape(instruction->opcode(),
+ instruction->operand(0)));
+}
- private:
- // Check the instruction's shape against the given expected shape and return
- // an appropriate error if there is a mismatch.
- Status CheckShape(const HloInstruction* instruction,
- const Shape& expected_shape) {
- if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
- return InvalidArgument(
- "Expected instruction to have shape compatible with %s, actual "
- "shape is %s:\n%s",
- ShapeUtil::HumanString(expected_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
- instruction->ToString().c_str());
- }
- return tensorflow::Status::OK();
- }
+Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
+ return CheckShape(
+ instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
+ instruction->operand(0),
+ instruction->operand(1)));
+}
- // Overload which takes a StatusOr to reduce boilerplate in the caller.
- Status CheckShape(const HloInstruction* instruction,
- const StatusOr<Shape>& expected_shape_status) {
- if (!expected_shape_status.ok()) {
- Status s = expected_shape_status.status();
- tensorflow::errors::AppendToMessage(&s, ", for instruction ",
- instruction->ToString());
- return s;
- }
- return CheckShape(instruction, expected_shape_status.ValueOrDie());
- }
+Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
+ return CheckShape(instruction,
+ ShapeInference::InferTernaryOpShape(
+ instruction->opcode(), instruction->operand(0),
+ instruction->operand(1), instruction->operand(2)));
+}
- // Check a unary (binary, etc) instruction's shape against the inferred shape.
- Status CheckUnaryShape(const HloInstruction* instruction) {
- return CheckShape(instruction,
- ShapeInference::InferUnaryOpShape(
- instruction->opcode(), instruction->operand(0)));
- }
- Status CheckBinaryShape(const HloInstruction* instruction) {
- return CheckShape(instruction,
- ShapeInference::InferBinaryOpShape(
- instruction->opcode(), instruction->operand(0),
- instruction->operand(1)));
- }
- Status CheckTernaryShape(const HloInstruction* instruction) {
- return CheckShape(instruction,
- ShapeInference::InferTernaryOpShape(
- instruction->opcode(), instruction->operand(0),
- instruction->operand(1), instruction->operand(2)));
- }
- Status CheckVariadicShape(const HloInstruction* instruction) {
- return CheckShape(instruction,
- ShapeInference::InferVariadicOpShape(
- instruction->opcode(), instruction->operands()));
- }
+Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
+ return CheckShape(instruction,
+ ShapeInference::InferVariadicOpShape(
+ instruction->opcode(), instruction->operands()));
+}
- // Checks if the given two instructions shares the same channel id.
- Status CheckSameChannel(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- if (instr1->channel_id() != instr2->channel_id()) {
- return FailedPrecondition(
- "Expected to have the same channel id, actual channel ids are: %s "
- "(%lld), %s (%lld)",
- instr1->ToString().c_str(), instr1->channel_id(),
- instr2->ToString().c_str(), instr2->channel_id());
- }
- return tensorflow::Status::OK();
+// Checks if the given two instructions shares the same channel id.
+Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ if (instr1->channel_id() != instr2->channel_id()) {
+ return FailedPrecondition(
+ "Expected to have the same channel id, actual channel ids are: %s "
+ "(%lld), %s (%lld)",
+ instr1->ToString().c_str(), instr1->channel_id(),
+ instr2->ToString().c_str(), instr2->channel_id());
}
-
- // Returns the size of a Shape in bytes.
- const std::function<int64(const Shape&)> shape_size_fn_;
-};
+ return tensorflow::Status::OK();
+}
string ComputationsToString(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
@@ -499,8 +474,6 @@ Status VerifyHloStructure(HloModule* module) {
return tensorflow::Status::OK();
}
-} // namespace
-
Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
@@ -622,7 +595,6 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions;
- ShapeVerifier shape_verifier(shape_size_fn_);
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
@@ -702,7 +674,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
instructions[instruction->name()] = instruction;
}
- TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier));
+ TF_RETURN_IF_ERROR(computation->Accept(shape_verifier_.get()));
}
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index e35a7f3642..6368611f32 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -18,14 +18,98 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+
namespace xla {
+// Visitor which verifies that the output shape is correctly set. Verifies
+// against the inferred shape for the instruction.
+// TODO(b/26024837): Check output shape for all instruction types.
+class ShapeVerifier : public DfsHloVisitor {
+ public:
+ Status HandleElementwiseUnary(HloInstruction* hlo) override;
+ Status HandleElementwiseBinary(HloInstruction* hlo) override;
+ Status HandleClamp(HloInstruction* clamp) override;
+ Status HandleSelect(HloInstruction* select) override;
+ Status HandleConcatenate(HloInstruction* concatenate) override;
+ Status HandleConvert(HloInstruction* convert) override;
+ Status HandleBitcastConvert(HloInstruction* convert) override;
+ Status HandleCopy(HloInstruction* copy) override;
+ Status HandleDot(HloInstruction* dot) override;
+ Status HandleConvolution(HloInstruction* convolution) override;
+ Status HandleFft(HloInstruction* fft) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleReducePrecision(HloInstruction* reduce_precision) override;
+ Status HandleInfeed(HloInstruction*) override;
+ Status HandleOutfeed(HloInstruction*) override;
+ Status HandleRng(HloInstruction*) override;
+ Status HandleReverse(HloInstruction* reverse) override;
+ Status HandleSort(HloInstruction* sort) override;
+ Status HandleConstant(HloInstruction* constant) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
+ Status HandleReduce(HloInstruction* reduce) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleReshape(HloInstruction* reshape) override;
+ Status HandleTranspose(HloInstruction* transpose) override;
+ Status HandleParameter(HloInstruction*) override;
+ Status HandleFusion(HloInstruction*) override;
+ Status HandleCall(HloInstruction* call) override;
+ Status HandleCustomCall(HloInstruction*) override;
+ Status HandleSlice(HloInstruction* slice) override;
+ Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
+ Status HandleDynamicUpdateSlice(
+ HloInstruction* dynamic_update_slice) override;
+ Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleMap(HloInstruction* map) override;
+ Status HandleReduceWindow(HloInstruction* reduce_window) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleWhile(HloInstruction* xla_while) override;
+ Status HandleConditional(HloInstruction* conditional) override;
+ Status HandlePad(HloInstruction* pad) override;
+ Status HandleSend(HloInstruction* send) override;
+ Status HandleSendDone(HloInstruction* send_done) override;
+ Status HandleRecv(HloInstruction* recv) override;
+ Status HandleRecvDone(HloInstruction* recv_done) override;
+ Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
+ Status HandleBatchNormInference(
+ HloInstruction* batch_norm_inference) override;
+ Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
+
+ Status FinishVisit(HloInstruction*) override {
+ return tensorflow::Status::OK();
+ }
+
+ protected:
+ // Check the instruction's shape against the given expected shape and return
+ // an appropriate error if there is a mismatch.
+ Status CheckShape(const HloInstruction* instruction,
+ const Shape& expected_shape);
+
+ // Overload which takes a StatusOr to reduce boilerplate in the caller.
+ Status CheckShape(const HloInstruction* instruction,
+ const StatusOr<Shape>& expected_shape_status);
+
+ // Check a unary (binary, etc) instruction's shape against the inferred shape.
+ Status CheckUnaryShape(const HloInstruction* instruction);
+ Status CheckBinaryShape(const HloInstruction* instruction);
+ Status CheckTernaryShape(const HloInstruction* instruction);
+ Status CheckVariadicShape(const HloInstruction* instruction);
+
+ // Checks if the given two instructions shares the same channel id.
+ Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2);
+};
+
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
public:
- explicit HloVerifier(const std::function<int64(const Shape&)>& shape_size_fn)
- : shape_size_fn_(shape_size_fn) {}
+ // Uses standard shape inference.
+ explicit HloVerifier() : shape_verifier_(MakeUnique<ShapeVerifier>()) {}
+ // Uses custom shape verification.
+ explicit HloVerifier(std::unique_ptr<ShapeVerifier> shape_verifier)
+ : shape_verifier_(std::move(shape_verifier)) {}
~HloVerifier() override = default;
tensorflow::StringPiece name() const override { return "verifier"; }
@@ -37,8 +121,8 @@ class HloVerifier : public HloPassInterface {
// CHECKs various invariants of a fusion instruction.
Status CheckFusionInstruction(HloInstruction* fusion) const;
- // Returns the size of a Shape in bytes.
- const std::function<int64(const Shape&)> shape_size_fn_;
+ // Verifies shapes match inferred expectations.
+ std::unique_ptr<ShapeVerifier> shape_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 6dc49ffe4c..a6d6c8b27f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1723,6 +1723,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return InvalidArgument("FFT only supports ranks 1-3, but got %lld",
fft_rank);
}
+#define RET_CHECK_RANK(x) \
+ if (x.dimensions_size() < fft_rank) { \
+ return InvalidArgument( \
+ "FFT of rank %lld requires input of at least " \
+ "same rank; got input of rank %d", \
+ fft_rank, x.dimensions_size()); \
+ }
switch (fft_type) {
case FFT:
case IFFT:
@@ -1731,12 +1738,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
FftType_Name(fft_type).c_str(),
PrimitiveType_Name(in.element_type()).c_str());
}
+ RET_CHECK_RANK(in);
return in;
case RFFT: {
if (in.element_type() != F32) {
return InvalidArgument("RFFT requires F32 input type, found %s",
PrimitiveType_Name(in.element_type()).c_str());
}
+ RET_CHECK_RANK(in);
for (int i = 0; i < fft_rank; i++) {
if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
fft_length[i]) {
@@ -1758,7 +1767,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return InvalidArgument("IRFFT requires C64 input type, found %s",
PrimitiveType_Name(in.element_type()).c_str());
}
- Shape result = ShapeUtil::ChangeElementType(in, F32);
+ RET_CHECK_RANK(in);
+ Shape result = ShapeUtil::ComplexComponentShape(in);
for (int i = 0; i < fft_rank - 1; i++) {
if (in.dimensions(in.dimensions_size() - fft_rank + i) !=
fft_length[i]) {
@@ -1785,6 +1795,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
default:
LOG(FATAL) << "Unexpected fft_type: " << fft_type;
}
+#undef RET_CHECK_RANK
}
/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index a27e0f2c10..7c1a993b47 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -91,9 +91,7 @@ HloTestBase::HloTestBase()
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ = MakeUnique<HloVerifier>([this](const Shape& shape) {
- return backend().transfer_manager()->GetByteSizeRequirement(shape);
- });
+ hlo_verifier_ = MakeUnique<HloVerifier>();
}
/* static */
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 31060b9e80..506091ddd8 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -23,15 +23,8 @@ limitations under the License.
namespace xla {
-/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) {
- constexpr int64 kPointerSize = sizeof(void*);
- if (ShapeUtil::IsOpaque(shape)) {
- return kPointerSize;
- }
- return ShapeUtil::ByteSizeOf(shape, kPointerSize);
-}
-
-HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {}
+HloVerifiedTestBase::HloVerifiedTestBase()
+ : shape_verifier_(MakeUnique<ShapeVerifier>()) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
@@ -47,7 +40,7 @@ void HloVerifiedTestBase::TearDown() {
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- HloVerifier verifier(shape_size_fn_);
+ HloVerifier verifier;
xla::StatusOr<bool> mutated = verifier.Run(module_.get());
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index b3d6b5af3b..492688bf7d 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -28,14 +28,13 @@ namespace xla {
// A base class for HLO tests that stores a default HloModule, and automatically
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
- public:
- // Returns the size in bytes of the given shape, using a default pointer size.
- static int64 DefaultShapeSize(const Shape& shape);
-
protected:
HloVerifiedTestBase();
~HloVerifiedTestBase() override;
+ // Constructs a default shape verifier.
+ std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
+
// Performs verification on the default HloModule returned by module().
// Automatically called by the testing framework for each test.
//
@@ -47,14 +46,14 @@ class HloVerifiedTestBase : public HloTestBase {
HloModule& module();
// Sets the shape-size function used during hlo verification. If this isn't
- // called, DefaultShapeSize is used instead.
- void SetShapeSizeFn(std::function<int64(const Shape&)> shape_size_fn) {
- shape_size_fn_ = std::move(shape_size_fn);
+ // called, a default ShapeVerifier is used instead.
+ void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) {
+ shape_verifier_ = std::move(shape_verifier);
}
private:
std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
- std::function<int64(const Shape&)> shape_size_fn_;
+ std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
};
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index bb215be8af..d7346d65c8 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -271,13 +271,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
Status VerifyHloModule(const perftools::gputools::Platform& platform,
HloModule* const module) {
- return HloVerifier(
- std::bind(
- &TransferManager::GetByteSizeRequirement,
- TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(),
- std::placeholders::_1))
- .Run(module)
- .status();
+ return HloVerifier().Run(module).status();
}
} // namespace xla