aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_matchers_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 18:38:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 18:41:13 -0700
commit3dfe81c60fac512703eadf224d0485e17fe7d55a (patch)
tree4d98b6486f2fe2c4c7d57da7f8d1e199f193b505 /tensorflow/compiler/xla/service/hlo_matchers_test.cc
parentda07aa28e0eef4aebe4851e9bdfc40e7b098cf04 (diff)
HloSharding parsing from string, used by new Sharding HloMatcher for ease of use.
PiperOrigin-RevId: 197825588
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_matchers_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 016cc01e33..1d10e3c4fe 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
@@ -147,6 +146,18 @@ TEST(HloMatchersTest, ShardingMatcher) {
"param.1");
p1->set_sharding(HloSharding::AssignDevice(1));
+ auto tuple_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}),
+ ShapeUtil::MakeShape(F32, {11})});
+ auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2");
+ Array<int64> assignment({2});
+ assignment.SetValues({0, 1});
+ auto sharding = HloSharding::Tuple(
+ tuple_shape,
+ {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment),
+ HloSharding::AssignDevice(1), HloSharding::Replicate()});
+ p2->set_sharding(sharding);
+
EXPECT_THAT(p0.get(), op::NoSharding());
EXPECT_THAT(p0.get(),
::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
@@ -155,6 +166,11 @@ TEST(HloMatchersTest, ShardingMatcher) {
::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
+ EXPECT_THAT(
+ p2.get(),
+ op::Sharding(
+ "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}"));
+
EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
"%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
"{maximal device=1})");