diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 14:35:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 14:39:45 -0700 |
commit | 05812d761031b108b43560c90867b96dc4f030eb (patch) | |
tree | f3b2307acfd9cb791c1a105ed62927542b8daa58 | |
parent | c921e45bccac86ce0becc71cedc3da2c702d5c38 (diff) |
Fixes for few issues in HloModule::CreateFromProto()
PiperOrigin-RevId: 215460064
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 22 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 3 |
5 files changed, 48 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index d1dad0d45f..deeb140b8f 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Literal literal(proto.shape()); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 6ef67ab0a8..c2041c4667 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -535,6 +535,28 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); + TF_RETURN_IF_ERROR([&]() -> Status { + std::vector<bool> parameters_seen(parameter_count); + int parameters_seen_count = 0; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) + << "Invalid parameter number. Expected [0, " << parameter_count + << "), got " << param_no; + TF_RET_CHECK(!parameters_seen[param_no]) + << "Parameter number " << param_no + << " already allocated in this computation"; + parameters_seen[param_no] = true; + parameters_seen_count++; + } + } + TF_RET_CHECK(parameters_seen_count == parameter_count) + << "Not all parameters in range [0, " << parameter_count + << ") were referenced"; + return Status::OK(); + }()); + auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index de22b2d3a5..5c16d6bb5e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -81,6 +81,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + + TF_RET_CHECK(std::all_of( + proto.operand_ids().begin(), proto.operand_ids().end(), + [&instruction_map](int64 id) { return instruction_map.contains(id); })) + << proto.name() << " instruction contains invalid operand id(s)"; + + TF_RET_CHECK(std::all_of( + proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), + [&computation_map](int64 id) { return computation_map.contains(id); })) + << proto.name() << " instruction references invalid computation id(s)"; + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -304,6 +318,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } break; case HloOpcode::kOutfeed: TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), operands(1), proto.outfeed_config()); break; @@ -492,14 +508,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; instruction->AppendOperand(instruction_map.at(operand_id)); } if (instruction->opcode() != HloOpcode::kFusion) { for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; instruction->called_computations_.push_back( computation_map.at(computation_id)); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index de7e6b53d4..94c7bafd3b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -369,10 +369,14 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } + + TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL) + << "Maximal sharding is expected to have single device assignment, but " + << proto.tile_assignment_devices().size() << " has provided."; + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector<int64> devices(proto.tile_assignment_devices().begin(), diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 020c167ee9..476a9fe868 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -831,7 +831,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString()); } |