aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/java_tools/junitrunner/java/com/google/testing/junit
diff options
context:
space:
mode:
authorGravatar Irina Iancu <elenairina@google.com>2016-09-23 07:46:25 +0000
committerGravatar Laszlo Csomor <laszlocsomor@google.com>2016-09-23 08:16:14 +0000
commit33ad37612ab7edc32d5e82c3912acebacaef42dc (patch)
tree2b8781a519cf4cc4c1294de7168d89ba3d828ea8 /src/java_tools/junitrunner/java/com/google/testing/junit
parent254cde06bc5df24df03e2164dd7c9fe589e3f82c (diff)
Open sourcing junitrunner/java/com/google/testing/junit/runner/sharding/weighted.
-- MOS_MIGRATED_REVID=134046554
Diffstat (limited to 'src/java_tools/junitrunner/java/com/google/testing/junit')
-rw-r--r--src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD24
-rw-r--r--src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java82
-rw-r--r--src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java154
3 files changed, 260 insertions, 0 deletions
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD
new file mode 100644
index 0000000000..a35ed2357c
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD
@@ -0,0 +1,24 @@
+DEFAULT_VISIBILITY = [
+ "//java/com/google/testing/junit/runner:__subpackages__",
+ "//javatests/com/google/testing/junit/runner:__subpackages__",
+ "//third_party/bazel/src/java_tools/junitrunner/java/com/google/testing/junit/runner:__subpackages__",
+ "//third_party/bazel/src/java_tools/junitrunner/javatests/com/google/testing/junit/runner:__subpackages__",
+]
+
+package(default_visibility = ["//src:__subpackages__"])
+
+# TODO(bazel-team): This should be testonly = 1.
+java_library(
+ name = "weighted",
+ srcs = glob(["*.java"]),
+ deps = [
+ "//java/com/google/testing/util",
+ "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/api",
+ "//third_party:junit4",
+ ],
+)
+
+filegroup(
+ name = "srcs",
+ srcs = glob(["**"]),
+)
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java
new file mode 100644
index 0000000000..5f597920bd
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java
@@ -0,0 +1,82 @@
+// Copyright 2015 The Bazel Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.testing.junit.runner.sharding.weighted;
+
+import com.google.testing.junit.runner.sharding.api.ShardingFilterFactory;
+import com.google.testing.junit.runner.sharding.api.WeightStrategy;
+import com.google.testing.util.RuntimeCost;
+import java.util.Collection;
+import org.junit.Ignore;
+import org.junit.runner.Description;
+import org.junit.runner.manipulation.Filter;
+
+/**
+ * A factory that creates a {@link WeightedShardingFilter} that extracts the weight for a test from
+ * the {@link RuntimeCost} annotations present in descriptions of tests.
+ */
+public final class BinStackingShardingFilterFactory implements ShardingFilterFactory {
+ static final String DEFAULT_TEST_WEIGHT_PROPERTY = "test.sharding.default_weight";
+ static final int DEFAULT_TEST_WEIGHT = 1;
+
+ private final int defaultTestWeight;
+
+ public BinStackingShardingFilterFactory() {
+ this(getDefaultTestWeight());
+ }
+
+ // VisibleForTesting
+ BinStackingShardingFilterFactory(int defaultTestWeight) {
+ this.defaultTestWeight = defaultTestWeight;
+ }
+
+ static int getDefaultTestWeight() {
+ String property = System.getProperty(DEFAULT_TEST_WEIGHT_PROPERTY);
+ if (property != null) {
+ return Integer.parseInt(property);
+ }
+ return DEFAULT_TEST_WEIGHT;
+ }
+
+ @Override
+ public Filter createFilter(
+ Collection<Description> testDescriptions, int shardIndex, int totalShards) {
+ return new WeightedShardingFilter(
+ testDescriptions,
+ shardIndex,
+ totalShards,
+ new RuntimeCostWeightStrategy(defaultTestWeight));
+ }
+
+ static class RuntimeCostWeightStrategy implements WeightStrategy {
+
+ private final int defaultTestWeight;
+
+ RuntimeCostWeightStrategy(int defaultTestWeight) {
+ this.defaultTestWeight = defaultTestWeight;
+ }
+
+ @Override
+ public int getDescriptionWeight(Description description) {
+ RuntimeCost runtimeCost = description.getAnnotation(RuntimeCost.class);
+ Ignore ignore = description.getAnnotation(Ignore.class);
+
+ if (runtimeCost == null || ignore != null) {
+ return defaultTestWeight;
+ } else {
+ return runtimeCost.value();
+ }
+ }
+ }
+}
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java
new file mode 100644
index 0000000000..bb3a8aa718
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java
@@ -0,0 +1,154 @@
+// Copyright 2015 The Bazel Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.testing.junit.runner.sharding.weighted;
+
+import com.google.testing.junit.runner.sharding.api.WeightStrategy;
+import com.google.testing.util.RuntimeCost;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.junit.runner.Description;
+import org.junit.runner.manipulation.Filter;
+
+/**
+ * A sharding function that attempts to evenly use time on all available
+ * shards while considering the test's weight.
+ *
+ * <p>When all tests have the same weight the sharding function behaves
+ * similarly to round robin.
+ */
+public final class WeightedShardingFilter extends Filter {
+ private final Map<Description, Integer> testToShardMap;
+ private final int shardIndex;
+
+ public WeightedShardingFilter(Collection<Description> descriptions, int shardIndex,
+ int totalShards, WeightStrategy weightStrategy) {
+ if (shardIndex < 0 || totalShards <= shardIndex) {
+ throw new IllegalArgumentException();
+ }
+ this.shardIndex = shardIndex;
+ this.testToShardMap = buildTestToShardMap(descriptions, totalShards, weightStrategy);
+ }
+
+ @Override
+ public String describe() {
+ return "bin stacking filter";
+ }
+
+ @Override
+ public boolean shouldRun(Description description) {
+ if (description.isSuite()) {
+ return true;
+ }
+ Integer shardForTest = testToShardMap.get(description);
+ if (shardForTest == null) {
+ throw new IllegalArgumentException("This filter keeps a mapping from each test "
+ + "description to a shard, and the given description was not passed in when "
+ + "filter was constructed: " + description);
+ }
+ return shardForTest == shardIndex;
+ }
+
+ private static Map<Description, Integer> buildTestToShardMap(
+ Collection<Description> descriptions, int numShards, WeightStrategy weightStrategy) {
+ Map<Description, Integer> map = new HashMap<>();
+
+ // Sorting this list is incredibly important to correctness. Otherwise,
+ // "shuffled" suites would break the sharding protocol.
+ List<Description> sortedDescriptions = new ArrayList<>(descriptions);
+ Collections.sort(sortedDescriptions, new WeightClassAndTestNameComparator(weightStrategy));
+
+ PriorityQueue<Shard> queue = new PriorityQueue<>(numShards);
+ for (int i = 0; i < numShards; i++) {
+ queue.offer(new Shard(i));
+ }
+
+ // If we get two descriptions that are equal, the shard number for the second
+ // one will overwrite the shard number for the first. Thus they'll run on the
+ // same shard.
+ for (Description description : sortedDescriptions) {
+ if (!description.isTest()) {
+ throw new IllegalArgumentException("Test suite should not be included in the set of tests "
+ + "to shard: " + description.getDisplayName());
+ }
+
+ Shard shard = queue.remove();
+ shard.addWeight(weightStrategy.getDescriptionWeight(description));
+ queue.offer(shard);
+ map.put(description, shard.getIndex());
+ }
+ return Collections.unmodifiableMap(map);
+ }
+
+ /**
+ * A comparator that sorts by weight in descending order, then by test case name.
+ */
+ private static class WeightClassAndTestNameComparator implements Comparator<Description> {
+
+ private final WeightStrategy weightStrategy;
+
+ WeightClassAndTestNameComparator(WeightStrategy weightStrategy) {
+ this.weightStrategy = weightStrategy;
+ }
+
+ @Override
+ public int compare(Description d1, Description d2) {
+ int weight1 = weightStrategy.getDescriptionWeight(d1);
+ int weight2 = weightStrategy.getDescriptionWeight(d2);
+ if (weight1 != weight2) {
+ // We consider the reverse order when comparing weights.
+ return -1 * compareInts(weight1, weight2);
+ }
+ return d1.getDisplayName().compareTo(d2.getDisplayName());
+ }
+ }
+
+ /**
+ * A bean representing the sum of {@link RuntimeCost}s assigned to a shard.
+ */
+ private static class Shard implements Comparable<Shard> {
+ private final int index;
+ private int weight = 0;
+
+ Shard(int index) {
+ this.index = index;
+ }
+
+ void addWeight(int weight) {
+ this.weight += weight;
+ }
+
+ int getIndex() {
+ return index;
+ }
+
+ @Override
+ public int compareTo(Shard other) {
+ if (weight != other.weight) {
+ return compareInts(weight, other.weight);
+ }
+ return compareInts(index, other.index);
+ }
+ }
+
+ private static int compareInts(int value1, int value2) {
+ return value1 - value2;
+ }
+}