diff options
author | Jakob Buchgraber <buchgr@google.com> | 2018-03-10 04:14:51 -0800 |
---|---|---|
committer | Copybara-Service <copybara-piper@google.com> | 2018-03-10 04:17:02 -0800 |
commit | deccc485603c004daad959fd747f1c0c9efc4f00 (patch) | |
tree | 6cdf8d42e01fd92fb32d5ef5f05325d7ea0d39e0 /src/test/java/com/google/devtools/build/lib | |
parent | 7e50ced9bb59b4ab445edd7904bf31601fd2cea0 (diff) |
remote/http: support refresh of oauth2 tokens in the remote cache.
Closes #4622.
PiperOrigin-RevId: 188595430
Diffstat (limited to 'src/test/java/com/google/devtools/build/lib')
4 files changed, 303 insertions, 4 deletions
diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD index a11f48ac1a..aa77f34f59 100644 --- a/src/test/java/com/google/devtools/build/lib/BUILD +++ b/src/test/java/com/google/devtools/build/lib/BUILD @@ -1210,6 +1210,7 @@ java_test( "//src/main/java/com/google/devtools/common/options", "//src/main/protobuf:remote_execution_log_java_proto", "//third_party:api_client", + "//third_party:auth", "//third_party:mockito", "//third_party:netty", "//third_party/grpc:grpc-jar", diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java new file mode 100644 index 0000000000..7ef77edc2f --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpBlobStoreTest.java @@ -0,0 +1,296 @@ +// Copyright 2018 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.blobstore.http; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.Collections.singletonList; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.auth.Credentials; +import com.google.common.base.Charsets; +import com.google.devtools.build.lib.remote.blobstore.http.HttpBlobStoreTest.NotAuthorizedHandler.ErrorType; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mockito; + +/** Tests for {@link HttpBlobStore}. */ +@RunWith(JUnit4.class) +public class HttpBlobStoreTest { + + private ServerSocketChannel startServer(ChannelHandler handler) throws Exception { + EventLoopGroup eventLoop = new NioEventLoopGroup(1); + ServerBootstrap sb = + new ServerBootstrap() + .group(eventLoop) + .channel(NioServerSocketChannel.class) + .childHandler( + new ChannelInitializer<NioSocketChannel>() { + @Override + protected void initChannel(NioSocketChannel ch) { + ch.pipeline().addLast(new HttpServerCodec()); + ch.pipeline().addLast(new HttpObjectAggregator(1000)); + ch.pipeline().addLast(handler); + } + }); + return ((ServerSocketChannel) sb.bind("localhost", 0).sync().channel()); + } + + @Test + public void expiredAuthTokensShouldBeRetried_get() throws Exception { + expiredAuthTokensShouldBeRetried_get(ErrorType.UNAUTHORIZED); + expiredAuthTokensShouldBeRetried_get(ErrorType.INVALID_TOKEN); + } + + private void expiredAuthTokensShouldBeRetried_get(ErrorType errorType) throws Exception { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + ByteArrayOutputStream out = Mockito.spy(new ByteArrayOutputStream()); + blobStore.get("key", out); + assertThat(out.toString(Charsets.US_ASCII.name())).isEqualTo("File Contents"); + verify(credentials, times(1)).refresh(); + verify(credentials, times(2)).getRequestMetadata(any(URI.class)); + verify(credentials, times(2)).hasRequestMetadata(); + // The caller is responsible to the close the stream. + verify(out, never()).close(); + verifyNoMoreInteractions(credentials); + } finally { + closeServerChannel(server); + } + } + + @Test + public void expiredAuthTokensShouldBeRetried_put() throws Exception { + expiredAuthTokensShouldBeRetried_put(ErrorType.UNAUTHORIZED); + expiredAuthTokensShouldBeRetried_put(ErrorType.INVALID_TOKEN); + } + + private void expiredAuthTokensShouldBeRetried_put(ErrorType errorType) throws Exception { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + byte[] data = "File Contents".getBytes(Charsets.US_ASCII); + ByteArrayInputStream in = new ByteArrayInputStream(data); + blobStore.put("key", data.length, in); + verify(credentials, times(1)).refresh(); + verify(credentials, times(2)).getRequestMetadata(any(URI.class)); + verify(credentials, times(2)).hasRequestMetadata(); + verifyNoMoreInteractions(credentials); + } finally { + closeServerChannel(server); + } + } + + @Test + public void errorCodesThatShouldNotBeRetried_get() throws InterruptedException { + errorCodeThatShouldNotBeRetried_get(ErrorType.INSUFFICIENT_SCOPE); + errorCodeThatShouldNotBeRetried_get(ErrorType.INVALID_REQUEST); + } + + private void errorCodeThatShouldNotBeRetried_get(ErrorType errorType) + throws InterruptedException { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + blobStore.get("key", new ByteArrayOutputStream()); + fail("Exception expected."); + } catch (Exception e) { + assertThat(e).isInstanceOf(HttpException.class); + assertThat(((HttpException) e).response().status()) + .isEqualTo(HttpResponseStatus.UNAUTHORIZED); + } finally { + closeServerChannel(server); + } + } + + @Test + public void errorCodesThatShouldNotBeRetried_put() throws InterruptedException { + errorCodeThatShouldNotBeRetried_put(ErrorType.INSUFFICIENT_SCOPE); + errorCodeThatShouldNotBeRetried_put(ErrorType.INVALID_REQUEST); + } + + private void errorCodeThatShouldNotBeRetried_put(ErrorType errorType) + throws InterruptedException { + ServerSocketChannel server = null; + try { + server = startServer(new NotAuthorizedHandler(errorType)); + int serverPort = server.localAddress().getPort(); + + Credentials credentials = newCredentials(); + HttpBlobStore blobStore = + new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); + blobStore.put("key", 1, new ByteArrayInputStream(new byte[] {0})); + fail("Exception expected."); + } catch (Exception e) { + assertThat(e).isInstanceOf(HttpException.class); + assertThat(((HttpException) e).response().status()) + .isEqualTo(HttpResponseStatus.UNAUTHORIZED); + } finally { + closeServerChannel(server); + } + } + + private Credentials newCredentials() throws Exception { + Credentials credentials = mock(Credentials.class); + when(credentials.hasRequestMetadata()).thenReturn(true); + Map<String, List<String>> headers = new HashMap<>(); + headers.put("Authorization", singletonList("Bearer invalidToken")); + when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers); + Mockito.doAnswer( + (mock) -> { + Map<String, List<String>> headers2 = new HashMap<>(); + headers2.put("Authorization", singletonList("Bearer validToken")); + when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers2); + return null; + }) + .when(credentials) + .refresh(); + return credentials; + } + + /** + * {@link ChannelHandler} that on the first request responds with a 401 UNAUTHORIZED status code, + * which the client is expected to retry once with a new authentication token. + */ + @Sharable + static class NotAuthorizedHandler extends SimpleChannelInboundHandler<FullHttpRequest> { + + enum ErrorType { + UNAUTHORIZED, + INVALID_TOKEN, + INSUFFICIENT_SCOPE, + INVALID_REQUEST + } + + private final ErrorType errorType; + private int messageCount; + + NotAuthorizedHandler(ErrorType errorType) { + this.errorType = errorType; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { + if (messageCount == 0) { + if (!"Bearer invalidToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + return; + } + final FullHttpResponse response; + if (errorType == ErrorType.UNAUTHORIZED) { + response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + } else { + response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + response + .headers() + .set( + HttpHeaderNames.WWW_AUTHENTICATE, + "Bearer realm=\"localhost\"," + + "error=\"" + + errorType.name().toLowerCase() + + "\"," + + "error_description=\"The access token expired\""); + } + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + messageCount++; + } else if (messageCount == 1) { + if (!"Bearer validToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + return; + } + ByteBuf content = ctx.alloc().buffer(); + content.writeCharSequence("File Contents", Charsets.US_ASCII); + FullHttpResponse response = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content); + HttpUtil.setKeepAlive(response, true); + HttpUtil.setContentLength(response, content.readableBytes()); + ctx.writeAndFlush(response); + messageCount++; + } else { + // No third message expected. + ctx.writeAndFlush( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) + .addListener(ChannelFutureListener.CLOSE); + } + } + } + + private void closeServerChannel(ServerSocketChannel server) throws InterruptedException { + if (server != null) { + server.close(); + server.closeFuture().sync(); + server.eventLoop().shutdownGracefully().sync(); + } + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java index ca544ee553..a44b7a7f2f 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpDownloadHandlerTest.java @@ -14,6 +14,7 @@ package com.google.devtools.build.lib.remote.blobstore.http; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import com.google.common.net.HttpHeaders; @@ -83,7 +84,7 @@ public class HttpDownloadHandlerTest extends AbstractHttpHandlerTest { assertThat(writePromise.isDone()).isTrue(); assertThat(out.toByteArray()).isEqualTo(new byte[] {1, 2, 3, 4, 5}); - verify(out).close(); + verify(out, never()).close(); assertThat(ch.isActive()).isTrue(); } @@ -102,11 +103,12 @@ public class HttpDownloadHandlerTest extends AbstractHttpHandlerTest { ch.writeInbound(response); assertThat(writePromise.isDone()).isTrue(); assertThat(writePromise.cause()).isInstanceOf(HttpException.class); - assertThat(((HttpException) writePromise.cause()).status()) + assertThat(((HttpException) writePromise.cause()).response().status()) .isEqualTo(HttpResponseStatus.NOT_FOUND); // No data should have been written to the OutputStream and it should have been closed. assertThat(out.size()).isEqualTo(0); - verify(out).close(); + // The caller is responsible for closing the stream. + verify(out, never()).close(); assertThat(ch.isOpen()).isFalse(); } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java index 27914da542..0b6b02ab93 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/blobstore/http/HttpUploadHandlerTest.java @@ -102,7 +102,7 @@ public class HttpUploadHandlerTest extends AbstractHttpHandlerTest { assertThat(writePromise.isDone()).isTrue(); assertThat(writePromise.cause()).isInstanceOf(HttpException.class); - assertThat(((HttpException) writePromise.cause()).status()) + assertThat(((HttpException) writePromise.cause()).response().status()) .isEqualTo(HttpResponseStatus.FORBIDDEN); assertThat(ch.isOpen()).isFalse(); } |