aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 13:35:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:44:11 -0700
commit9f2d1e2cf6be4a17b6318b429447a71d9d48af32 (patch)
tree624b86293ffb76e30bfcaaad09e02032cfba35b9 /tensorflow/compiler/xla/service/hlo_sharding.cc
parent9e8c7afa5867bd19b6684458566b064148b2665b (diff)
Few more fixes for issued in parsing invalid HLO module proto.
PiperOrigin-RevId: 215794086
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 94c7bafd3b..188f4acc79 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
<< "Maximal sharding is expected to have single device assignment, but "
<< proto.tile_assignment_devices().size() << " has provided.";
+ TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
+ TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
+
+ // RE: the product of tile assignment tensor dimensions must be
+ // equal to tile_assignment_devices.size().
+ int64 product_of_dimensions = 1;
+ for (auto dimension : proto.tile_assignment_dimensions()) {
+ TF_RET_CHECK(dimension > 0);
+ product_of_dimensions =
+ MultiplyWithoutOverflow(product_of_dimensions, dimension);
+ TF_RET_CHECK(product_of_dimensions > 0);
+ }
+ TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),