diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 13:35:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:44:11 -0700 |
commit | 9f2d1e2cf6be4a17b6318b429447a71d9d48af32 (patch) | |
tree | 624b86293ffb76e30bfcaaad09e02032cfba35b9 /tensorflow/compiler/xla/service/hlo_sharding.cc | |
parent | 9e8c7afa5867bd19b6684458566b064148b2665b (diff) |
Few more fixes for issued in parsing invalid HLO module proto.
PiperOrigin-RevId: 215794086
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 94c7bafd3b..188f4acc79 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, << "Maximal sharding is expected to have single device assignment, but " << proto.tile_assignment_devices().size() << " has provided."; + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector<int64> devices(proto.tile_assignment_devices().begin(), |