diff options
Diffstat (limited to 'src')
11 files changed, 570 insertions, 74 deletions
diff --git a/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java b/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java index 656288b7c0..2d4f4e9bbc 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java @@ -165,7 +165,9 @@ public final class SimpleBlobStoreActionCache extends AbstractRemoteActionCache } private Digest uploadBlob(byte[] blob, Digest digest) throws IOException, InterruptedException { - return uploadStream(digest, new ByteArrayInputStream(blob)); + try (InputStream in = new ByteArrayInputStream(blob)) { + return uploadStream(digest, in); + } } public Digest uploadStream(Digest digest, InputStream in) diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/SimpleBlobStore.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/SimpleBlobStore.java index 4231060603..b7e4db2c3a 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/SimpleBlobStore.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/SimpleBlobStore.java @@ -32,21 +32,27 @@ public interface SimpleBlobStore { /** * Fetches the BLOB associated with the {@code key} from the CAS and writes it to {@code out}. * + * <p>The caller is responsible to close {@code out}. + * * @return {@code true} if the {@code key} was found. {@code false} otherwise. */ boolean get(String key, OutputStream out) throws IOException, InterruptedException; /** - * Fetches the BLOB associated with the {@code key} from the Action Cache and writes it to - * {@code out}. + * Fetches the BLOB associated with the {@code key} from the Action Cache and writes it to {@code + * out}. + * + * <p>The caller is responsible to close {@code out}. * * @return {@code true} if the {@code key} was found. {@code false} otherwise. */ - boolean getActionResult(String actionKey, OutputStream out) throws IOException, - InterruptedException; + boolean getActionResult(String actionKey, OutputStream out) + throws IOException, InterruptedException; /** * Uploads a BLOB (as {@code in}) with length {@code length} indexed by {@code key} to the CAS. + * + * <p>The caller is responsible to close {@code in}. */ void put(String key, long length, InputStream in) throws IOException, InterruptedException; diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/AbstractHttpHandler.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/AbstractHttpHandler.java index fc8c14a005..0c4c8e2bee 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/AbstractHttpHandler.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/AbstractHttpHandler.java @@ -26,6 +26,7 @@ import io.netty.handler.codec.http.HttpRequest; import java.io.IOException; import java.net.SocketAddress; import java.net.URI; +import java.nio.channels.ClosedChannelException; import java.util.List; import java.util.Map; @@ -93,56 +94,61 @@ abstract class AbstractHttpHandler<T extends HttpObject> extends SimpleChannelIn } @Override - public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable throwable) - throws Exception { + public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable throwable) { failAndResetUserPromise(throwable); } @SuppressWarnings("FutureReturnValueIgnored") @Override - public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) - throws Exception { + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { ctx.bind(localAddress, promise); } - @SuppressWarnings("FutureReturnValueIgnored") + @SuppressWarnings("FutureReturnValueIgnored") @Override public void connect( ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, - ChannelPromise promise) - throws Exception { + ChannelPromise promise) { ctx.connect(remoteAddress, localAddress, promise); } @SuppressWarnings("FutureReturnValueIgnored") @Override - public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + failAndResetUserPromise(new ClosedChannelException()); ctx.disconnect(promise); } @SuppressWarnings("FutureReturnValueIgnored") @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + failAndResetUserPromise(new ClosedChannelException()); ctx.close(promise); } @SuppressWarnings("FutureReturnValueIgnored") @Override - public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { ctx.deregister(promise); } @SuppressWarnings("FutureReturnValueIgnored") @Override - public void read(ChannelHandlerContext ctx) throws Exception { + public void read(ChannelHandlerContext ctx) { ctx.read(); } - @SuppressWarnings("FutureReturnValueIgnored") + @SuppressWarnings("FutureReturnValueIgnored") @Override - public void flush(ChannelHandlerContext ctx) throws Exception { + public void flush(ChannelHandlerContext ctx) { ctx.flush(); } + + @Override + public void channelInactive(ChannelHandlerContext channelHandlerContext) throws Exception { + failAndResetUserPromise(new ClosedChannelException()); + super.channelInactive(channelHandlerContext); + } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStore.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStore.java index b1ddb0f4fb..bb10623913 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStore.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStore.java @@ -21,12 +21,15 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.pool.ChannelPool; import io.netty.channel.pool.ChannelPoolHandler; import io.netty.channel.pool.SimpleChannelPool; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.ssl.OpenSsl; @@ -37,12 +40,20 @@ import io.netty.handler.ssl.SslProvider; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.internal.PlatformDependent; import java.io.ByteArrayInputStream; +import java.io.FileInputStream; +import java.io.FilterInputStream; +import java.io.FilterOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; +import java.util.List; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Pattern; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.net.ssl.SSLEngine; /** @@ -71,11 +82,22 @@ import javax.net.ssl.SSLEngine; */ public final class HttpBlobStore implements SimpleBlobStore { + private static final Pattern INVALID_TOKEN_ERROR = + Pattern.compile("\\s*error\\s*=\\s*\"?invalid_token\"?"); + private final NioEventLoopGroup eventLoop = new NioEventLoopGroup(2 /* number of threads */); - private final SimpleChannelPool downloadChannels; - private final SimpleChannelPool uploadChannels; + private final ChannelPool downloadChannels; + private final ChannelPool uploadChannels; private final URI uri; + private final Object credentialsLock = new Object(); + + @GuardedBy("credentialsLock") + private final Credentials creds; + + @GuardedBy("credentialsLock") + private long lastRefreshTime; + public HttpBlobStore(URI uri, int timeoutMillis, @Nullable final Credentials creds) throws Exception { boolean useTls = uri.getScheme().equals("https"); @@ -92,11 +114,10 @@ public final class HttpBlobStore implements SimpleBlobStore { uri.getFragment()); } this.uri = uri; - final SslContext sslCtx; if (useTls) { - // OpenSsl gives us a > 2x speed improvement on fast networks, but requires netty tcnative - // to be there which is not available on all platforms and environments. + // OpenSsl gives us a > 2x speed improvement on fast networks, but requires netty tcnative + // to be there which is not available on all platforms and environments. SslProvider sslProvider = OpenSsl.isAvailable() ? SslProvider.OPENSSL : SslProvider.JDK; sslCtx = SslContextBuilder.forClient().sslProvider(sslProvider).build(); } else { @@ -114,13 +135,13 @@ public final class HttpBlobStore implements SimpleBlobStore { clientBootstrap, new ChannelPoolHandler() { @Override - public void channelReleased(Channel ch) throws Exception {} + public void channelReleased(Channel ch) {} @Override - public void channelAcquired(Channel ch) throws Exception {} + public void channelAcquired(Channel ch) {} @Override - public void channelCreated(Channel ch) throws Exception { + public void channelCreated(Channel ch) { ChannelPipeline p = ch.pipeline(); if (sslCtx != null) { SSLEngine engine = sslCtx.newEngine(ch.alloc()); @@ -136,13 +157,13 @@ public final class HttpBlobStore implements SimpleBlobStore { clientBootstrap, new ChannelPoolHandler() { @Override - public void channelReleased(Channel ch) throws Exception {} + public void channelReleased(Channel ch) {} @Override - public void channelAcquired(Channel ch) throws Exception {} + public void channelAcquired(Channel ch) {} @Override - public void channelCreated(Channel ch) throws Exception { + public void channelCreated(Channel ch) { ChannelPipeline p = ch.pipeline(); if (sslCtx != null) { SSLEngine engine = sslCtx.newEngine(ch.alloc()); @@ -158,6 +179,7 @@ public final class HttpBlobStore implements SimpleBlobStore { p.addLast(new HttpUploadHandler(creds)); } }); + this.creds = creds; } @Override @@ -173,33 +195,74 @@ public final class HttpBlobStore implements SimpleBlobStore { @SuppressWarnings("FutureReturnValueIgnored") private boolean get(String key, OutputStream out, boolean casDownload) throws IOException, InterruptedException { - final Channel ch; - try { - ch = downloadChannels.acquire().get(); - } catch (ExecutionException e) { - PlatformDependent.throwException(e.getCause()); - return false; - } - DownloadCommand download = new DownloadCommand(uri, casDownload, key, out); + final AtomicBoolean dataWritten = new AtomicBoolean(); + OutputStream wrappedOut = + new FilterOutputStream(out) { + + @Override + public void write(int b) throws IOException { + dataWritten.set(true); + super.write(b); + } + + @Override + public void close() { + // Ensure that the OutputStream can't be closed somewhere in the Netty + // pipeline, so that we can support retries. The OutputStream is closed in + // the finally block below. + } + }; + DownloadCommand download = new DownloadCommand(uri, casDownload, key, wrappedOut); + ; + Channel ch = null; try { + ch = acquireDownloadChannel(); ChannelFuture downloadFuture = ch.writeAndFlush(download); downloadFuture.sync(); + return true; } catch (Exception e) { // e can be of type HttpException, because Netty uses Unsafe.throwException to re-throw a // checked exception that hasn't been declared in the method signature. if (e instanceof HttpException) { - HttpResponseStatus status = ((HttpException) e).status(); - if (status.equals(HttpResponseStatus.NOT_FOUND) - || status.equals(HttpResponseStatus.NO_CONTENT)) { - // Cache miss. Supporting NO_CONTENT for nginx webdav compatibility. + HttpResponse response = ((HttpException) e).response(); + if (!dataWritten.get() && authTokenExpired(response)) { + // The error is due to an auth token having expired. Let's try again. + refreshCredentials(); + return getAfterCredentialRefresh(download); + } + if (cacheMiss(response.status())) { return false; } } throw e; } finally { - downloadChannels.release(ch); + if (ch != null) { + downloadChannels.release(ch); + } + } + } + + @SuppressWarnings("FutureReturnValueIgnored") + private boolean getAfterCredentialRefresh(DownloadCommand cmd) throws InterruptedException { + Channel ch = null; + try { + ch = acquireDownloadChannel(); + ChannelFuture downloadFuture = ch.writeAndFlush(cmd); + downloadFuture.sync(); + return true; + } catch (Exception e) { + if (e instanceof HttpException) { + HttpResponse response = ((HttpException) e).response(); + if (cacheMiss(response.status())) { + return false; + } + } + throw e; + } finally { + if (ch != null) { + downloadChannels.release(ch); + } } - return true; } @Override @@ -217,30 +280,85 @@ public final class HttpBlobStore implements SimpleBlobStore { @SuppressWarnings("FutureReturnValueIgnored") private void put(String key, long length, InputStream in, boolean casUpload) throws IOException, InterruptedException { - final Channel ch; + InputStream wrappedIn = + new FilterInputStream(in) { + @Override + public void close() { + // Ensure that the InputStream can't be closed somewhere in the Netty + // pipeline, so that we can support retries. The InputStream is closed in + // the finally block below. + } + }; + UploadCommand upload = new UploadCommand(uri, casUpload, key, wrappedIn, length); + Channel ch = null; try { - ch = uploadChannels.acquire().get(); - } catch (ExecutionException e) { - throw new IOException("Failed to obtain a channel from the pool.", e); + ch = acquireUploadChannel(); + ChannelFuture uploadFuture = ch.writeAndFlush(upload); + uploadFuture.sync(); + } catch (Exception e) { + // e can be of type HttpException, because Netty uses Unsafe.throwException to re-throw a + // checked exception that hasn't been declared in the method signature. + if (e instanceof HttpException) { + HttpResponse response = ((HttpException) e).response(); + if (authTokenExpired(response)) { + refreshCredentials(); + // The error is due to an auth token having expired. Let's try again. + if (!reset(in)) { + // The InputStream can't be reset and thus we can't retry as most likely + // bytes have already been read from the InputStream. + throw e; + } + putAfterCredentialRefresh(upload); + return; + } + } + throw e; + } finally { + in.close(); + if (ch != null) { + uploadChannels.release(ch); + } } - UploadCommand upload = new UploadCommand(uri, casUpload, key, in, length); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void putAfterCredentialRefresh(UploadCommand cmd) throws InterruptedException { + Channel ch = null; try { - ChannelFuture uploadFuture = ch.writeAndFlush(upload); + ch = acquireUploadChannel(); + ChannelFuture uploadFuture = ch.writeAndFlush(cmd); uploadFuture.sync(); } finally { - uploadChannels.release(ch); + if (ch != null) { + uploadChannels.release(ch); + } + } + } + + private boolean reset(InputStream in) throws IOException { + if (in.markSupported()) { + in.reset(); + return true; + } + if (in instanceof FileInputStream) { + // FileInputStream does not support reset(). + ((FileInputStream) in).getChannel().position(0); + return true; } + return false; } @Override - public void putActionResult(String actionKey, byte[] in) + public void putActionResult(String actionKey, byte[] data) throws IOException, InterruptedException { - put(actionKey, in.length, new ByteArrayInputStream(in), false); + try (InputStream in = new ByteArrayInputStream(data)) { + put(actionKey, data.length, in, false); + } } /** - * It's safe to suppress this warning because all methods on Netty - * futures return {@code this}. So we are not ignoring anything. + * It's safe to suppress this warning because all methods on Netty futures return {@code this}. So + * we are not ignoring anything. */ @SuppressWarnings("FutureReturnValueIgnored") @Override @@ -249,4 +367,58 @@ public final class HttpBlobStore implements SimpleBlobStore { uploadChannels.close(); eventLoop.shutdownGracefully(); } + + private boolean cacheMiss(HttpResponseStatus status) { + // Supporting NO_CONTENT for nginx webdav compatibility. + return status.equals(HttpResponseStatus.NOT_FOUND) + || status.equals(HttpResponseStatus.NO_CONTENT); + } + + /** See https://tools.ietf.org/html/rfc6750#section-3.1 */ + private boolean authTokenExpired(HttpResponse response) { + synchronized (credentialsLock) { + if (creds == null) { + return false; + } + } + List<String> values = response.headers().getAllAsString(HttpHeaderNames.WWW_AUTHENTICATE); + String value = String.join(",", values); + if (value != null && value.startsWith("Bearer")) { + return INVALID_TOKEN_ERROR.matcher(value).find(); + } else { + return response.status().equals(HttpResponseStatus.UNAUTHORIZED); + } + } + + private Channel acquireDownloadChannel() throws InterruptedException { + try { + return downloadChannels.acquire().get(); + } catch (ExecutionException e) { + PlatformDependent.throwException(e.getCause()); + return null; + } + } + + private Channel acquireUploadChannel() throws InterruptedException { + try { + return uploadChannels.acquire().get(); + } catch (ExecutionException e) { + PlatformDependent.throwException(e.getCause()); + return null; + } + } + + private void refreshCredentials() throws IOException { + synchronized (credentialsLock) { + long now = System.currentTimeMillis(); + // Call creds.refresh() at most once per second. The one second was arbitrarily chosen, as + // a small enough value that we don't expect to interfere with actual token lifetimes, but + // it should just make sure that potentially hundreds of threads don't call this method + // at the same time. + if ((now - lastRefreshTime) > TimeUnit.SECONDS.toMillis(1)) { + lastRefreshTime = now; + creds.refresh(); + } + } + } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandler.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandler.java index 502c0ecec6..7cdbc10534 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandler.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandler.java @@ -38,6 +38,8 @@ import java.io.OutputStream; /** ChannelHandler for downloads. */ final class HttpDownloadHandler extends AbstractHttpHandler<HttpObject> { + private long contentLength = -1; + private long bytesReceived; private OutputStream out; private boolean keepAlive = HttpVersion.HTTP_1_1.isKeepAliveDefault(); @@ -47,20 +49,27 @@ final class HttpDownloadHandler extends AbstractHttpHandler<HttpObject> { @Override protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception { + if (!msg.decoderResult().isSuccess()) { + failAndResetUserPromise(new IOException("Failed to parse the HTTP response.")); + return; + } checkState(userPromise != null, "response before request"); if (msg instanceof HttpResponse) { HttpResponse response = (HttpResponse) msg; keepAlive = HttpUtil.isKeepAlive((HttpResponse) msg); + if (HttpUtil.isContentLengthSet(response)) { + contentLength = HttpUtil.getContentLength(response); + } if (!response.status().equals(HttpResponseStatus.OK)) { failAndReset( - new HttpException( - response.status(), "Download failed with Status: " + response.status(), null), + new HttpException(response, "Download failed with status: " + response.status(), null), ctx); } } else if (msg instanceof HttpContent) { ByteBuf content = ((HttpContent) msg).content(); + bytesReceived += content.readableBytes(); content.readBytes(out, content.readableBytes()); - if (msg instanceof LastHttpContent) { + if (bytesReceived == contentLength || msg instanceof LastHttpContent) { succeedAndReset(ctx); } } else { @@ -128,8 +137,9 @@ final class HttpDownloadHandler extends AbstractHttpHandler<HttpObject> { if (!keepAlive) { ctx.close(); } - out.close(); } finally { + contentLength = -1; + bytesReceived = 0; out = null; keepAlive = HttpVersion.HTTP_1_1.isKeepAliveDefault(); } diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpException.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpException.java index 0a5368b8d4..43a9c22689 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpException.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpException.java @@ -14,19 +14,19 @@ package com.google.devtools.build.lib.remote.blobstore.http; -import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpResponse; import java.io.IOException; /** An exception that propagates the http status. */ final class HttpException extends IOException { - private final HttpResponseStatus status; + private final HttpResponse response; - HttpException(HttpResponseStatus status, String message, Throwable cause) { + HttpException(HttpResponse response, String message, Throwable cause) { super(message, cause); - this.status = status; + this.response = response; } - public HttpResponseStatus status() { - return status; + public HttpResponse response() { + return response; } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandler.java b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandler.java index 0cd919a862..11e243515f 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandler.java +++ b/src/main/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandler.java @@ -30,6 +30,7 @@ import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.stream.ChunkedStream; import io.netty.util.internal.StringUtil; +import java.io.IOException; /** ChannelHandler for uploads. */ final class HttpUploadHandler extends AbstractHttpHandler<FullHttpResponse> { @@ -40,8 +41,11 @@ final class HttpUploadHandler extends AbstractHttpHandler<FullHttpResponse> { @SuppressWarnings("FutureReturnValueIgnored") @Override - protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse response) - throws Exception { + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse response) { + if (!response.decoderResult().isSuccess()) { + failAndResetUserPromise(new IOException("Failed to parse the HTTP response.")); + return; + } try { checkState(userPromise != null, "response before request"); if (!response.status().equals(HttpResponseStatus.OK) @@ -50,8 +54,7 @@ final class HttpUploadHandler extends AbstractHttpHandler<FullHttpResponse> { && !response.status().equals(HttpResponseStatus.NO_CONTENT)) { // Supporting more than OK status to be compatible with nginx webdav. failAndResetUserPromise( - new HttpException( - response.status(), "Download failed with " + "Status: " + response.status(), null)); + new HttpException(response, "Upload failed with status: " + response.status(), null)); } else { succeedAndResetUserPromise(); } @@ -82,7 +85,6 @@ final class HttpUploadHandler extends AbstractHttpHandler<FullHttpResponse> { if (f.isSuccess()) { return; } - body.close(); failAndResetUserPromise(f.cause()); }); ctx.writeAndFlush(body) @@ -91,7 +93,6 @@ final class HttpUploadHandler extends AbstractHttpHandler<FullHttpResponse> { if (f.isSuccess()) { return; } - body.close(); failAndResetUserPromise(f.cause()); }); } diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD index a11f48ac1a..aa77f34f59 100644 --- a/src/test/java/com/google/devtools/build/lib/BUILD +++ b/src/test/java/com/google/devtools/build/lib/BUILD @@ -1210,6 +1210,7 @@ java_test( "//src/main/java/com/google/devtools/common/options", "//src/main/protobuf:remote_execution_log_java_proto", "//third_party:api_client", + "//third_party:auth", "//third_party:mockito", "//third_party:netty", "//third_party/grpc:grpc-jar", diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java new file mode 100644 index 0000000000..7ef77edc2f --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java @@ -0,0 +1,296 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.blobstore.http; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.Collections.singletonList; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.auth.Credentials; +import com.google.common.base.Charsets; +import com.google.devtools.build.lib.remote.blobstore.http.HttpBlobStoreTest.NotAuthorizedHandler.ErrorType; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** Tests for {@link HttpBlobStore}. */ +@RunWith(JUnit4.class) +public class HttpBlobStoreTest { + + private ServerSocketChannel startServer(ChannelHandler handler) throws Exception { + EventLoopGroup eventLoop = new NioEventLoopGroup(1); + ServerBootstrap sb = + new ServerBootstrap() + .group(eventLoop) + .channel(NioServerSocketChannel.class) + .childHandler( + new ChannelInitializer<NioSocketChannel>() { + @Override + protected void initChannel(NioSocketChannel ch) { + ch.pipeline().addLast(new HttpServerCodec()); + ch.pipeline().addLast(new HttpObjectAggregator(1000)); + ch.pipeline().addLast(handler); + } + }); + return ((ServerSocketChannel) sb.bind("localhost", 0).sync().channel()); + } + + @Test + public void expiredAuthTokensShouldBeRetried_get() throws Exception { + expiredAuthTokensShouldBeRetried_get(ErrorType.UNAUTHORIZED); + expiredAuthTokensShouldBeRetried_get(ErrorType.INVALID_TOKEN); + } + + private void expiredAuthTokensShouldBeRetried_get(ErrorType errorType) throws Exception { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + ByteArrayOutputStream out = Mockito.spy(new ByteArrayOutputStream()); + blobStore.get("key", out); + assertThat(out.toString(Charsets.US_ASCII.name())).isEqualTo("File Contents"); + verify(credentials, times(1)).refresh(); + verify(credentials, times(2)).getRequestMetadata(any(URI.class)); + verify(credentials, times(2)).hasRequestMetadata(); + // The caller is responsible to the close the stream. + verify(out, never()).close(); + verifyNoMoreInteractions(credentials); + } finally { + closeServerChannel(server); + } + } + + @Test + public void expiredAuthTokensShouldBeRetried_put() throws Exception { + expiredAuthTokensShouldBeRetried_put(ErrorType.UNAUTHORIZED); + expiredAuthTokensShouldBeRetried_put(ErrorType.INVALID_TOKEN); + } + + private void expiredAuthTokensShouldBeRetried_put(ErrorType errorType) throws Exception { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + byte[] data = "File Contents".getBytes(Charsets.US_ASCII); + ByteArrayInputStream in = new ByteArrayInputStream(data); + blobStore.put("key", data.length, in); + verify(credentials, times(1)).refresh(); + verify(credentials, times(2)).getRequestMetadata(any(URI.class)); + verify(credentials, times(2)).hasRequestMetadata(); + verifyNoMoreInteractions(credentials); + } finally { + closeServerChannel(server); + } + } + + @Test + public void errorCodesThatShouldNotBeRetried_get() throws InterruptedException { + errorCodeThatShouldNotBeRetried_get(ErrorType.INSUFFICIENT_SCOPE); + errorCodeThatShouldNotBeRetried_get(ErrorType.INVALID_REQUEST); + } + + private void errorCodeThatShouldNotBeRetried_get(ErrorType errorType) + throws InterruptedException { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + blobStore.get("key", new ByteArrayOutputStream()); + fail("Exception expected."); + } catch (Exception e) { + assertThat(e).isInstanceOf(HttpException.class); + assertThat(((HttpException) e).response().status()) + .isEqualTo(HttpResponseStatus.UNAUTHORIZED); + } finally { + closeServerChannel(server); + } + } + + @Test + public void errorCodesThatShouldNotBeRetried_put() throws InterruptedException { + errorCodeThatShouldNotBeRetried_put(ErrorType.INSUFFICIENT_SCOPE); + errorCodeThatShouldNotBeRetried_put(ErrorType.INVALID_REQUEST); + } + + private void errorCodeThatShouldNotBeRetried_put(ErrorType errorType) + throws InterruptedException { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + blobStore.put("key", 1, new ByteArrayInputStream(new byte[] {0})); + fail("Exception expected."); + } catch (Exception e) { + assertThat(e).isInstanceOf(HttpException.class); + assertThat(((HttpException) e).response().status()) + .isEqualTo(HttpResponseStatus.UNAUTHORIZED); + } finally { + closeServerChannel(server); + } + } + + private Credentials newCredentials() throws Exception { + Credentials credentials = mock(Credentials.class); + when(credentials.hasRequestMetadata()).thenReturn(true); + Map<String, List<String>> headers = new HashMap<>(); + headers.put("Authorization", singletonList("Bearer invalidToken")); + when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers); + Mockito.doAnswer( + (mock) -> { + Map<String, List<String>> headers2 = new HashMap<>(); + headers2.put("Authorization", singletonList("Bearer validToken")); + when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers2); + return null; + }) + .when(credentials) + .refresh(); + return credentials; + } + + /** + * {@link ChannelHandler} that on the first request responds with a 401 UNAUTHORIZED status code, + * which the client is expected to retry once with a new authentication token. + */ + @Sharable + static class NotAuthorizedHandler extends SimpleChannelInboundHandler<FullHttpRequest> { + + enum ErrorType { + UNAUTHORIZED, + INVALID_TOKEN, + INSUFFICIENT_SCOPE, + INVALID_REQUEST + } + + private final ErrorType errorType; + private int messageCount; + + NotAuthorizedHandler(ErrorType errorType) { + this.errorType = errorType; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { + if (messageCount == 0) { + if (!"Bearer invalidToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + return; + } + final FullHttpResponse response; + if (errorType == ErrorType.UNAUTHORIZED) { + response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + } else { + response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + response + .headers() + .set( + HttpHeaderNames.WWW_AUTHENTICATE, + "Bearer realm=\"localhost\"," + + "error=\"" + + errorType.name().toLowerCase() + + "\"," + + "error_description=\"The access token expired\""); + } + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + messageCount++; + } else if (messageCount == 1) { + if (!"Bearer validToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + return; + } + ByteBuf content = ctx.alloc().buffer(); + content.writeCharSequence("File Contents", Charsets.US_ASCII); + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content); + HttpUtil.setKeepAlive(response, true); + HttpUtil.setContentLength(response, content.readableBytes()); + ctx.writeAndFlush(response); + messageCount++; + } else { + // No third message expected. + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + } + } + } + + private void closeServerChannel(ServerSocketChannel server) throws InterruptedException { + if (server != null) { + server.close(); + server.closeFuture().sync(); + server.eventLoop().shutdownGracefully().sync(); + } + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java index ca544ee553..a44b7a7f2f 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java @@ -14,6 +14,7 @@ package com.google.devtools.build.lib.remote.blobstore.http; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import com.google.common.net.HttpHeaders; @@ -83,7 +84,7 @@ public class HttpDownloadHandlerTest extends AbstractHttpHandlerTest { assertThat(writePromise.isDone()).isTrue(); assertThat(out.toByteArray()).isEqualTo(new byte[] {1, 2, 3, 4, 5}); - verify(out).close(); + verify(out, never()).close(); assertThat(ch.isActive()).isTrue(); } @@ -102,11 +103,12 @@ public class HttpDownloadHandlerTest extends AbstractHttpHandlerTest { ch.writeInbound(response); assertThat(writePromise.isDone()).isTrue(); assertThat(writePromise.cause()).isInstanceOf(HttpException.class); - assertThat(((HttpException) writePromise.cause()).status()) + assertThat(((HttpException) writePromise.cause()).response().status()) .isEqualTo(HttpResponseStatus.NOT_FOUND); // No data should have been written to the OutputStream and it should have been closed. assertThat(out.size()).isEqualTo(0); - verify(out).close(); + // The caller is responsible for closing the stream. + verify(out, never()).close(); assertThat(ch.isOpen()).isFalse(); } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java index 27914da542..0b6b02ab93 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java @@ -102,7 +102,7 @@ public class HttpUploadHandlerTest extends AbstractHttpHandlerTest { assertThat(writePromise.isDone()).isTrue(); assertThat(writePromise.cause()).isInstanceOf(HttpException.class); - assertThat(((HttpException) writePromise.cause()).status()) + assertThat(((HttpException) writePromise.cause()).response().status()) .isEqualTo(HttpResponseStatus.FORBIDDEN); assertThat(ch.isOpen()).isFalse(); } |