diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-23 18:38:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-23 18:41:13 -0700 |
commit | 3dfe81c60fac512703eadf224d0485e17fe7d55a (patch) | |
tree | 4d98b6486f2fe2c4c7d57da7f8d1e199f193b505 /tensorflow/compiler/xla/service/hlo_matchers_test.cc | |
parent | da07aa28e0eef4aebe4851e9bdfc40e7b098cf04 (diff) |
HloSharding parsing from string, used by new Sharding HloMatcher for ease of use.
PiperOrigin-RevId: 197825588
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers_test.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 016cc01e33..1d10e3c4fe 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -147,6 +146,18 @@ TEST(HloMatchersTest, ShardingMatcher) { "param.1"); p1->set_sharding(HloSharding::AssignDevice(1)); + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array<int64> assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + EXPECT_THAT(p0.get(), op::NoSharding()); EXPECT_THAT(p0.get(), ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); @@ -155,6 +166,11 @@ TEST(HloMatchersTest, ShardingMatcher) { ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " "{maximal device=1})"); |