diff options
author | 2017-06-21 23:12:51 +0200 | |
---|---|---|
committer | 2017-06-22 12:41:41 +0200 | |
commit | bf2e2d8003a2418941404741fce82f1e51b3b28b (patch) | |
tree | 635b390996de6d1d8207ee0f923e00b31646aa0e /src/main/java/com/google/devtools | |
parent | aade4f64350a7817e452225733ae98cf19c63a69 (diff) |
In the Blaze Query implementation, use Set and Map implementations backed by the same KeyExtractor used that the Uniquifier implementation uses. This fixes a hypothetical issue where we were previously relying on Target#equals/hashCode.
RELNOTES: None
PiperOrigin-RevId: 159741545
Diffstat (limited to 'src/main/java/com/google/devtools')
19 files changed, 371 insertions, 182 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java index 4b1a1d5e6e..55cb759a8d 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java +++ b/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java @@ -238,7 +238,8 @@ public abstract class AbstractBlazeQueryEnvironment<T> extends AbstractQueryEnvi return true; } - public QueryTaskFuture<Set<T>> evalTargetPattern(QueryExpression caller, String pattern) { + public QueryTaskFuture<ThreadSafeMutableSet<T>> evalTargetPattern( + QueryExpression caller, String pattern) { try { preloadOrThrow(caller, ImmutableList.of(pattern)); } catch (TargetParsingException tpe) { @@ -253,14 +254,15 @@ public abstract class AbstractBlazeQueryEnvironment<T> extends AbstractQueryEnvi } catch (InterruptedException e) { return immediateCancelledFuture(); } - final AggregateAllCallback<T> aggregatingCallback = QueryUtil.newAggregateAllCallback(); + final AggregateAllCallback<T, ThreadSafeMutableSet<T>> aggregatingCallback = + QueryUtil.newAggregateAllCallback(this); QueryTaskFuture<Void> evalFuture = getTargetsMatchingPattern(caller, pattern, aggregatingCallback); return whenSucceedsCall( evalFuture, - new QueryTaskCallable<Set<T>>() { + new QueryTaskCallable<ThreadSafeMutableSet<T>>() { @Override - public Set<T> call() { + public ThreadSafeMutableSet<T> call() { return aggregatingCallback.getResult(); } }); diff --git a/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java index 361514b5e7..7fa00512d9 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java +++ b/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java @@ -16,12 +16,15 @@ package com.google.devtools.build.lib.query2; import com.google.common.base.Function; import com.google.common.base.Predicate; import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.cmdline.LabelSyntaxException; +import com.google.devtools.build.lib.cmdline.PackageIdentifier; import com.google.devtools.build.lib.cmdline.ResolvedTargets; import com.google.devtools.build.lib.cmdline.TargetParsingException; +import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; import com.google.devtools.build.lib.events.ExtendedEventHandler; import com.google.devtools.build.lib.graph.Digraph; import com.google.devtools.build.lib.graph.Node; @@ -42,6 +45,8 @@ import com.google.devtools.build.lib.query2.engine.QueryEvalResult; import com.google.devtools.build.lib.query2.engine.QueryException; import com.google.devtools.build.lib.query2.engine.QueryExpression; import com.google.devtools.build.lib.query2.engine.QueryUtil.MinDepthUniquifierImpl; +import com.google.devtools.build.lib.query2.engine.QueryUtil.MutableKeyExtractorBackedMapImpl; +import com.google.devtools.build.lib.query2.engine.QueryUtil.ThreadSafeMutableKeyExtractorBackedSetImpl; import com.google.devtools.build.lib.query2.engine.QueryUtil.UniquifierImpl; import com.google.devtools.build.lib.query2.engine.SkyframeRestartQueryException; import com.google.devtools.build.lib.query2.engine.ThreadSafeOutputFormatterCallback; @@ -226,7 +231,7 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> @Override public Collection<Target> getFwdDeps(Iterable<Target> targets) { - Set<Target> result = new HashSet<>(); + ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet(); for (Target target : targets) { result.addAll(getTargetsFromNodes(getNode(target).getSuccessors())); } @@ -235,7 +240,7 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> @Override public Collection<Target> getReverseDeps(Iterable<Target> targets) { - Set<Target> result = new HashSet<>(); + ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet(); for (Target target : targets) { result.addAll(getTargetsFromNodes(getNode(target).getPredecessors())); } @@ -243,7 +248,8 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> } @Override - public Set<Target> getTransitiveClosure(Set<Target> targetNodes) { + public ThreadSafeMutableSet<Target> getTransitiveClosure( + ThreadSafeMutableSet<Target> targetNodes) { for (Target node : targetNodes) { checkBuilt(node); } @@ -270,11 +276,10 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> @Override public void buildTransitiveClosure(QueryExpression caller, - Set<Target> targetNodes, + ThreadSafeMutableSet<Target> targetNodes, int maxDepth) throws QueryException, InterruptedException { - Set<Target> targets = targetNodes; - preloadTransitiveClosure(targets, maxDepth); - labelVisitor.syncWithVisitor(eventHandler, targets, keepGoing, + preloadTransitiveClosure(targetNodes, maxDepth); + labelVisitor.syncWithVisitor(eventHandler, targetNodes, keepGoing, loadingPhaseThreads, maxDepth, errorObserver, new GraphBuildingObserver()); if (errorObserver.hasErrors()) { @@ -283,8 +288,24 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> } @Override - public Set<Target> getNodesOnPath(Target from, Target to) { - return getTargetsFromNodes(graph.getShortestPath(getNode(from), getNode(to))); + public Iterable<Target> getNodesOnPath(Target from, Target to) { + ImmutableList.Builder<Target> builder = ImmutableList.builder(); + for (Node<Target> node : graph.getShortestPath(getNode(from), getNode(to))) { + builder.add(node.getLabel()); + } + return builder.build(); + } + + @ThreadSafe + @Override + public ThreadSafeMutableSet<Target> createThreadSafeMutableSet() { + return new ThreadSafeMutableKeyExtractorBackedSetImpl<>( + TargetKeyExtractor.INSTANCE, Target.class); + } + + @Override + public <V> MutableMap<Target, V> createMutableMap() { + return new MutableKeyExtractorBackedMapImpl<>(TargetKeyExtractor.INSTANCE); } @Override @@ -297,7 +318,7 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> return new MinDepthUniquifierImpl<>(TargetKeyExtractor.INSTANCE, /*concurrencyLevel=*/ 1); } - private void preloadTransitiveClosure(Set<Target> targets, int maxDepth) + private void preloadTransitiveClosure(ThreadSafeMutableSet<Target> targets, int maxDepth) throws InterruptedException { if (maxDepth >= MAX_DEPTH_FULL_SCAN_LIMIT && transitivePackageLoader != null) { // Only do the full visitation if "maxDepth" is large enough. Otherwise, the benefits of @@ -350,15 +371,15 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> // TODO(bazel-team): rename this to getDependentFiles when all implementations // of QueryEnvironment is fixed. @Override - public Set<Target> getBuildFiles( + public ThreadSafeMutableSet<Target> getBuildFiles( final QueryExpression caller, - Set<Target> nodes, + ThreadSafeMutableSet<Target> nodes, boolean buildFiles, boolean subincludes, boolean loads) throws QueryException { - Set<Target> dependentFiles = new LinkedHashSet<>(); - Set<Package> seenPackages = new HashSet<>(); + ThreadSafeMutableSet<Target> dependentFiles = createThreadSafeMutableSet(); + Set<PackageIdentifier> seenPackages = new HashSet<>(); // Keep track of seen labels, to avoid adding a fake subinclude label that also exists as a // real target. Set<Label> seenLabels = new HashSet<>(); @@ -367,7 +388,7 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> // extensions) for package "pkg", to "buildfiles". for (Target x : nodes) { Package pkg = x.getPackage(); - if (seenPackages.add(pkg)) { + if (seenPackages.add(pkg.getPackageIdentifier())) { if (buildFiles) { addIfUniqueLabel(getNode(pkg.getBuildFile()), seenLabels, dependentFiles); } @@ -438,8 +459,8 @@ public class BlazeQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> } /** Given a set of target nodes, returns the targets. */ - private static Set<Target> getTargetsFromNodes(Iterable<Node<Target>> input) { - Set<Target> result = new LinkedHashSet<>(); + private ThreadSafeMutableSet<Target> getTargetsFromNodes(Iterable<Node<Target>> input) { + ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet(); for (Node<Target> node : input) { result.add(node.getLabel()); } diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java index 8a22522d26..a2cf884ffd 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java +++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java @@ -62,17 +62,21 @@ import com.google.devtools.build.lib.packages.Target; import com.google.devtools.build.lib.pkgcache.PathPackageLocator; import com.google.devtools.build.lib.pkgcache.TargetPatternEvaluator; import com.google.devtools.build.lib.profiler.AutoProfiler; +import com.google.devtools.build.lib.query2.AbstractBlazeQueryEnvironment.TargetKeyExtractor; import com.google.devtools.build.lib.query2.engine.AllRdepsFunction; import com.google.devtools.build.lib.query2.engine.Callback; import com.google.devtools.build.lib.query2.engine.FunctionExpression; import com.google.devtools.build.lib.query2.engine.KeyExtractor; import com.google.devtools.build.lib.query2.engine.MinDepthUniquifier; import com.google.devtools.build.lib.query2.engine.OutputFormatterCallback; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap; import com.google.devtools.build.lib.query2.engine.QueryEvalResult; import com.google.devtools.build.lib.query2.engine.QueryException; import com.google.devtools.build.lib.query2.engine.QueryExpression; import com.google.devtools.build.lib.query2.engine.QueryExpressionMapper; import com.google.devtools.build.lib.query2.engine.QueryUtil.MinDepthUniquifierImpl; +import com.google.devtools.build.lib.query2.engine.QueryUtil.MutableKeyExtractorBackedMapImpl; +import com.google.devtools.build.lib.query2.engine.QueryUtil.ThreadSafeMutableKeyExtractorBackedSetImpl; import com.google.devtools.build.lib.query2.engine.QueryUtil.UniquifierImpl; import com.google.devtools.build.lib.query2.engine.RdepsFunction; import com.google.devtools.build.lib.query2.engine.StreamableQueryEnvironment; @@ -110,7 +114,6 @@ import java.util.Collection; import java.util.Deque; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -430,30 +433,6 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> return result.build(); } - private Map<Target, Collection<Target>> targetifyKeys(Map<SkyKey, Collection<Target>> input) - throws InterruptedException { - Map<SkyKey, Target> targets = makeTargetsFromSkyKeys(input.keySet()); - ImmutableMap.Builder<Target, Collection<Target>> resultBuilder = ImmutableMap.builder(); - for (Map.Entry<SkyKey, Collection<Target>> entry : input.entrySet()) { - SkyKey key = entry.getKey(); - Target target = targets.get(key); - if (target != null) { - resultBuilder.put(target, entry.getValue()); - } - } - return resultBuilder.build(); - } - - private Map<Target, Collection<Target>> targetifyKeysAndValues( - Map<SkyKey, Iterable<SkyKey>> input) throws InterruptedException { - return targetifyKeys(targetifyValues(input)); - } - - private Map<Target, Collection<Target>> getRawFwdDeps(Iterable<Target> targets) - throws InterruptedException { - return targetifyKeysAndValues(graph.getDirectDeps(makeTransitiveTraversalKeys(targets))); - } - private Map<SkyKey, Collection<Target>> getRawReverseDeps( Iterable<SkyKey> transitiveTraversalKeys) throws InterruptedException { return targetifyValues(graph.getReverseDeps(transitiveTraversalKeys)); @@ -482,22 +461,24 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> }); } - /** Targets may not be in the graph because they are not in the universe or depend on cycles. */ - private void warnIfMissingTargets( - Iterable<Target> targets, Set<Target> result) { - if (Iterables.size(targets) != result.size()) { - Set<Target> missingTargets = Sets.difference(ImmutableSet.copyOf(targets), result); + @Override + public ThreadSafeMutableSet<Target> getFwdDeps(Iterable<Target> targets) + throws InterruptedException { + Map<SkyKey, Target> targetsByKey = new HashMap<>(Iterables.size(targets)); + for (Target target : targets) { + targetsByKey.put(TARGET_TO_SKY_KEY.apply(target), target); + } + Map<SkyKey, Collection<Target>> directDeps = targetifyValues( + graph.getDirectDeps(targetsByKey.keySet())); + if (targetsByKey.keySet().size() != directDeps.keySet().size()) { + Iterable<Label> missingTargets = Iterables.transform( + Sets.difference(targetsByKey.keySet(), directDeps.keySet()), + SKYKEY_TO_LABEL); eventHandler.handle(Event.warn("Targets were missing from graph: " + missingTargets)); } - } - - @Override - public Collection<Target> getFwdDeps(Iterable<Target> targets) throws InterruptedException { - Set<Target> result = new HashSet<>(); - Map<Target, Collection<Target>> rawFwdDeps = getRawFwdDeps(targets); - warnIfMissingTargets(targets, rawFwdDeps.keySet()); - for (Map.Entry<Target, Collection<Target>> entry : rawFwdDeps.entrySet()) { - result.addAll(filterFwdDeps(entry.getKey(), entry.getValue())); + ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet(); + for (Map.Entry<SkyKey, Collection<Target>> entry : directDeps.entrySet()) { + result.addAll(filterFwdDeps(targetsByKey.get(entry.getKey()), entry.getValue())); } return result; } @@ -555,35 +536,46 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> } @Override - public Set<Target> getTransitiveClosure(Set<Target> targets) throws InterruptedException { - Set<Target> visited = new HashSet<>(); - Collection<Target> current = targets; + public ThreadSafeMutableSet<Target> getTransitiveClosure(ThreadSafeMutableSet<Target> targets) + throws InterruptedException { + ThreadSafeMutableSet<Target> visited = createThreadSafeMutableSet(); + ThreadSafeMutableSet<Target> current = targets; while (!current.isEmpty()) { - Collection<Target> toVisit = Collections2.filter(current, + Iterable<Target> toVisit = Iterables.filter(current, Predicates.not(Predicates.in(visited))); current = getFwdDeps(toVisit); - visited.addAll(toVisit); + Iterables.addAll(visited, toVisit); } - return ImmutableSet.copyOf(visited); + return visited; } // Implemented with a breadth-first search. @Override - public Set<Target> getNodesOnPath(Target from, Target to) throws InterruptedException { + public ImmutableList<Target> getNodesOnPath(Target from, Target to) + throws InterruptedException { // Tree of nodes visited so far. - Map<Target, Target> nodeToParent = new HashMap<>(); + Map<Label, Label> nodeToParent = new HashMap<>(); + Map<Label, Target> labelToTarget = new HashMap<>(); // Contains all nodes left to visit in a (LIFO) stack. Deque<Target> toVisit = new ArrayDeque<>(); toVisit.add(from); - nodeToParent.put(from, null); + nodeToParent.put(from.getLabel(), null); + labelToTarget.put(from.getLabel(), from); while (!toVisit.isEmpty()) { Target current = toVisit.removeFirst(); if (to.equals(current)) { - return ImmutableSet.copyOf(Digraph.getPathToTreeNode(nodeToParent, to)); + List<Label> labelPath = Digraph.getPathToTreeNode(nodeToParent, to.getLabel()); + ImmutableList.Builder<Target> targetPathBuilder = ImmutableList.builder(); + for (Label label : labelPath) { + targetPathBuilder.add(Preconditions.checkNotNull(labelToTarget.get(label), label)); + } + return targetPathBuilder.build(); } for (Target dep : getFwdDeps(ImmutableList.of(current))) { - if (!nodeToParent.containsKey(dep)) { - nodeToParent.put(dep, current); + Label depLabel = dep.getLabel(); + if (!nodeToParent.containsKey(depLabel)) { + nodeToParent.put(depLabel, current.getLabel()); + labelToTarget.put(depLabel, dep); toVisit.addFirst(dep); } } @@ -649,6 +641,18 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> @ThreadSafe @Override + public ThreadSafeMutableSet<Target> createThreadSafeMutableSet() { + return new ThreadSafeMutableKeyExtractorBackedSetImpl<>( + TargetKeyExtractor.INSTANCE, Target.class, DEFAULT_THREAD_COUNT); + } + + @Override + public <V> MutableMap<Target, V> createMutableMap() { + return new MutableKeyExtractorBackedMapImpl<Target, Label, V>(TargetKeyExtractor.INSTANCE); + } + + @ThreadSafe + @Override public Uniquifier<Target> createUniquifier() { return createTargetUniquifier(); } @@ -731,15 +735,15 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> @ThreadSafe @Override - public Set<Target> getBuildFiles( + public ThreadSafeMutableSet<Target> getBuildFiles( QueryExpression caller, - Set<Target> nodes, + ThreadSafeMutableSet<Target> nodes, boolean buildFiles, boolean subincludes, boolean loads) throws QueryException { - Set<Target> dependentFiles = new LinkedHashSet<>(); - Set<Package> seenPackages = new HashSet<>(); + ThreadSafeMutableSet<Target> dependentFiles = createThreadSafeMutableSet(); + Set<PackageIdentifier> seenPackages = new HashSet<>(); // Keep track of seen labels, to avoid adding a fake subinclude label that also exists as a // real target. Set<Label> seenLabels = new HashSet<>(); @@ -748,7 +752,7 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> // extensions) for package "pkg", to "buildfiles". for (Target x : nodes) { Package pkg = x.getPackage(); - if (seenPackages.add(pkg)) { + if (seenPackages.add(pkg.getPackageIdentifier())) { if (buildFiles) { addIfUniqueLabel(pkg.getBuildFile(), seenLabels, dependentFiles); } @@ -843,8 +847,10 @@ public class SkyQueryEnvironment extends AbstractBlazeQueryEnvironment<Target> } @Override - public void buildTransitiveClosure(QueryExpression caller, Set<Target> targets, int maxDepth) - throws QueryException, InterruptedException { + public void buildTransitiveClosure( + QueryExpression caller, + ThreadSafeMutableSet<Target> targets, + int maxDepth) throws QueryException, InterruptedException { // Everything has already been loaded, so here we just check for errors so that we can // pre-emptively throw/report if needed. Iterable<SkyKey> transitiveTraversalKeys = makeTransitiveTraversalKeys(targets); diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java index a078a4b7ad..84529ca3ed 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java @@ -127,8 +127,8 @@ public abstract class AbstractQueryEnvironment<T> implements QueryEnvironment<T> // QueryEnvironment#buildTransitiveClosure. So if the implementation of that method does some // heavyweight blocking work, then it's best to do this blocking work in a single batch. // Importantly, the callback we pass in needs to maintain order. - final QueryUtil.AggregateAllCallback<T> aggregateAllCallback = - QueryUtil.newOrderedAggregateAllOutputFormatterCallback(); + final QueryUtil.AggregateAllCallback<T, ?> aggregateAllCallback = + QueryUtil.newOrderedAggregateAllOutputFormatterCallback(this); QueryTaskFuture<Void> evalAllFuture = expr.eval(this, context, aggregateAllCallback); return whenSucceedsCall( evalAllFuture, diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/AllPathsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/AllPathsFunction.java index 81be4c8ca2..e0b5a45acd 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/AllPathsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/AllPathsFunction.java @@ -23,6 +23,7 @@ import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.Collection; import java.util.HashSet; import java.util.List; @@ -57,9 +58,9 @@ public class AllPathsFunction implements QueryFunction { final QueryExpression expression, List<Argument> args, final Callback<T> callback) { - final QueryTaskFuture<Set<T>> fromValueFuture = + final QueryTaskFuture<ThreadSafeMutableSet<T>> fromValueFuture = QueryUtil.evalAll(env, context, args.get(0).getExpression()); - final QueryTaskFuture<Set<T>> toValueFuture = + final QueryTaskFuture<ThreadSafeMutableSet<T>> toValueFuture = QueryUtil.evalAll(env, context, args.get(1).getExpression()); return env.whenAllSucceedCall( @@ -73,8 +74,8 @@ public class AllPathsFunction implements QueryFunction { // closure and intersection operations are interleaved for efficiency. // "result" holds the intersection. - Set<T> fromValue = fromValueFuture.getIfSuccessful(); - Set<T> toValue = toValueFuture.getIfSuccessful(); + ThreadSafeMutableSet<T> fromValue = fromValueFuture.getIfSuccessful(); + ThreadSafeMutableSet<T> toValue = toValueFuture.getIfSuccessful(); env.buildTransitiveClosure(expression, fromValue, Integer.MAX_VALUE); @@ -85,7 +86,7 @@ public class AllPathsFunction implements QueryFunction { callback.process(result); Collection<T> worklist = result; while (!worklist.isEmpty()) { - Collection<T> reverseDeps = env.getReverseDeps(worklist); + Iterable<T> reverseDeps = env.getReverseDeps(worklist); worklist = uniquifier.unique(Iterables.filter(reverseDeps, reachable)); callback.process(worklist); } diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/AllRdepsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/AllRdepsFunction.java index d7123a4840..2024b8d4c7 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/AllRdepsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/AllRdepsFunction.java @@ -95,8 +95,10 @@ public class AllRdepsFunction implements QueryFunction { // Filter already visited nodes: if we see a node in a later round, then we don't // need to visit it again, because the depth at which we see it must be greater // than or equal to the last visit. - next.addAll(env.getReverseDeps( - minDepthUniquifier.uniqueAtDepthLessThanOrEqualTo(currentInUniverse, i))); + Iterables.addAll( + next, + env.getReverseDeps( + minDepthUniquifier.uniqueAtDepthLessThanOrEqualTo(currentInUniverse, i))); callback.process(currentInUniverse); if (next.isEmpty()) { // Exit when there are no more nodes to visit. diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java b/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java index f9d20dbb19..cd042ea235 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java @@ -15,9 +15,9 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.base.Function; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import com.google.devtools.build.lib.util.Preconditions; import java.util.ArrayList; import java.util.Collection; @@ -100,12 +100,13 @@ public class BinaryOperatorExpression extends QueryExpression { final QueryEnvironment<T> env, final VariableContext<T> context, final Callback<T> callback) { - QueryTaskFuture<Set<T>> lhsValueFuture = QueryUtil.evalAll(env, context, operands.get(0)); - Function<Set<T>, QueryTaskFuture<Void>> substractAsyncFunction = - new Function<Set<T>, QueryTaskFuture<Void>>() { + QueryTaskFuture<ThreadSafeMutableSet<T>> lhsValueFuture = + QueryUtil.evalAll(env, context, operands.get(0)); + Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>> subtractAsyncFunction = + new Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>>() { @Override - public QueryTaskFuture<Void> apply(Set<T> lhsValue) { - final Set<T> threadSafeLhsValue = Sets.newConcurrentHashSet(lhsValue); + public QueryTaskFuture<Void> apply(ThreadSafeMutableSet<T> lhsValue) { + final Set<T> threadSafeLhsValue = lhsValue; Callback<T> subtractionCallback = new Callback<T>() { @Override public void process(Iterable<T> partialResult) { @@ -127,7 +128,7 @@ public class BinaryOperatorExpression extends QueryExpression { }); } }; - return env.transformAsync(lhsValueFuture, substractAsyncFunction); + return env.transformAsync(lhsValueFuture, subtractAsyncFunction); } private <T> QueryTaskFuture<Void> evalIntersect( @@ -141,20 +142,24 @@ public class BinaryOperatorExpression extends QueryExpression { // TODO(bazel-team): Consider keeping just the name / label of the right-hand side results // instead of the potentially heavy-weight instances of type T. This would let us process all // right-hand side operands in parallel without worrying about memory usage. - QueryTaskFuture<Set<T>> rollingResultFuture = QueryUtil.evalAll(env, context, operands.get(0)); + QueryTaskFuture<ThreadSafeMutableSet<T>> rollingResultFuture = + QueryUtil.evalAll(env, context, operands.get(0)); for (int i = 1; i < operands.size(); i++) { final int index = i; - Function<Set<T>, QueryTaskFuture<Set<T>>> evalOperandAndIntersectAsyncFunction = - new Function<Set<T>, QueryTaskFuture<Set<T>>>() { + Function<ThreadSafeMutableSet<T>, QueryTaskFuture<ThreadSafeMutableSet<T>>> + evalOperandAndIntersectAsyncFunction = + new Function<ThreadSafeMutableSet<T>, QueryTaskFuture<ThreadSafeMutableSet<T>>>() { @Override - public QueryTaskFuture<Set<T>> apply(final Set<T> rollingResult) { - final QueryTaskFuture<Set<T>> rhsOperandValueFuture = + public QueryTaskFuture<ThreadSafeMutableSet<T>> apply( + final ThreadSafeMutableSet<T> rollingResult) { + final QueryTaskFuture<ThreadSafeMutableSet<T>> rhsOperandValueFuture = QueryUtil.evalAll(env, context, operands.get(index)); return env.whenSucceedsCall( rhsOperandValueFuture, - new QueryTaskCallable<Set<T>>() { + new QueryTaskCallable<ThreadSafeMutableSet<T>>() { @Override - public Set<T> call() throws QueryException, InterruptedException { + public ThreadSafeMutableSet<T> call() + throws QueryException, InterruptedException { rollingResult.retainAll(rhsOperandValueFuture.getIfSuccessful()); return rollingResult; } @@ -164,7 +169,7 @@ public class BinaryOperatorExpression extends QueryExpression { rollingResultFuture = env.transformAsync(rollingResultFuture, evalOperandAndIntersectAsyncFunction); } - final QueryTaskFuture<Set<T>> resultFuture = rollingResultFuture; + final QueryTaskFuture<ThreadSafeMutableSet<T>> resultFuture = rollingResultFuture; return env.whenSucceedsCall( resultFuture, new QueryTaskCallable<Void>() { diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/BuildFilesFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/BuildFilesFunction.java index c64cb6c5fa..e0feac5de1 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/BuildFilesFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/BuildFilesFunction.java @@ -15,13 +15,12 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.devtools.build.lib.collect.CompactHashSet; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; -import java.util.Set; /** * A buildfiles(x) query expression, which computes the set of BUILD files and @@ -55,7 +54,7 @@ public class BuildFilesFunction implements QueryFunction { @Override public void process(Iterable<T> partialResult) throws QueryException, InterruptedException { - Set<T> result = CompactHashSet.create(); + ThreadSafeMutableSet<T> result = env.createThreadSafeMutableSet(); Iterables.addAll(result, partialResult); callback.process(uniquifier.unique( env.getBuildFiles( diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/DepsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/DepsFunction.java index 8b1fc37d62..de4cd341dd 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/DepsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/DepsFunction.java @@ -14,14 +14,13 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; +import com.google.common.collect.Iterables; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; -import java.util.Collection; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; -import java.util.Set; /** * A "deps" query expression, which computes the dependencies of the argument. An optional @@ -64,8 +63,9 @@ final class DepsFunction implements QueryFunction { return env.eval(args.get(0).getExpression(), context, new Callback<T>() { @Override public void process(Iterable<T> partialResult) throws QueryException, InterruptedException { - Collection<T> current = Sets.newHashSet(partialResult); - env.buildTransitiveClosure(expression, (Set<T>) current, depthBound); + ThreadSafeMutableSet<T> current = env.createThreadSafeMutableSet(); + Iterables.addAll(current, partialResult); + env.buildTransitiveClosure(expression, current, depthBound); // We need to iterate depthBound + 1 times. for (int i = 0; i <= depthBound; i++) { @@ -75,7 +75,8 @@ final class DepsFunction implements QueryFunction { ImmutableList<T> toProcess = minDepthUniquifier.uniqueAtDepthLessThanOrEqualTo(current, i); callback.process(toProcess); - current = ImmutableList.copyOf(env.getFwdDeps(toProcess)); + current = env.createThreadSafeMutableSet(); + Iterables.addAll(current, env.getFwdDeps(toProcess)); if (current.isEmpty()) { // Exit when there are no more nodes to visit. break; diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/LetExpression.java b/src/main/java/com/google/devtools/build/lib/query2/engine/LetExpression.java index a7c3abeb62..05021e6f8a 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/LetExpression.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/LetExpression.java @@ -15,8 +15,8 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.base.Function; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.Collection; -import java.util.Set; import java.util.regex.Pattern; /** @@ -74,11 +74,12 @@ class LetExpression extends QueryExpression { return env.immediateFailedFuture( new QueryException(this, "invalid variable name '" + varName + "' in let expression")); } - QueryTaskFuture<Set<T>> varValueFuture = QueryUtil.evalAll(env, context, varExpr); - Function<Set<T>, QueryTaskFuture<Void>> evalBodyAsyncFunction = - new Function<Set<T>, QueryTaskFuture<Void>>() { + QueryTaskFuture<ThreadSafeMutableSet<T>> varValueFuture = + QueryUtil.evalAll(env, context, varExpr); + Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>> evalBodyAsyncFunction = + new Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>>() { @Override - public QueryTaskFuture<Void> apply(Set<T> varValue) { + public QueryTaskFuture<Void> apply(ThreadSafeMutableSet<T> varValue) { VariableContext<T> bodyContext = VariableContext.with(context, varName, varValue); return env.eval(bodyExpr, bodyContext, callback); } diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/LoadFilesFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/LoadFilesFunction.java index dd25f15aa2..a008779ae5 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/LoadFilesFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/LoadFilesFunction.java @@ -15,10 +15,9 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.devtools.build.lib.collect.CompactHashSet; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; -import java.util.Set; /** * A loadfiles(x) query expression, which computes the set of .bzl files @@ -51,7 +50,7 @@ public class LoadFilesFunction implements QueryEnvironment.QueryFunction { @Override public void process(Iterable<T> partialResult) throws QueryException, InterruptedException { - Set<T> result = CompactHashSet.create(); + ThreadSafeMutableSet<T> result = env.createThreadSafeMutableSet(); Iterables.addAll(result, partialResult); callback.process(uniquifier.unique( env.getBuildFiles( diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java index 8281e5bf11..453a049cb8 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java @@ -16,17 +16,20 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.base.Function; import com.google.common.collect.ImmutableList; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; -import java.util.Collection; import java.util.List; import java.util.Set; import java.util.concurrent.Callable; import javax.annotation.Nonnull; +import javax.annotation.Nullable; /** * The environment of a Blaze query. Implementations do not need to be thread-safe. The generic type * T represents a node of the graph on which the query runs; as such, there is no restriction on T. * However, query assumes a certain graph model, and the {@link TargetAccessor} class is used to - * access properties of these nodes. + * access properties of these nodes. Also, the query engine doesn't assume T's + * {@link Object#hashCode} and {@link Object#equals} are meaningful and instead uses + * {@link QueryEnvironment#createUniquifier}, {@link QueryEnvironment#createThreadSafeMutableSet()}, + * and {@link QueryEnvironment#createMutableMap()} when appropriate. * * @param <T> the node type of the dependency graph */ @@ -152,16 +155,17 @@ public interface QueryEnvironment<T> { T getOrCreate(T target); /** Returns the direct forward dependencies of the specified targets. */ - Collection<T> getFwdDeps(Iterable<T> targets) throws InterruptedException; + Iterable<T> getFwdDeps(Iterable<T> targets) throws InterruptedException; /** Returns the direct reverse dependencies of the specified targets. */ - Collection<T> getReverseDeps(Iterable<T> targets) throws InterruptedException; + Iterable<T> getReverseDeps(Iterable<T> targets) throws InterruptedException; /** * Returns the forward transitive closure of all of the targets in "targets". Callers must ensure * that {@link #buildTransitiveClosure} has been called for the relevant subgraph. */ - Set<T> getTransitiveClosure(Set<T> targets) throws InterruptedException; + ThreadSafeMutableSet<T> getTransitiveClosure(ThreadSafeMutableSet<T> targets) + throws InterruptedException; /** * Construct the dependency graph for a depth-bounded forward transitive closure @@ -173,11 +177,11 @@ public interface QueryEnvironment<T> { * after it is built anyway. */ void buildTransitiveClosure(QueryExpression caller, - Set<T> targetNodes, + ThreadSafeMutableSet<T> targetNodes, int maxDepth) throws QueryException, InterruptedException; - /** Returns the set of nodes on some path from "from" to "to". */ - Set<T> getNodesOnPath(T from, T to) throws InterruptedException; + /** Returns the ordered sequence of nodes on some path from "from" to "to". */ + Iterable<T> getNodesOnPath(T from, T to) throws InterruptedException; /** * Returns a {@link QueryTaskFuture} representing the asynchronous evaluation of the given @@ -335,6 +339,41 @@ public interface QueryEnvironment<T> { } /** + * A mutable {@link ThreadSafe} {@link Set} that uses proper equality semantics for {@code T}. + * {@link QueryExpression}/{@link QueryFunction} implementations should use + * {@code ThreadSafeMutableSet<T>} they need a set-like data structure for {@code T}. + */ + @ThreadSafe + interface ThreadSafeMutableSet<T> extends Set<T> { + } + + /** Returns a fresh {@link ThreadSafeMutableSet} instance for the type {@code T}. */ + ThreadSafeMutableSet<T> createThreadSafeMutableSet(); + + /** + * A simple map-like interface that uses proper equality semantics for the key type. + * {@link QueryExpression}/{@link QueryFunction} implementations should use + * {@code ThreadSafeMutableSet<T, V>} they need a map-like data structure for {@code T}. + */ + interface MutableMap<K, V> { + /** + * Returns the value {@code value} associated with the given key by the most recent call to + * {@code put(key, value)}, or {@code null} if there was no such call. + */ + @Nullable + V get(K key); + + /** + * Associates the given key with the given value and returns the previous value associated with + * the key, or {@code null} if there wasn't one. + */ + V put(K key, V value); + } + + /** Returns a fresh {@link MutableMap} instance with key type {@code T}. */ + <V> MutableMap<T, V> createMutableMap(); + + /** * Creates a Uniquifier for use in a {@code QueryExpression}. Note that the usage of this * uniquifier should not be used for returning unique results to the parent callback. It should * only be used to avoid processing the same elements multiple times within this QueryExpression. @@ -355,9 +394,9 @@ public interface QueryEnvironment<T> { * Returns the set of BUILD, and optionally sub-included and Skylark files that define the given * set of targets. Each such file is itself represented as a target in the result. */ - Set<T> getBuildFiles( - QueryExpression caller, Set<T> nodes, boolean buildFiles, boolean subincludes, boolean loads) - throws QueryException, InterruptedException; + ThreadSafeMutableSet<T> getBuildFiles( + QueryExpression caller, ThreadSafeMutableSet<T> nodes, boolean buildFiles, + boolean subincludes, boolean loads) throws QueryException, InterruptedException; /** * Returns an object that can be used to query information about targets. Implementations should diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java index b423803744..73dd930ac3 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java @@ -16,14 +16,21 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.MapMaker; -import com.google.common.collect.Sets; import com.google.devtools.build.lib.collect.CompactHashSet; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; +import java.util.AbstractSet; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; /** Several query utilities to make easier to work with query callbacks and uniquifiers. */ public final class QueryUtil { @@ -31,19 +38,23 @@ public final class QueryUtil { private QueryUtil() { } /** A {@link Callback} that can aggregate all the partial results into one set. */ - public interface AggregateAllCallback<T> extends Callback<T> { - /** Returns a (mutable) set of all the results. */ - Set<T> getResult(); + public interface AggregateAllCallback<T, S extends Set<T>> extends Callback<T> { + /** Returns a {@link Set} of all the results. */ + S getResult(); } /** A {@link OutputFormatterCallback} that is also a {@link AggregateAllCallback}. */ - public abstract static class AggregateAllOutputFormatterCallback<T> - extends ThreadSafeOutputFormatterCallback<T> implements AggregateAllCallback<T> { + public abstract static class AggregateAllOutputFormatterCallback<T, S extends Set<T>> + extends ThreadSafeOutputFormatterCallback<T> implements AggregateAllCallback<T, S> { } private static class AggregateAllOutputFormatterCallbackImpl<T> - extends AggregateAllOutputFormatterCallback<T> { - private final Set<T> result = Sets.newConcurrentHashSet(); + extends AggregateAllOutputFormatterCallback<T, ThreadSafeMutableSet<T>> { + private final ThreadSafeMutableSet<T> result; + + private AggregateAllOutputFormatterCallbackImpl(QueryEnvironment<T> env) { + this.result = env.createThreadSafeMutableSet(); + } @Override public final void processOutput(Iterable<T> partialResult) { @@ -51,22 +62,35 @@ public final class QueryUtil { } @Override - public Set<T> getResult() { + public ThreadSafeMutableSet<T> getResult() { return result; } } private static class OrderedAggregateAllOutputFormatterCallbackImpl<T> - extends AggregateAllOutputFormatterCallback<T> { - private final Set<T> result = CompactHashSet.create(); + extends AggregateAllOutputFormatterCallback<T, Set<T>> { + private final Set<T> resultSet; + private final List<T> resultList; + + private OrderedAggregateAllOutputFormatterCallbackImpl(QueryEnvironment<T> env) { + this.resultSet = env.createThreadSafeMutableSet(); + this.resultList = new ArrayList<>(); + } @Override public final synchronized void processOutput(Iterable<T> partialResult) { - Iterables.addAll(result, partialResult); + for (T element : partialResult) { + if (resultSet.add(element)) { + resultList.add(element); + } + } } @Override public synchronized Set<T> getResult() { + // A CompactHashSet's iteration order is the same as its insertion order. + CompactHashSet<T> result = CompactHashSet.createWithExpectedSize(resultList.size()); + result.addAll(resultList); return result; } } @@ -76,35 +100,120 @@ public final class QueryUtil { * {@link AggregateAllCallback#getResult} returns all the elements of the result in the order they * were processed. */ - public static <T> AggregateAllOutputFormatterCallback<T> - newOrderedAggregateAllOutputFormatterCallback() { - return new OrderedAggregateAllOutputFormatterCallbackImpl<>(); + public static <T> AggregateAllOutputFormatterCallback<T, Set<T>> + newOrderedAggregateAllOutputFormatterCallback(QueryEnvironment<T> env) { + return new OrderedAggregateAllOutputFormatterCallbackImpl<>(env); } /** Returns a fresh {@link AggregateAllCallback} instance. */ - public static <T> AggregateAllCallback<T> newAggregateAllCallback() { - return new AggregateAllOutputFormatterCallbackImpl<>(); + public static <T> AggregateAllCallback<T, ThreadSafeMutableSet<T>> newAggregateAllCallback( + QueryEnvironment<T> env) { + return new AggregateAllOutputFormatterCallbackImpl<>(env); } /** - * Returns a {@link QueryTaskFuture} representing the evaluation of {@code expr} as a (mutable) - * {@link Set} comprised of all the results. + * Returns a {@link QueryTaskFuture} representing the evaluation of {@code expr} as a mutable, + * thread safe {@link Set} comprised of all the results. * * <p>Should only be used by QueryExpressions when it is the only way of achieving correctness. */ - public static <T> QueryTaskFuture<Set<T>> evalAll( + public static <T> QueryTaskFuture<ThreadSafeMutableSet<T>> evalAll( QueryEnvironment<T> env, VariableContext<T> context, QueryExpression expr) { - final AggregateAllCallback<T> callback = newAggregateAllCallback(); + final AggregateAllCallback<T, ThreadSafeMutableSet<T>> callback = newAggregateAllCallback(env); return env.whenSucceedsCall( env.eval(expr, context, callback), - new QueryTaskCallable<Set<T>>() { + new QueryTaskCallable<ThreadSafeMutableSet<T>>() { @Override - public Set<T> call() { + public ThreadSafeMutableSet<T> call() { return callback.getResult(); } }); } + /** + * A mutable thread safe {@link Set} that uses a {@link KeyExtractor} for determining equality of + * its elements. This is useful e.g. when {@code T} isn't guaranteed to have a useful + * {@link Object#equals} and {@link Object#hashCode} but {@code K} is. + */ + public static class ThreadSafeMutableKeyExtractorBackedSetImpl<T, K> + extends AbstractSet<T> implements ThreadSafeMutableSet<T> { + private final KeyExtractor<T, K> extractor; + private final Class<T> elementClass; + private final ConcurrentMap<K, T> map; + + public ThreadSafeMutableKeyExtractorBackedSetImpl( + KeyExtractor<T, K> extractor, Class<T> elementClass) { + this(extractor, elementClass, /*concurrencyLevel=*/ 1); + } + + public ThreadSafeMutableKeyExtractorBackedSetImpl( + KeyExtractor<T, K> extractor, + Class<T> elementClass, + int concurrencyLevel) { + this.extractor = extractor; + this.elementClass = elementClass; + this.map = new MapMaker().concurrencyLevel(concurrencyLevel).makeMap(); + } + + @Override + public Iterator<T> iterator() { + return map.values().iterator(); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean add(T element) { + return map.putIfAbsent(extractor.extractKey(element), element) == null; + } + + @Override + public boolean contains(Object obj) { + if (!elementClass.isInstance(obj)) { + return false; + } + T element = elementClass.cast(obj); + return map.containsKey(extractor.extractKey(element)); + } + + @Override + public boolean remove(Object obj) { + if (!elementClass.isInstance(obj)) { + return false; + } + T element = elementClass.cast(obj); + return map.remove(extractor.extractKey(element)) != null; + } + } + + /** + * A {@link MutableMap} implementation that uses a {@link KeyExtractor} for determining equality + * of its keys. + */ + public static class MutableKeyExtractorBackedMapImpl<T, K, V> implements MutableMap<T, V> { + private final KeyExtractor<T, K> extractor; + private final HashMap<K, V> map; + + public MutableKeyExtractorBackedMapImpl(KeyExtractor<T, K> extractor) { + this.extractor = extractor; + this.map = new HashMap<>(); + } + + @Override + @Nullable + public V get(T key) { + return map.get(extractor.extractKey(key)); + } + + @Override + public V put(T key, V value) { + return map.put(extractor.extractKey(key), value); + } + } + /** A trivial {@link Uniquifier} implementation. */ public static class UniquifierImpl<T, K> implements Uniquifier<T> { private final KeyExtractor<T, K> extractor; diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/RdepsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/RdepsFunction.java index 82faf72538..81c0cfcaf0 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/RdepsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/RdepsFunction.java @@ -21,8 +21,8 @@ import com.google.common.collect.ImmutableList; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; -import java.util.Set; /** * An "rdeps" query expression, which computes the reverse dependencies of the argument within the @@ -62,12 +62,12 @@ public final class RdepsFunction extends AllRdepsFunction { final QueryExpression expression, final List<Argument> args, final Callback<T> callback) { - QueryTaskFuture<Set<T>> universeValueFuture = + QueryTaskFuture<ThreadSafeMutableSet<T>> universeValueFuture = QueryUtil.evalAll(env, context, args.get(0).getExpression()); - Function<Set<T>, QueryTaskFuture<Void>> evalInUniverseAsyncFunction = - new Function<Set<T>, QueryTaskFuture<Void>>() { + Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>> evalInUniverseAsyncFunction = + new Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>>() { @Override - public QueryTaskFuture<Void> apply(Set<T> universeValue) { + public QueryTaskFuture<Void> apply(ThreadSafeMutableSet<T> universeValue) { Predicate<T> universe; try { env.buildTransitiveClosure(expression, universeValue, Integer.MAX_VALUE); diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/SomePathFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/SomePathFunction.java index 229863c79a..93650d5fbb 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/SomePathFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/SomePathFunction.java @@ -22,8 +22,8 @@ import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; -import java.util.Set; /** * A somepath(x, y) query expression, which computes the set of nodes @@ -57,9 +57,9 @@ class SomePathFunction implements QueryFunction { final QueryExpression expression, List<Argument> args, final Callback<T> callback) { - final QueryTaskFuture<Set<T>> fromValueFuture = + final QueryTaskFuture<ThreadSafeMutableSet<T>> fromValueFuture = QueryUtil.evalAll(env, context, args.get(0).getExpression()); - final QueryTaskFuture<Set<T>> toValueFuture = + final QueryTaskFuture<ThreadSafeMutableSet<T>> toValueFuture = QueryUtil.evalAll(env, context, args.get(1).getExpression()); return env.whenAllSucceedCall( @@ -72,8 +72,8 @@ class SomePathFunction implements QueryFunction { // to an arbitrary node in the intersection, and return the path. This // avoids computing the full transitive closure of "from" in some cases. - Set<T> fromValue = fromValueFuture.getIfSuccessful(); - Set<T> toValue = toValueFuture.getIfSuccessful(); + ThreadSafeMutableSet<T> fromValue = fromValueFuture.getIfSuccessful(); + ThreadSafeMutableSet<T> toValue = toValueFuture.getIfSuccessful(); env.buildTransitiveClosure(expression, fromValue, Integer.MAX_VALUE); @@ -81,7 +81,9 @@ class SomePathFunction implements QueryFunction { Uniquifier<T> uniquifier = env.createUniquifier(); for (T x : uniquifier.unique(fromValue)) { - Set<T> xtc = env.getTransitiveClosure(ImmutableSet.of(x)); + ThreadSafeMutableSet<T> xSet = env.createThreadSafeMutableSet(); + xSet.add(x); + ThreadSafeMutableSet<T> xtc = env.getTransitiveClosure(xSet); SetView<T> result; if (xtc.size() > toValue.size()) { result = Sets.intersection(toValue, xtc); diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java index d9ed576a5d..15950d7114 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java @@ -14,20 +14,19 @@ package com.google.devtools.build.lib.query2.engine; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Setting; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Set; /** @@ -146,7 +145,7 @@ class TestsFunction implements QueryFunction { private static final class Closure<T> { private final QueryExpression expression; /** A dynamically-populated mapping from test_suite rules to their tests. */ - private final Map<T, Set<T>> testsInSuite = new HashMap<>(); + private final MutableMap<T, ThreadSafeMutableSet<T>> testsInSuite; /** The environment in which this query is being evaluated. */ private final QueryEnvironment<T> env; @@ -157,6 +156,7 @@ class TestsFunction implements QueryFunction { this.expression = expression; this.env = env; this.strict = env.isSettingEnabled(Setting.TESTS_EXPRESSION_STRICT); + this.testsInSuite = env.createMutableMap(); } /** @@ -165,11 +165,11 @@ class TestsFunction implements QueryFunction { * * @precondition env.getAccessor().isTestSuite(testSuite) */ - private synchronized Set<T> getTestsInSuite(T testSuite) + private synchronized ThreadSafeMutableSet<T> getTestsInSuite(T testSuite) throws QueryException, InterruptedException { - Set<T> tests = testsInSuite.get(testSuite); + ThreadSafeMutableSet<T> tests = testsInSuite.get(testSuite); if (tests == null) { - tests = Sets.newHashSet(); + tests = env.createThreadSafeMutableSet(); testsInSuite.put(testSuite, tests); // break cycles by inserting empty set early. computeTestsInSuite(testSuite, tests); } @@ -184,7 +184,7 @@ class TestsFunction implements QueryFunction { * * @precondition env.getAccessor().isTestSuite(testSuite) */ - private void computeTestsInSuite(T testSuite, Set<T> result) + private void computeTestsInSuite(T testSuite, ThreadSafeMutableSet<T> result) throws QueryException, InterruptedException { List<T> testsAndSuites = new ArrayList<>(); // Note that testsAndSuites can contain input file targets; the test_suite rule does not @@ -245,7 +245,7 @@ class TestsFunction implements QueryFunction { * @precondition {@code env.getAccessor().isTestSuite(testSuite)} * @precondition {@code env.getAccessor().isTestRule(test)} for all test in tests */ - private void filterTests(T testSuite, Set<T> tests) { + private void filterTests(T testSuite, ThreadSafeMutableSet<T> tests) { List<String> tagsAttribute = env.getAccessor().getStringListAttr(testSuite, "tags"); // Split the tags list into positive and negative tags Set<String> requiredTags = new HashSet<>(); diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/VisibleFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/VisibleFunction.java index b09910c715..7a364b8285 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/engine/VisibleFunction.java +++ b/src/main/java/com/google/devtools/build/lib/query2/engine/VisibleFunction.java @@ -20,6 +20,7 @@ import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction; import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture; +import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet; import java.util.List; import java.util.Set; @@ -59,12 +60,12 @@ public class VisibleFunction implements QueryFunction { QueryExpression expression, final List<Argument> args, final Callback<T> callback) { - final QueryTaskFuture<Set<T>> toSetFuture = + final QueryTaskFuture<ThreadSafeMutableSet<T>> toSetFuture = QueryUtil.evalAll(env, context, args.get(0).getExpression()); - Function<Set<T>, QueryTaskFuture<Void>> computeVisibleNodesAsyncFunction = - new Function<Set<T>, QueryTaskFuture<Void>>() { + Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>> computeVisibleNodesAsyncFunction = + new Function<ThreadSafeMutableSet<T>, QueryTaskFuture<Void>>() { @Override - public QueryTaskFuture<Void> apply(final Set<T> toSet) { + public QueryTaskFuture<Void> apply(final ThreadSafeMutableSet<T> toSet) { return env.eval(args.get(1).getExpression(), context, new Callback<T>() { @Override public void process(Iterable<T> partialResult) diff --git a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java index f5abf25019..fbd0943744 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java +++ b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java @@ -284,8 +284,7 @@ public class GenQuery implements RuleConfiguredTargetFactory { DigraphQueryEvalResult<Target> queryResult; OutputFormatter formatter; - AggregateAllOutputFormatterCallback<Target> targets = - QueryUtil.newOrderedAggregateAllOutputFormatterCallback(); + AggregateAllOutputFormatterCallback<Target, ?> targets; try { Set<Setting> settings = queryOptions.toSettings(); @@ -327,6 +326,7 @@ public class GenQuery implements RuleConfiguredTargetFactory { /*blockUniverseEvaluationErrors=*/ false); QueryExpression expr = QueryExpression.parse(query, queryEnvironment); formatter.verifyCompatible(queryEnvironment, expr); + targets = QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnvironment); queryResult = queryEnvironment.evaluateQuery(expr, targets); } catch (SkyframeRestartQueryException e) { // Do not emit errors for skyframe restarts. They make output of the ConfiguredTargetFunction diff --git a/src/main/java/com/google/devtools/build/lib/runtime/commands/QueryCommand.java b/src/main/java/com/google/devtools/build/lib/runtime/commands/QueryCommand.java index 8881db18ba..c9d31ba945 100644 --- a/src/main/java/com/google/devtools/build/lib/runtime/commands/QueryCommand.java +++ b/src/main/java/com/google/devtools/build/lib/runtime/commands/QueryCommand.java @@ -167,7 +167,7 @@ public final class QueryCommand implements BlazeCommand { queryOptions.aspectDeps.createResolver(env.getPackageManager(), env.getReporter())); callback = streamedFormatter.createStreamCallback(out, queryOptions, queryEnv); } else { - callback = QueryUtil.newOrderedAggregateAllOutputFormatterCallback(); + callback = QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnv); } boolean catastrophe = true; try { @@ -211,7 +211,8 @@ public final class QueryCommand implements BlazeCommand { if (!streamResults) { disableAnsiCharactersFiltering(env); try { - Set<Target> targets = ((AggregateAllOutputFormatterCallback<Target>) callback).getResult(); + Set<Target> targets = + ((AggregateAllOutputFormatterCallback<Target, ?>) callback).getResult(); QueryOutputUtils.output( queryOptions, result, |