aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 14:35:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 14:39:45 -0700
commit05812d761031b108b43560c90867b96dc4f030eb (patch)
treef3b2307acfd9cb791c1a105ed62927542b8daa58
parentc921e45bccac86ce0becc71cedc3da2c702d5c38 (diff)
Fixes for few issues in HloModule::CreateFromProto()
PiperOrigin-RevId: 215460064
-rw-r--r--tensorflow/compiler/xla/literal.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc8
-rw-r--r--tensorflow/compiler/xla/shape_util.cc3
5 files changed, 48 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index d1dad0d45f..deeb140b8f 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return InvalidArgument("LiteralProto has no layout");
}
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
Literal literal(proto.shape());
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 6ef67ab0a8..c2041c4667 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -535,6 +535,28 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
+ TF_RETURN_IF_ERROR([&]() -> Status {
+ std::vector<bool> parameters_seen(parameter_count);
+ int parameters_seen_count = 0;
+ for (auto& instruction : instructions) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ int64 param_no = instruction->parameter_number();
+ TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
+ << "Invalid parameter number. Expected [0, " << parameter_count
+ << "), got " << param_no;
+ TF_RET_CHECK(!parameters_seen[param_no])
+ << "Parameter number " << param_no
+ << " already allocated in this computation";
+ parameters_seen[param_no] = true;
+ parameters_seen_count++;
+ }
+ }
+ TF_RET_CHECK(parameters_seen_count == parameter_count)
+ << "Not all parameters in range [0, " << parameter_count
+ << ") were referenced";
+ return Status::OK();
+ }());
+
auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*fusion_instruction=*/nullptr));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index de22b2d3a5..5c16d6bb5e 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -81,6 +81,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
+
+ TF_RET_CHECK(std::all_of(
+ proto.operand_ids().begin(), proto.operand_ids().end(),
+ [&instruction_map](int64 id) { return instruction_map.contains(id); }))
+ << proto.name() << " instruction contains invalid operand id(s)";
+
+ TF_RET_CHECK(std::all_of(
+ proto.called_computation_ids().begin(),
+ proto.called_computation_ids().end(),
+ [&computation_map](int64 id) { return computation_map.contains(id); }))
+ << proto.name() << " instruction references invalid computation id(s)";
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
@@ -304,6 +318,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
} break;
case HloOpcode::kOutfeed:
TF_RET_CHECK(proto.operand_ids_size() == 2);
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
break;
@@ -492,14 +508,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
- << "No instruction with id " << operand_id;
instruction->AppendOperand(instruction_map.at(operand_id));
}
if (instruction->opcode() != HloOpcode::kFusion) {
for (const int64 computation_id : proto.called_computation_ids()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_id))
- << "No computation with id " << computation_id;
instruction->called_computations_.push_back(
computation_map.at(computation_id));
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index de7e6b53d4..94c7bafd3b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -369,10 +369,14 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return HloSharding(tuple_shardings);
} else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
- proto.tile_assignment_devices().size() == 1) {
+ } else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0));
}
+
+ TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL)
+ << "Maximal sharding is expected to have single device assignment, but "
+ << proto.tile_assignment_devices().size() << " has provided.";
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 020c167ee9..476a9fe868 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -831,7 +831,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
- if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
+ !PrimitiveType_IsValid(shape.element_type())) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString());
}