aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 13:35:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:44:11 -0700
commit9f2d1e2cf6be4a17b6318b429447a71d9d48af32 (patch)
tree624b86293ffb76e30bfcaaad09e02032cfba35b9 /tensorflow/compiler/xla/service
parent9e8c7afa5867bd19b6684458566b064148b2665b (diff)
Few more fixes for issued in parsing invalid HLO module proto.
PiperOrigin-RevId: 215794086
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc15
3 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index fb91adc302..2f6db7cd7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kIota:
- TF_RET_CHECK(proto.dimensions_size() <= 1)
- << "Iota instruction should have at most 1 dimension but sees "
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index b618510640..255123d331 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
- %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
+ %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 94c7bafd3b..188f4acc79 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
<< "Maximal sharding is expected to have single device assignment, but "
<< proto.tile_assignment_devices().size() << " has provided.";
+ TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
+ TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
+
+ // RE: the product of tile assignment tensor dimensions must be
+ // equal to tile_assignment_devices.size().
+ int64 product_of_dimensions = 1;
+ for (auto dimension : proto.tile_assignment_dimensions()) {
+ TF_RET_CHECK(dimension > 0);
+ product_of_dimensions =
+ MultiplyWithoutOverflow(product_of_dimensions, dimension);
+ TF_RET_CHECK(product_of_dimensions > 0);
+ }
+ TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),