aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 13:35:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:44:11 -0700
commit9f2d1e2cf6be4a17b6318b429447a71d9d48af32 (patch)
tree624b86293ffb76e30bfcaaad09e02032cfba35b9 /tensorflow/compiler
parent9e8c7afa5867bd19b6684458566b064148b2665b (diff)
Few more fixes for issued in parsing invalid HLO module proto.
PiperOrigin-RevId: 215794086
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/literal.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc15
-rw-r--r--tensorflow/compiler/xla/shape_util.cc7
5 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 177f39cc74..656ce720a1 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -1945,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
}
} break;
case TUPLE:
- LOG(FATAL) << "Should not be called on tuple shapes: "
- << ShapeUtil::HumanString(subshape());
- break;
+ return InvalidArgument("Should not be called on tuple shapes: %s",
+ ShapeUtil::HumanString(subshape()));
default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ return InvalidArgument("Is called on unsupported shape: %s",
+ ShapeUtil::HumanString(subshape()));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index fb91adc302..2f6db7cd7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kIota:
- TF_RET_CHECK(proto.dimensions_size() <= 1)
- << "Iota instruction should have at most 1 dimension but sees "
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index b618510640..255123d331 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
- %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
+ %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 94c7bafd3b..188f4acc79 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
<< "Maximal sharding is expected to have single device assignment, but "
<< proto.tile_assignment_devices().size() << " has provided.";
+ TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
+ TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
+
+ // RE: the product of tile assignment tensor dimensions must be
+ // equal to tile_assignment_devices.size().
+ int64 product_of_dimensions = 1;
+ for (auto dimension : proto.tile_assignment_dimensions()) {
+ TF_RET_CHECK(dimension > 0);
+ product_of_dimensions =
+ MultiplyWithoutOverflow(product_of_dimensions, dimension);
+ TF_RET_CHECK(product_of_dimensions > 0);
+ }
+ TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 476a9fe868..d244923532 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -869,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return Status::OK();
}
- if (Rank(shape) != shape.dimensions_size()) {
- return InvalidArgument(
- "shape's rank is mismatched with dimension count; rank=%d "
- "dimensions_size=%d",
- Rank(shape), shape.dimensions_size());
+ if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) {
+ return InvalidArgument("sparse arrays must have rank > 0");
}
for (int64 i = 0; i < Rank(shape); ++i) {
int64 dimension = shape.dimensions(i);