diff options
5 files changed, 100 insertions, 14 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfiguration.java b/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfiguration.java index 1b86f7b8e9..83d622fdbf 100644 --- a/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfiguration.java +++ b/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfiguration.java @@ -34,11 +34,14 @@ public class PlatformConfiguration extends BuildConfiguration.Fragment { private final Label executionPlatform; private final ImmutableList<Label> targetPlatforms; + private final ImmutableList<Label> extraToolchains; - public PlatformConfiguration(Label executionPlatform, List<Label> targetPlatforms) { + public PlatformConfiguration( + Label executionPlatform, List<Label> targetPlatforms, List<Label> extraToolchains) { this.executionPlatform = executionPlatform; this.targetPlatforms = ImmutableList.copyOf(targetPlatforms); + this.extraToolchains = ImmutableList.copyOf(extraToolchains); } @SkylarkCallable( @@ -54,4 +57,9 @@ public class PlatformConfiguration extends BuildConfiguration.Fragment { public ImmutableList<Label> getTargetPlatforms() { return targetPlatforms; } + + /** Additional toolchains that should be considered during toolchain resolution. */ + public ImmutableList<Label> getExtraToolchains() { + return extraToolchains; + } } diff --git a/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfigurationLoader.java b/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfigurationLoader.java index 85cc1748aa..1eedd1265b 100644 --- a/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfigurationLoader.java +++ b/src/main/java/com/google/devtools/build/lib/analysis/PlatformConfigurationLoader.java @@ -47,6 +47,6 @@ public class PlatformConfigurationLoader implements ConfigurationFragmentFactory // TODO(katre): This will change with remote execution. Label executionPlatform = options.hostPlatform; - return new PlatformConfiguration(executionPlatform, options.platforms); + return new PlatformConfiguration(executionPlatform, options.platforms, options.extraToolchains); } } diff --git a/src/main/java/com/google/devtools/build/lib/analysis/PlatformOptions.java b/src/main/java/com/google/devtools/build/lib/analysis/PlatformOptions.java index ed2e45248b..47d64eefa2 100644 --- a/src/main/java/com/google/devtools/build/lib/analysis/PlatformOptions.java +++ b/src/main/java/com/google/devtools/build/lib/analysis/PlatformOptions.java @@ -16,6 +16,7 @@ package com.google.devtools.build.lib.analysis; import com.google.common.collect.ImmutableList; import com.google.devtools.build.lib.analysis.config.BuildConfiguration; +import com.google.devtools.build.lib.analysis.config.BuildConfiguration.LabelListConverter; import com.google.devtools.build.lib.analysis.config.FragmentOptions; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.common.options.Option; @@ -51,6 +52,17 @@ public class PlatformOptions extends FragmentOptions { ) public List<Label> platforms; + @Option( + name = "extra_toolchains", + converter = LabelListConverter.class, + defaultValue = "", + documentationCategory = OptionDocumentationCategory.UNDOCUMENTED, + effectTags = {OptionEffectTag.UNKNOWN}, + metadataTags = {OptionMetadataTag.HIDDEN}, + help = "Extra toolchains to be considered during toolchain resolution." + ) + public List<Label> extraToolchains; + @Override public PlatformOptions getHost(boolean fallback) { PlatformOptions host = (PlatformOptions) getDefault(); diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunction.java index e77c53d2b9..666da966c1 100644 --- a/src/main/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunction.java +++ b/src/main/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunction.java @@ -16,10 +16,12 @@ package com.google.devtools.build.lib.skyframe; import com.google.common.collect.ImmutableList; import com.google.devtools.build.lib.analysis.ConfiguredTarget; +import com.google.devtools.build.lib.analysis.PlatformConfiguration; import com.google.devtools.build.lib.analysis.config.BuildConfiguration; import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo; import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.rules.ExternalPackageUtil; +import com.google.devtools.build.lib.rules.ExternalPackageUtil.ExternalPackageException; import com.google.devtools.build.lib.skyframe.ConfiguredTargetFunction.ConfiguredValueCreationException; import com.google.devtools.build.skyframe.LegacySkyKey; import com.google.devtools.build.skyframe.SkyFunction; @@ -44,15 +46,22 @@ public class RegisteredToolchainsFunction implements SkyFunction { BuildConfiguration configuration = (BuildConfiguration) skyKey.argument(); - // Get the registered toolchains. - List<Label> registeredToolchainLabels = ExternalPackageUtil.getRegisteredToolchainLabels(env); - if (registeredToolchainLabels == null) { + ImmutableList.Builder<Label> registeredToolchainLabels = new ImmutableList.Builder<>(); + + // Get the toolchains from the configuration. + PlatformConfiguration platformConfiguration = + configuration.getFragment(PlatformConfiguration.class); + registeredToolchainLabels.addAll(platformConfiguration.getExtraToolchains()); + + // Get the registered toolchains from the WORKSPACE. + registeredToolchainLabels.addAll(getWorkspaceToolchains(env)); + if (env.valuesMissing()) { return null; } // Load the configured target for each, and get the declared toolchain providers. ImmutableList<DeclaredToolchainInfo> registeredToolchains = - configureRegisteredToolchains(env, configuration, registeredToolchainLabels); + configureRegisteredToolchains(env, configuration, registeredToolchainLabels.build()); if (env.valuesMissing()) { return null; } @@ -60,6 +69,15 @@ public class RegisteredToolchainsFunction implements SkyFunction { return RegisteredToolchainsValue.create(registeredToolchains); } + private Iterable<? extends Label> getWorkspaceToolchains(Environment env) + throws ExternalPackageException, InterruptedException { + List<Label> labels = ExternalPackageUtil.getRegisteredToolchainLabels(env); + if (labels == null) { + return ImmutableList.of(); + } + return labels; + } + private ImmutableList<DeclaredToolchainInfo> configureRegisteredToolchains( Environment env, BuildConfiguration configuration, List<Label> labels) throws InterruptedException, RegisteredToolchainsFunctionException { diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java index 4d9c93f5f9..0626a6c0c4 100644 --- a/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java +++ b/src/test/java/com/google/devtools/build/lib/skyframe/RegisteredToolchainsFunctionTest.java @@ -19,11 +19,15 @@ import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory. import com.google.common.collect.ImmutableList; import com.google.common.testing.EqualsTester; +import com.google.common.truth.IterableSubject; import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo; +import com.google.devtools.build.lib.cmdline.Label; import com.google.devtools.build.lib.rules.platform.ToolchainTestCase; import com.google.devtools.build.lib.skyframe.util.SkyframeExecutorTestUtils; import com.google.devtools.build.skyframe.EvaluationResult; import com.google.devtools.build.skyframe.SkyKey; +import java.util.List; +import java.util.stream.Collectors; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -75,6 +79,38 @@ public class RegisteredToolchainsFunctionTest extends ToolchainTestCase { } @Test + public void testRegisteredToolchains_flagOverride() throws Exception { + + // Add an extra toolchain. + scratch.file( + "extra/BUILD", + "load('//toolchain:toolchain_def.bzl', 'test_toolchain')", + "toolchain(", + " name = 'extra_toolchain',", + " toolchain_type = '//toolchain:test_toolchain',", + " exec_compatible_with = ['//constraint:linux'],", + " target_compatible_with = ['//constraint:linux'],", + " toolchain = ':extra_toolchain_impl')", + "test_toolchain(", + " name='extra_toolchain_impl',", + " data = 'extra')"); + + rewriteWorkspace("register_toolchains('//toolchain:toolchain_1')"); + useConfiguration("--extra_toolchains=//extra:extra_toolchain"); + + SkyKey toolchainsKey = RegisteredToolchainsValue.key(targetConfig); + EvaluationResult<RegisteredToolchainsValue> result = + requestToolchainsFromSkyframe(toolchainsKey); + assertThatEvaluationResult(result).hasNoError(); + + // Verify that the target registered with the extra_toolchains flag is first in the list. + assertToolchainLabels(result.get(toolchainsKey)) + .containsExactly( + makeLabel("//extra:extra_toolchain_impl"), makeLabel("//toolchain:test_toolchain_1")) + .inOrder(); + } + + @Test public void testRegisteredToolchains_notToolchain() throws Exception { rewriteWorkspace("register_toolchains(", " '//error:not_a_toolchain')"); scratch.file("error/BUILD", "filegroup(name = 'not_a_toolchain')"); @@ -98,10 +134,8 @@ public class RegisteredToolchainsFunctionTest extends ToolchainTestCase { EvaluationResult<RegisteredToolchainsValue> result = requestToolchainsFromSkyframe(toolchainsKey); assertThatEvaluationResult(result).hasNoError(); - RegisteredToolchainsValue value = result.get(toolchainsKey); - assertThat(value.registeredToolchains()).hasSize(1); - assertThat(value.registeredToolchains().get(0).toolchainLabel()) - .isEqualTo(makeLabel("//toolchain:test_toolchain_1")); + assertToolchainLabels(result.get(toolchainsKey)) + .containsExactly(makeLabel("//toolchain:test_toolchain_1")); // Re-write the WORKSPACE. rewriteWorkspace("register_toolchains('//toolchain:toolchain_2')"); @@ -109,10 +143,8 @@ public class RegisteredToolchainsFunctionTest extends ToolchainTestCase { toolchainsKey = RegisteredToolchainsValue.key(targetConfig); result = requestToolchainsFromSkyframe(toolchainsKey); assertThatEvaluationResult(result).hasNoError(); - value = result.get(toolchainsKey); - assertThat(value.registeredToolchains()).hasSize(1); - assertThat(value.registeredToolchains().get(0).toolchainLabel()) - .isEqualTo(makeLabel("//toolchain:test_toolchain_2")); + assertToolchainLabels(result.get(toolchainsKey)) + .containsExactly(makeLabel("//toolchain:test_toolchain_2")); } @Test @@ -139,4 +171,20 @@ public class RegisteredToolchainsFunctionTest extends ToolchainTestCase { RegisteredToolchainsValue.create(ImmutableList.of(toolchain2)), RegisteredToolchainsValue.create(ImmutableList.of(toolchain2, toolchain1))); } + + private static IterableSubject assertToolchainLabels( + RegisteredToolchainsValue registeredToolchainsValue) { + assertThat(registeredToolchainsValue).isNotNull(); + ImmutableList<DeclaredToolchainInfo> declaredToolchains = + registeredToolchainsValue.registeredToolchains(); + List<Label> labels = collectToolchainLabels(declaredToolchains); + return assertThat(labels); + } + + private static List<Label> collectToolchainLabels(List<DeclaredToolchainInfo> toolchains) { + return toolchains + .stream() + .map((toolchain -> toolchain.toolchainLabel())) + .collect(Collectors.toList()); + } } |