diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 13:35:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:44:11 -0700 |
commit | 9f2d1e2cf6be4a17b6318b429447a71d9d48af32 (patch) | |
tree | 624b86293ffb76e30bfcaaad09e02032cfba35b9 /tensorflow/compiler/xla/service | |
parent | 9e8c7afa5867bd19b6684458566b064148b2665b (diff) |
Few more fixes for issued in parsing invalid HLO module proto.
PiperOrigin-RevId: 215794086
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 15 |
3 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index fb91adc302..2f6db7cd7c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kIota: - TF_RET_CHECK(proto.dimensions_size() <= 1) - << "Iota instruction should have at most 1 dimension but sees " + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b618510640..255123d331 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 94c7bafd3b..188f4acc79 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, << "Maximal sharding is expected to have single device assignment, but " << proto.tile_assignment_devices().size() << " has provided."; + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector<int64> devices(proto.tile_assignment_devices().begin(), |