diff options
11 files changed, 195 insertions, 70 deletions
diff --git a/Firestore/Example/Tests/Integration/FSTStreamTests.mm b/Firestore/Example/Tests/Integration/FSTStreamTests.mm index b944a93..3ad6868 100644 --- a/Firestore/Example/Tests/Integration/FSTStreamTests.mm +++ b/Firestore/Example/Tests/Integration/FSTStreamTests.mm @@ -18,8 +18,11 @@ #import <GRPCClient/GRPCCall.h> +#import <FirebaseFirestore/FIRFirestoreErrors.h> #import <FirebaseFirestore/FIRFirestoreSettings.h> +#include <utility> + #import "Firestore/Example/Tests/Util/FSTHelpers.h" #import "Firestore/Example/Tests/Util/FSTIntegrationTestCase.h" #import "Firestore/Source/Remote/FSTDatastore.h" @@ -327,4 +330,74 @@ using firebase::firestore::model::SnapshotVersion; ]]; } +class MockCredentialsProvider : public firebase::firestore::auth::EmptyCredentialsProvider { + public: + MockCredentialsProvider() { + observed_states_ = [NSMutableArray new]; + } + + void GetToken(firebase::firestore::auth::TokenListener completion) override { + [observed_states_ addObject:@"GetToken"]; + EmptyCredentialsProvider::GetToken(std::move(completion)); + } + + void InvalidateToken() override { + [observed_states_ addObject:@"InvalidateToken"]; + EmptyCredentialsProvider::InvalidateToken(); + } + + NSMutableArray<NSString *> *observed_states() const { + return observed_states_; + } + + private: + NSMutableArray<NSString *> *observed_states_; +}; + +- (void)testStreamRefreshesTokenUponExpiration { + MockCredentialsProvider credentials; + FSTDatastore *datastore = [[FSTDatastore alloc] initWithDatabaseInfo:&_databaseInfo + workerDispatchQueue:_workerDispatchQueue + credentials:&credentials]; + FSTWatchStream *watchStream = [datastore createWatchStream]; + + [_delegate awaitNotificationFromBlock:^{ + [watchStream startWithDelegate:_delegate]; + }]; + + // Simulate callback from GRPC with an unauthenticated error -- this should invalidate the token. + NSError *unauthenticatedError = [NSError errorWithDomain:FIRFirestoreErrorDomain + code:FIRFirestoreErrorCodeUnauthenticated + userInfo:nil]; + dispatch_async(_testQueue, ^{ + [watchStream.callbackFilter writesFinishedWithError:unauthenticatedError]; + }); + // Drain the queue. + dispatch_sync(_testQueue, ^{ + }); + + // Try reconnecting. + [_delegate awaitNotificationFromBlock:^{ + [watchStream startWithDelegate:_delegate]; + }]; + // Simulate a different error -- token should not be invalidated this time. + NSError *unavailableError = [NSError errorWithDomain:FIRFirestoreErrorDomain + code:FIRFirestoreErrorCodeUnavailable + userInfo:nil]; + dispatch_async(_testQueue, ^{ + [watchStream.callbackFilter writesFinishedWithError:unavailableError]; + }); + dispatch_sync(_testQueue, ^{ + }); + + [_delegate awaitNotificationFromBlock:^{ + [watchStream startWithDelegate:_delegate]; + }]; + dispatch_sync(_testQueue, ^{ + }); + + NSArray<NSString *> *expected = @[ @"GetToken", @"InvalidateToken", @"GetToken", @"GetToken" ]; + XCTAssertEqualObjects(credentials.observed_states(), expected); +} + @end diff --git a/Firestore/Example/Tests/Remote/FSTDatastoreTests.mm b/Firestore/Example/Tests/Remote/FSTDatastoreTests.mm index 6d6e912..9783e37 100644 --- a/Firestore/Example/Tests/Remote/FSTDatastoreTests.mm +++ b/Firestore/Example/Tests/Remote/FSTDatastoreTests.mm @@ -53,6 +53,12 @@ code:FIRFirestoreErrorCodeUnavailable userInfo:nil]; XCTAssertFalse([FSTDatastore isPermanentWriteError:error]); + + // "unauthenticated" is considered a recoverable error due to expired token. + error = [NSError errorWithDomain:FIRFirestoreErrorDomain + code:FIRFirestoreErrorCodeUnauthenticated + userInfo:nil]; + XCTAssertFalse([FSTDatastore isPermanentWriteError:error]); } @end diff --git a/Firestore/Source/Remote/FSTDatastore.mm b/Firestore/Source/Remote/FSTDatastore.mm index 5f79122..fdbeea3 100644 --- a/Firestore/Source/Remote/FSTDatastore.mm +++ b/Firestore/Source/Remote/FSTDatastore.mm @@ -157,10 +157,9 @@ typedef GRPCProtoCall * (^RPCFactory)(void); case FIRFirestoreErrorCodeResourceExhausted: case FIRFirestoreErrorCodeInternal: case FIRFirestoreErrorCodeUnavailable: + // Unauthenticated means something went wrong with our token and we need to retry with new + // credentials which will happen automatically. case FIRFirestoreErrorCodeUnauthenticated: - // Unauthenticated means something went wrong with our token and we need - // to retry with new credentials which will happen automatically. - // TODO(b/37325376): Give up after second unauthenticated error. return NO; case FIRFirestoreErrorCodeInvalidArgument: case FIRFirestoreErrorCodeNotFound: @@ -174,6 +173,7 @@ typedef GRPCProtoCall * (^RPCFactory)(void); case FIRFirestoreErrorCodeOutOfRange: case FIRFirestoreErrorCodeUnimplemented: case FIRFirestoreErrorCodeDataLoss: + return YES; default: return YES; } @@ -239,6 +239,9 @@ typedef GRPCProtoCall * (^RPCFactory)(void); handler:^(GCFSCommitResponse *response, NSError *_Nullable error) { error = [FSTDatastore firestoreErrorForError:error]; [self.workerDispatchQueue dispatchAsync:^{ + if (error != nil && error.code == FIRFirestoreErrorCodeUnauthenticated) { + _credentials->InvalidateToken(); + } LOG_DEBUG("RPC CommitRequest completed. Error: %s", error); [FSTDatastore logHeadersForRPC:rpc RPCName:@"CommitRequest"]; completion(error); @@ -273,6 +276,9 @@ typedef GRPCProtoCall * (^RPCFactory)(void); [self.workerDispatchQueue dispatchAsync:^{ if (error) { LOG_DEBUG("RPC BatchGetDocuments completed. Error: %s", error); + if (error.code == FIRFirestoreErrorCodeUnauthenticated) { + _credentials->InvalidateToken(); + } [FSTDatastore logHeadersForRPC:rpc RPCName:@"BatchGetDocuments"]; completion(nil, error); return; @@ -310,25 +316,21 @@ typedef GRPCProtoCall * (^RPCFactory)(void); - (void)invokeRPCWithFactory:(GRPCProtoCall * (^)(void))rpcFactory errorHandler:(FSTVoidErrorBlock)errorHandler { - // TODO(mikelehen): We should force a refresh if the previous RPC failed due to an expired token, - // but I'm not sure how to detect that right now. http://b/32762461 - _credentials->GetToken( - /*force_refresh=*/false, [self, rpcFactory, errorHandler](util::StatusOr<Token> result) { - [self.workerDispatchQueue dispatchAsyncAllowingSameQueue:^{ - if (!result.ok()) { - errorHandler(util::MakeNSError(result.status())); - } else { - GRPCProtoCall *rpc = rpcFactory(); - const Token &token = result.ValueOrDie(); - [FSTDatastore - prepareHeadersForRPC:rpc - databaseID:&self.databaseInfo->database_id() - token:(token.user().is_authenticated() ? token.token() - : absl::string_view())]; - [rpc start]; - } - }]; - }); + _credentials->GetToken([self, rpcFactory, errorHandler](util::StatusOr<Token> result) { + [self.workerDispatchQueue dispatchAsyncAllowingSameQueue:^{ + if (!result.ok()) { + errorHandler(util::MakeNSError(result.status())); + } else { + GRPCProtoCall *rpc = rpcFactory(); + const Token &token = result.ValueOrDie(); + [FSTDatastore prepareHeadersForRPC:rpc + databaseID:&self.databaseInfo->database_id() + token:(token.user().is_authenticated() ? token.token() + : absl::string_view())]; + [rpc start]; + } + }]; + }); } - (FSTWatchStream *)createWatchStream { diff --git a/Firestore/Source/Remote/FSTStream.mm b/Firestore/Source/Remote/FSTStream.mm index 3a6c035..2c9c8d9 100644 --- a/Firestore/Source/Remote/FSTStream.mm +++ b/Firestore/Source/Remote/FSTStream.mm @@ -265,12 +265,11 @@ static const NSTimeInterval kIdleTimeout = 60.0; HARD_ASSERT(_delegate == nil, "Delegate must be nil"); _delegate = delegate; - _credentials->GetToken( - /*force_refresh=*/false, [self](util::StatusOr<Token> result) { - [self.workerDispatchQueue dispatchAsyncAllowingSameQueue:^{ - [self resumeStartWithToken:result]; - }]; - }); + _credentials->GetToken([self](util::StatusOr<Token> result) { + [self.workerDispatchQueue dispatchAsyncAllowingSameQueue:^{ + [self resumeStartWithToken:result]; + }]; + }); } /** Add an access token to our RPC, after obtaining one from the credentials provider. */ @@ -283,8 +282,6 @@ static const NSTimeInterval kIdleTimeout = 60.0; } HARD_ASSERT(self.state == FSTStreamStateAuth, "State should still be auth (was %s)", self.state); - // TODO(mikelehen): We should force a refresh if the previous RPC failed due to an expired token, - // but I'm not sure how to detect that right now. http://b/32762461 if (!result.ok()) { // RPC has not been started yet, so just invoke higher-level close handler. [self handleStreamClose:util::MakeNSError(result.status())]; @@ -383,6 +380,10 @@ static const NSTimeInterval kIdleTimeout = 60.0; LOG_DEBUG("%s %s Using maximum backoff delay to prevent overloading the backend.", [self class], (__bridge void *)self); [self.backoff resetToMax]; + } else if (error != nil && error.code == FIRFirestoreErrorCodeUnauthenticated) { + // "unauthenticated" error means the token was rejected. Try force refreshing it in case it just + // expired. + _credentials->InvalidateToken(); } if (finalState != FSTStreamStateError) { diff --git a/Firestore/core/src/firebase/firestore/auth/credentials_provider.h b/Firestore/core/src/firebase/firestore/auth/credentials_provider.h index 0a1930a..d6ed39a 100644 --- a/Firestore/core/src/firebase/firestore/auth/credentials_provider.h +++ b/Firestore/core/src/firebase/firestore/auth/credentials_provider.h @@ -46,11 +46,14 @@ class CredentialsProvider { virtual ~CredentialsProvider(); + /** Requests token for the current user. */ + virtual void GetToken(TokenListener completion) = 0; + /** - * Requests token for the current user, optionally forcing a refreshed token - * to be fetched. + * Marks the last retrieved token as invalid, making the next `GetToken` + * request force refresh the token. */ - virtual void GetToken(bool force_refresh, TokenListener completion) = 0; + virtual void InvalidateToken() = 0; /** * Sets the listener to be notified of user changes (sign-in / sign-out). It diff --git a/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.cc b/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.cc index da4198d..77156cc 100644 --- a/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.cc +++ b/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.cc @@ -14,17 +14,13 @@ * limitations under the License. */ -#define UNUSED(x) (void)(x) - #include "Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.h" namespace firebase { namespace firestore { namespace auth { -void EmptyCredentialsProvider::GetToken(bool force_refresh, - TokenListener completion) { - UNUSED(force_refresh); +void EmptyCredentialsProvider::GetToken(TokenListener completion) { if (completion) { // Unauthenticated token will force the GRPC fallback to use default // settings. @@ -39,6 +35,9 @@ void EmptyCredentialsProvider::SetUserChangeListener( } } +void EmptyCredentialsProvider::InvalidateToken() { +} + } // namespace auth } // namespace firestore } // namespace firebase diff --git a/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.h b/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.h index 55b3cc6..3ea0cab 100644 --- a/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.h +++ b/Firestore/core/src/firebase/firestore/auth/empty_credentials_provider.h @@ -26,7 +26,8 @@ namespace auth { /** `EmptyCredentialsProvider` always yields an empty token. */ class EmptyCredentialsProvider : public CredentialsProvider { public: - void GetToken(bool force_refresh, TokenListener completion) override; + void GetToken(TokenListener completion) override; + void InvalidateToken() override; void SetUserChangeListener(UserChangeListener listener) override; }; diff --git a/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.h b/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.h index 0e1da31..f54b72f 100644 --- a/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.h +++ b/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.h @@ -65,10 +65,12 @@ class FirebaseCredentialsProvider : public CredentialsProvider { ~FirebaseCredentialsProvider() override; - void GetToken(bool force_refresh, TokenListener completion) override; + void GetToken(TokenListener completion) override; void SetUserChangeListener(UserChangeListener listener) override; + void InvalidateToken() override; + private: /** * Most contents of the FirebaseCredentialProvider are kept in this @@ -95,6 +97,8 @@ class FirebaseCredentialsProvider : public CredentialsProvider { int user_counter = 0; std::mutex mutex; + + bool force_refresh = false; }; /** diff --git a/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.mm b/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.mm index 9d5b89e..74858c6 100644 --- a/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.mm +++ b/Firestore/core/src/firebase/firestore/auth/firebase_credentials_provider_apple.mm @@ -78,8 +78,7 @@ FirebaseCredentialsProvider::~FirebaseCredentialsProvider() { } } -void FirebaseCredentialsProvider::GetToken(bool force_refresh, - TokenListener completion) { +void FirebaseCredentialsProvider::GetToken(TokenListener completion) { HARD_ASSERT(auth_listener_handle_, "GetToken cannot be called after listener removed."); @@ -121,8 +120,13 @@ void FirebaseCredentialsProvider::GetToken(bool force_refresh, } }; - [contents_->app getTokenForcingRefresh:force_refresh + [contents_->app getTokenForcingRefresh:contents_->force_refresh withCallback:get_token_callback]; + contents_->force_refresh = false; +} + +void FirebaseCredentialsProvider::InvalidateToken() { + contents_->force_refresh = true; } void FirebaseCredentialsProvider::SetUserChangeListener( diff --git a/Firestore/core/test/firebase/firestore/auth/empty_credentials_provider_test.cc b/Firestore/core/test/firebase/firestore/auth/empty_credentials_provider_test.cc index 60845e5..a2f5780 100644 --- a/Firestore/core/test/firebase/firestore/auth/empty_credentials_provider_test.cc +++ b/Firestore/core/test/firebase/firestore/auth/empty_credentials_provider_test.cc @@ -25,15 +25,14 @@ namespace auth { TEST(EmptyCredentialsProvider, GetToken) { EmptyCredentialsProvider credentials_provider; - credentials_provider.GetToken( - /*force_refresh=*/true, [](util::StatusOr<Token> result) { - EXPECT_TRUE(result.ok()); - const Token& token = result.ValueOrDie(); - EXPECT_ANY_THROW(token.token()); - const User& user = token.user(); - EXPECT_EQ("", user.uid()); - EXPECT_FALSE(user.is_authenticated()); - }); + credentials_provider.GetToken([](util::StatusOr<Token> result) { + EXPECT_TRUE(result.ok()); + const Token& token = result.ValueOrDie(); + EXPECT_ANY_THROW(token.token()); + const User& user = token.user(); + EXPECT_EQ("", user.uid()); + EXPECT_FALSE(user.is_authenticated()); + }); } TEST(EmptyCredentialsProvider, SetListener) { @@ -46,6 +45,13 @@ TEST(EmptyCredentialsProvider, SetListener) { credentials_provider.SetUserChangeListener(nullptr); } +TEST(EmptyCredentialsProvider, InvalidateToken) { + EmptyCredentialsProvider credentials_provider; + credentials_provider.InvalidateToken(); + credentials_provider.GetToken( + [](util::StatusOr<Token> result) { EXPECT_TRUE(result.ok()); }); +} + } // namespace auth } // namespace firestore } // namespace firebase diff --git a/Firestore/core/test/firebase/firestore/auth/firebase_credentials_provider_test.mm b/Firestore/core/test/firebase/firestore/auth/firebase_credentials_provider_test.mm index 9d358b5..873f1b2 100644 --- a/Firestore/core/test/firebase/firestore/auth/firebase_credentials_provider_test.mm +++ b/Firestore/core/test/firebase/firestore/auth/firebase_credentials_provider_test.mm @@ -51,30 +51,28 @@ TEST(FirebaseCredentialsProviderTest, GetTokenUnauthenticated) { FIRApp* app = AppWithFakeUid(nil); FirebaseCredentialsProvider credentials_provider(app); - credentials_provider.GetToken( - /*force_refresh=*/true, [](util::StatusOr<Token> result) { - EXPECT_TRUE(result.ok()); - const Token& token = result.ValueOrDie(); - EXPECT_ANY_THROW(token.token()); - const User& user = token.user(); - EXPECT_EQ("", user.uid()); - EXPECT_FALSE(user.is_authenticated()); - }); + credentials_provider.GetToken([](util::StatusOr<Token> result) { + EXPECT_TRUE(result.ok()); + const Token& token = result.ValueOrDie(); + EXPECT_ANY_THROW(token.token()); + const User& user = token.user(); + EXPECT_EQ("", user.uid()); + EXPECT_FALSE(user.is_authenticated()); + }); } TEST(FirebaseCredentialsProviderTest, GetToken) { FIRApp* app = AppWithFakeUidAndToken(@"fake uid", @"token for fake uid"); FirebaseCredentialsProvider credentials_provider(app); - credentials_provider.GetToken( - /*force_refresh=*/true, [](util::StatusOr<Token> result) { - EXPECT_TRUE(result.ok()); - const Token& token = result.ValueOrDie(); - EXPECT_EQ("token for fake uid", token.token()); - const User& user = token.user(); - EXPECT_EQ("fake uid", user.uid()); - EXPECT_TRUE(user.is_authenticated()); - }); + credentials_provider.GetToken([](util::StatusOr<Token> result) { + EXPECT_TRUE(result.ok()); + const Token& token = result.ValueOrDie(); + EXPECT_EQ("token for fake uid", token.token()); + const User& user = token.user(); + EXPECT_EQ("fake uid", user.uid()); + EXPECT_TRUE(user.is_authenticated()); + }); } TEST(FirebaseCredentialsProviderTest, SetListener) { @@ -89,6 +87,34 @@ TEST(FirebaseCredentialsProviderTest, SetListener) { credentials_provider.SetUserChangeListener(nullptr); } +FIRApp* FakeAppExpectingForceRefreshToken(NSString* _Nullable uid, + NSString* _Nullable token) { + FIRApp* app = testutil::AppForUnitTesting(); + app.getUIDImplementation = ^NSString* { + return uid; + }; + app.getTokenImplementation = + ^(BOOL force_refresh, FIRTokenCallback callback) { + EXPECT_TRUE(force_refresh); + callback(token, nil); + }; + return app; +} + +TEST(FirebaseCredentialsProviderTest, InvalidateToken) { + FIRApp* app = + FakeAppExpectingForceRefreshToken(@"fake uid", @"token for fake uid"); + + FirebaseCredentialsProvider credentials_provider{app}; + credentials_provider.InvalidateToken(); + credentials_provider.GetToken([](util::StatusOr<Token> result) { + EXPECT_TRUE(result.ok()); + const Token& token = result.ValueOrDie(); + EXPECT_EQ("token for fake uid", token.token()); + EXPECT_EQ("fake uid", token.user().uid()); + }); +} + } // namespace auth } // namespace firestore } // namespace firebase |