1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
// Copyright 2010 The Bazel Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.testing.junit.runner.sharding;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.runner.Description;
import org.junit.runner.manipulation.Filter;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
/**
* Implements the round-robin sharding strategy.
*
* <p>This is done by equally dividing up the tests across all the shards
* Each test is numbered and the test number is modded with the number of
* shards and checked against the shard number to see whether it should run
* on a particular shard.
*
* <p>Equals and hashCode implementations are not necessary for correct
* sharding, but are done so that this filter can be compared in tests.
*/
public final class RoundRobinShardingFilter extends Filter {
@VisibleForTesting
final Map<Description, Integer> testToShardMap;
@VisibleForTesting
final int shardIndex;
@VisibleForTesting
final int totalShards;
public RoundRobinShardingFilter(Collection<Description> testDescriptions,
int shardIndex, int totalShards) {
Preconditions.checkArgument(shardIndex >= 0);
Preconditions.checkArgument(totalShards > shardIndex);
this.testToShardMap = buildTestToShardMap(testDescriptions);
this.shardIndex = shardIndex;
this.totalShards = totalShards;
}
/**
* Given a list of test case descriptions, returns a mapping from each
* to its index in the list.
*/
private static Map<Description, Integer> buildTestToShardMap(
Collection<Description> testDescriptions) {
Map<Description, Integer> map = Maps.newHashMap();
// Sorting this list is incredibly important to correctness. Otherwise,
// "shuffled" suites would break the sharding protocol.
List<Description> sortedDescriptions = Lists.newArrayList(testDescriptions);
Collections.sort(sortedDescriptions, new DescriptionComparator());
// If we get two descriptions that are equal, the shard number for the second
// one will overwrite the shard number for the first. Thus they'll run on the
// same shard.
int index = 0;
for (Description description : sortedDescriptions) {
Preconditions.checkArgument(description.isTest(),
"Test suite should not be included in the set of tests to shard: %s",
description.getDisplayName());
map.put(description, index);
index++;
}
return Collections.unmodifiableMap(map);
}
@Override
public boolean shouldRun(Description description) {
if (description.isSuite()) {
return true;
}
Integer testNumber = testToShardMap.get(description);
if (testNumber == null) {
throw new IllegalArgumentException("This filter keeps a mapping from each test "
+ "description to a shard, and the given description was not passed in when "
+ "filter was constructed: " + description);
}
return (testNumber % totalShards) == shardIndex;
}
@Override
public String describe() {
return "round robin sharding filter";
}
@VisibleForTesting
static class DescriptionComparator implements Comparator<Description> {
@Override
public int compare(Description d1, Description d2) {
return d1.getDisplayName().compareTo(d2.getDisplayName());
}
}
}
|