callback) {
switch (operator) {
case PLUS:
case UNION:
return evalPlus(operands, env, context, callback);
case MINUS:
case EXCEPT:
return evalMinus(operands, env, context, callback);
case INTERSECT:
case CARET:
return evalIntersect(env, context, callback);
default:
throw new IllegalStateException(operator.toString());
}
}
/**
* Evaluates an expression of the form "e1 + e2 + ... + eK" by evaluating all the subexpressions
* separately.
*
* N.B. {@code operands.size()} may be {@code 1}.
*/
private static QueryTaskFuture evalPlus(
ImmutableList operands,
QueryEnvironment env,
VariableContext context,
Callback callback) {
ArrayList> queryTasks = new ArrayList<>(operands.size());
for (QueryExpression operand : operands) {
queryTasks.add(env.eval(operand, context, callback));
}
return env.whenAllSucceed(queryTasks);
}
/**
* Evaluates an expression of the form "e1 - e2 - ... - eK" by noting its equivalence to
* "e1 - (e2 + ... + eK)" and evaluating the subexpressions on the right-hand-side separately.
*/
private static QueryTaskFuture evalMinus(
final ImmutableList operands,
final QueryEnvironment env,
final VariableContext context,
final Callback callback) {
QueryTaskFuture> lhsValueFuture = QueryUtil.evalAll(env, context, operands.get(0));
Function, QueryTaskFuture> substractAsyncFunction =
new Function, QueryTaskFuture>() {
@Override
public QueryTaskFuture apply(Set lhsValue) {
final Set threadSafeLhsValue = Sets.newConcurrentHashSet(lhsValue);
Callback subtractionCallback = new Callback() {
@Override
public void process(Iterable partialResult) {
for (T target : partialResult) {
threadSafeLhsValue.remove(target);
}
}
};
QueryTaskFuture rhsEvaluatedFuture = evalPlus(
operands.subList(1, operands.size()), env, context, subtractionCallback);
return env.whenSucceedsCall(
rhsEvaluatedFuture,
new QueryTaskCallable() {
@Override
public Void call() throws QueryException, InterruptedException {
callback.process(threadSafeLhsValue);
return null;
}
});
}
};
return env.transformAsync(lhsValueFuture, substractAsyncFunction);
}
private QueryTaskFuture evalIntersect(
final QueryEnvironment env,
final VariableContext context,
final Callback callback) {
// For each right-hand side operand, intersection cannot be performed in a streaming manner; the
// entire result of that operand is needed. So, in order to avoid pinning too much in memory at
// once, we process each right-hand side operand one at a time and throw away that operand's
// result.
// TODO(bazel-team): Consider keeping just the name / label of the right-hand side results
// instead of the potentially heavy-weight instances of type T. This would let us process all
// right-hand side operands in parallel without worrying about memory usage.
QueryTaskFuture> rollingResultFuture = QueryUtil.evalAll(env, context, operands.get(0));
for (int i = 1; i < operands.size(); i++) {
final int index = i;
Function, QueryTaskFuture>> evalOperandAndIntersectAsyncFunction =
new Function, QueryTaskFuture>>() {
@Override
public QueryTaskFuture> apply(final Set rollingResult) {
final QueryTaskFuture> rhsOperandValueFuture =
QueryUtil.evalAll(env, context, operands.get(index));
return env.whenSucceedsCall(
rhsOperandValueFuture,
new QueryTaskCallable>() {
@Override
public Set call() throws QueryException, InterruptedException {
rollingResult.retainAll(rhsOperandValueFuture.getIfSuccessful());
return rollingResult;
}
});
}
};
rollingResultFuture =
env.transformAsync(rollingResultFuture, evalOperandAndIntersectAsyncFunction);
}
final QueryTaskFuture> resultFuture = rollingResultFuture;
return env.whenSucceedsCall(
resultFuture,
new QueryTaskCallable() {
@Override
public Void call() throws QueryException, InterruptedException {
callback.process(resultFuture.getIfSuccessful());
return null;
}
});
}
@Override
public void collectTargetPatterns(Collection literals) {
for (QueryExpression subExpression : operands) {
subExpression.collectTargetPatterns(literals);
}
}
@Override
public QueryExpression getMapped(QueryExpressionMapper mapper) {
return mapper.map(this);
}
@Override
public String toString() {
StringBuilder result = new StringBuilder();
for (int i = 1; i < operands.size(); i++) {
result.append("(");
}
result.append(operands.get(0));
for (int i = 1; i < operands.size(); i++) {
result.append(" " + operator.getPrettyName() + " " + operands.get(i) + ")");
}
return result.toString();
}
}