aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/RoundRobinShardingFilter.java
blob: 38d9f110b7cd6b2e62c02755e1e1d689aa9b24c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright 2010 The Bazel Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.testing.junit.runner.sharding;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import org.junit.runner.Description;
import org.junit.runner.manipulation.Filter;

import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

/**
 * Implements the round-robin sharding strategy.
 *
 * <p>This is done by equally dividing up the tests across all the shards
 * Each test is numbered and the test number is modded with the number of
 * shards and checked against the shard number to see whether it should run
 * on a particular shard.
 *
 * <p>Equals and hashCode implementations are not necessary for correct
 * sharding, but are done so that this filter can be compared in tests.
 */
public final class RoundRobinShardingFilter extends Filter {

  @VisibleForTesting
  final Map<Description, Integer> testToShardMap;
  @VisibleForTesting
  final int shardIndex;
  @VisibleForTesting
  final int totalShards;

  public RoundRobinShardingFilter(Collection<Description> testDescriptions,
      int shardIndex, int totalShards) {
    Preconditions.checkArgument(shardIndex >= 0);
    Preconditions.checkArgument(totalShards > shardIndex);
    this.testToShardMap = buildTestToShardMap(testDescriptions);
    this.shardIndex = shardIndex;
    this.totalShards = totalShards;
  }

  /**
   * Given a list of test case descriptions, returns a mapping from each
   * to its index in the list.
   */
  private static Map<Description, Integer> buildTestToShardMap(
      Collection<Description> testDescriptions) {
    Map<Description, Integer> map = Maps.newHashMap();

    // Sorting this list is incredibly important to correctness. Otherwise,
    // "shuffled" suites would break the sharding protocol.
    List<Description> sortedDescriptions = Lists.newArrayList(testDescriptions);
    Collections.sort(sortedDescriptions, new DescriptionComparator());

    // If we get two descriptions that are equal, the shard number for the second
    // one will overwrite the shard number for the first.  Thus they'll run on the
    // same shard.
    int index = 0;
    for (Description description : sortedDescriptions) {
      Preconditions.checkArgument(description.isTest(),
          "Test suite should not be included in the set of tests to shard: %s",
          description.getDisplayName());
      map.put(description, index);
      index++;
    }
    return Collections.unmodifiableMap(map);
  }

  @Override
  public boolean shouldRun(Description description) {
    if (description.isSuite()) {
      return true;
    }
    Integer testNumber = testToShardMap.get(description);
    if (testNumber == null) {
      throw new IllegalArgumentException("This filter keeps a mapping from each test "
          + "description to a shard, and the given description was not passed in when "
          + "filter was constructed: " + description);
    }
    return (testNumber % totalShards) == shardIndex;
  }

  @Override
  public String describe() {
    return "round robin sharding filter";
  }

  @VisibleForTesting
  static class DescriptionComparator implements Comparator<Description> {
    @Override
    public int compare(Description d1, Description d2) {
      return d1.getDisplayName().compareTo(d2.getDisplayName());
    }
  }

}