diff --git a/SSignalKit/SSignal+Catch.h b/SSignalKit/SSignal+Catch.h index 5cbdad90c9..dc8898c062 100644 --- a/SSignalKit/SSignal+Catch.h +++ b/SSignalKit/SSignal+Catch.h @@ -4,5 +4,6 @@ - (SSignal *)catch:(SSignal *(^)(id error))f; - (SSignal *)restart; +- (SSignal *)retryIf:(bool (^)(id error))predicate; @end diff --git a/SSignalKit/SSignal+Catch.m b/SSignalKit/SSignal+Catch.m index 2387c88946..f61e277750 100644 --- a/SSignalKit/SSignal+Catch.m +++ b/SSignalKit/SSignal+Catch.m @@ -91,4 +91,57 @@ static dispatch_block_t recursiveBlock(void (^block)(dispatch_block_t recurse)) }]; } +- (SSignal *)retryIf:(bool (^)(id error))predicate { + return [[SSignal alloc] initWithGenerator:^id (SSubscriber *subscriber) + { + SAtomic *shouldRestart = [[SAtomic alloc] initWithValue:@true]; + + SMetaDisposable *currentDisposable = [[SMetaDisposable alloc] init]; + + void (^start)() = recursiveBlock(^(dispatch_block_t recurse) + { + NSNumber *currentShouldRestart = [shouldRestart with:^id(NSNumber *current) + { + return current; + }]; + + if ([currentShouldRestart boolValue]) + { + id disposable = [self startWithNext:^(id next) + { + [subscriber putNext:next]; + } error:^(id error) + { + if (predicate(error)) { + recurse(); + } else { + [subscriber putError:error]; + } + } completed:^ + { + [shouldRestart modify:^id(__unused id current) { + return @false; + }]; + [subscriber putCompletion]; + }]; + [currentDisposable setDisposable:disposable]; + } else { + [subscriber putCompletion]; + } + }); + + start(); + + return [[SBlockDisposable alloc] initWithBlock:^ + { + [currentDisposable dispose]; + + [shouldRestart modify:^id(__unused id current) + { + return @false; + }]; + }]; + }]; +} + @end diff --git a/SSignalKitTests/SSignalBasicTests.m b/SSignalKitTests/SSignalBasicTests.m index cb5a986af6..56ca98337c 100644 --- a/SSignalKitTests/SSignalBasicTests.m +++ b/SSignalKitTests/SSignalBasicTests.m @@ -702,4 +702,63 @@ } } +- (void)testRetryIfNoError { + SSignal *s = [[SSignal single:@1] retryIf:^bool(__unused id error) { + return true; + }]; + [s startWithNext:^(id next) { + XCTAssertEqual(next, @1); + }]; +} + +- (void)testRetryErrorNoMatch { + SSignal *s = [[SSignal fail:@false] retryIf:^bool(id error) { + return false; + }]; +} + +- (void)testRetryErrorMatch { + __block counter = 1; + SSignal *s = [[[SSignal alloc] initWithGenerator:^id (SSubscriber *subscriber) { + if (counter == 1) { + counter++; + [subscriber putError:@true]; + } else { + [subscriber putNext:@(counter)]; + } + return nil; + }] retryIf:^bool(id error) { + return [error boolValue]; + }]; + + __block int value = 0; + [s startWithNext:^(id next) { + value = [next intValue]; + }]; + + XCTAssertEqual(value, 2); +} + +- (void)testRetryErrorFailNoMatch { + __block counter = 1; + SSignal *s = [[[SSignal alloc] initWithGenerator:^id (SSubscriber *subscriber) { + if (counter == 1) { + counter++; + [subscriber putError:@true]; + } else { + [subscriber putError:@false]; + } + return nil; + }] retryIf:^bool(id error) { + return [error boolValue]; + }]; + + __block bool errorMatches = false; + [s startWithNext:nil error:^(id error) { + errorMatches = ![error boolValue]; + } completed:nil]; + + XCTAssert(errorMatches); +} + @end