Creating Your Own Custom Combine Operator
In this article
A few things to note before we dive in. For those of you unfamiliar with how Combine works, there are a couple of important essential concepts.
Combine controls
A Publisher
does what you think: it publishes data of some kind. An operator, which is not an explicit type, you can think of as something that can take one or several input streams and return a new stream. Finally, there's a Subscriber
which is what receives data. It's important to note that a Publisher
sends data when a Subscriber
subscribes to it.
Incidentally, if you're interested in being able to create your own publishers that are somewhat based on existing ones check out this open-source implementation of Combine. It gives great insight into how Combine works under the hood.
Steps to a custom operator
So what custom operator are we going to create? Well, in this case we're going to create RetryOn
. Combine comes with a Retry
operator that will restart an entire reactive chain when an error is received, and you can tell it to retry n number of times. Example:
let url = URL(string: "https://wwt.com")!
URLSession.shared.dataTaskPublisher(for: url)
.retry(3) //on any failure, retry the entire chain up to 3 times
.sink { /* do something with the result */ }
This is great, but what if we only want to retry when a certain error occurs? Even better, what if we not only want to retry on a specific kind of error, but we'd like to go through an additional stream first?
A good example of this is when making a network call, upon receiving an unauthorized error, use a refresh token to get a new access token and try again once. Something like this:
let url = URL(string: "https://wwt.com")!
URLSession.shared.dataTaskPublisher(for: url)
.retryOn(API.AuthorizationError.unauthorized, retries: 1, chainedPublisher: refresh)
.sink { /* do something with the result */ }
So, like the good TDD practitioners we are, let's start with tests. The first thing we need is some way of manipulating a subscriber from our tests.
To do that let's create a TestPublisher
that gives us access to its Subscriber
.
public class TestPublisher<Output, Failure: Error>: Publisher {
let subscribeBody: (AnySubscriber<Output, Failure>) -> Void
public init(_ subscribe: @escaping (AnySubscriber<Output, Failure>) -> Void) {
self.subscribeBody = subscribe
}
public func receive<S: Subscriber>(subscriber: S) where Failure == S.Failure, Output == S.Input {
self.subscribeBody(AnySubscriber(subscriber))
}
}
Now, when we create a TestPublisher
we have access to a closure that lets us manipulate subscribers. This also means we can force subscribers to be notified that the Publisher
is finished, or sent an error. This sets us up nicely for being able to test our RetryOn
operator. Let's take a look.
class RetryOnTests:XCTestCase {
var subscribers = Set<AnyCancellable>()
func testRetryOnStartsTheChainOverIfTheErrorMatches() {
enum Err: Error {
case e1
case e2
}
var called = 0
let pub = TestPublisher<Int, Err> { s in
s.receive(subscription: Subscriptions.empty)
called += 1
if (called > 3) { s.receive(completion: .finished) }
s.receive(completion: .failure(Err.e1))
}
pub.retryOn(Err.e1, retries: 1)
.sink(receiveCompletion: { _ in }, receiveValue: { _ in })
.store(in: &subscribers)
waitUntil(called > 0)
XCTAssertEqual(called, 2)
}
func testRetryOnStartsTheChainOverTheSpecifiedNumberOfTimesIfTheErrorMatches() {
enum Err: Error {
case e1
case e2
}
let attempts = UInt.random(in: 2...5)
var called = 0
let pub = TestPublisher<Int, Err> { s in
s.receive(subscription: Subscriptions.empty)
called += 1
if (called > attempts) { s.receive(completion: .finished) }
s.receive(completion: .failure(Err.e1))
}
pub.retryOn(Err.e1, retries: attempts)
.sink(receiveCompletion: { _ in }, receiveValue: { _ in })
.store(in: &subscribers)
waitUntil(called > 0)
XCTAssertEqual(called, Int(attempts)+1)
}
func testRetryOnChainsPublishersBeforeRetrying() {
enum Err: Error {
case e1
case e2
}
var called = 0
let refresh = Just(1)
.setFailureType(to: Err.self)
.tryMap { i -> Int in
called += 1
return i
}.mapError { $0 as! Err }
.eraseToAnyPublisher()
Just(1)
.setFailureType(to: Err.self)
.tryMap { _ -> Int in
throw Err.e1
}.mapError { $0 as! Err}
.retryOn(Err.e1, retries: 1, chainedPublisher: refresh)
.sink(receiveCompletion: { _ in }, receiveValue: { _ in })
.store(in: &subscribers)
waitUntil(called > 0)
XCTAssertEqual(called, 1)
}
func testRetryOnDoesNotRetryIfErrorDoesNotMatch() {
enum Err: Error {
case e1
case e2
}
var called = 0
Just(1)
.setFailureType(to: Err.self)
.tryMap { _ -> Int in
called += 1
throw Err.e1
}.mapError { $0 as! Err}
.retryOn(Err.e2, retries: 1)
.sink(receiveCompletion: { _ in }, receiveValue: { _ in })
.store(in: &subscribers)
waitUntil(called > 0)
XCTAssertEqual(called, 1)
}
}
These tests ought to cover our use cases nicely, now let's get our implementation.
First, we need to create our RetryOn
Publisher. Notice that we extend Publishers
(plural) here, as that's where Combine publishers for operators tend to live.
extension Publishers {
/// A publisher that attempts to recreate its subscription to a failed upstream publisher.
struct RetryOn<Upstream: Publisher, ErrorType: Error & Equatable>: Publisher {
typealias Output = Upstream.Output
typealias Failure = Upstream.Failure
let upstream: Upstream
let retries: UInt
let error:ErrorType
let chainedPublisher:AnyPublisher<Output, Failure>?
/// Creates a publisher that attempts to recreate its subscription to a failed upstream publisher.
///
/// - Parameters:
/// - upstream: The publisher from which this publisher receives its elements.
/// - error: An equatable error that should trigger the retry
/// - retries: The number of times to attempt to recreate the subscription.
/// - chainedPublisher: An optional publisher of the same type, to chain before the retry
init(upstream: Upstream, retries: UInt, error:ErrorType, chainedPublisher:AnyPublisher<Output, Failure>?) {
self.upstream = upstream
self.retries = retries
self.error = error
self.chainedPublisher = chainedPublisher
}
func receive<S: Subscriber>(subscriber: S) where Upstream.Failure == S.Failure, Upstream.Output == S.Input {
self.upstream
.catch { e -> AnyPublisher<Output, Failure> in
guard (e as? ErrorType) == self.error,
self.retries > 0 else {
//if it is not the specific error we should retry on, just pass it forward in the stream
return Fail<Output, Failure>(error: e).eraseToAnyPublisher()
}
if let chainedPublisher = self.chainedPublisher {
//if we have a chained publisher, use it before retrying
return chainedPublisher.flatMap { value -> AnyPublisher<Output, Failure> in
self.upstream.retryOn(self.error, retries:self.retries - 1).eraseToAnyPublisher()
}.eraseToAnyPublisher()
}
return self.upstream.retryOn(self.error, retries:self.retries - 1).eraseToAnyPublisher()
}
.subscribe(subscriber)
}
}
}
Next, let's create the operator. Notice this gets extended on Publisher
as that's where operators usually go.
extension Publisher {
/// Attempts to recreate a failed subscription with the upstream publisher using a specified number of attempts to establish the connection.
///
/// After exceeding the specified number of retries, the publisher passes the failure to the downstream receiver.
/// - Parameter error: An equatable error that should trigger the retry
/// - Parameter retries: The number of times to attempt to recreate the subscription.
/// - Parameter chainedPublisher: An optional publisher of the same type, to chain before the retry
/// - Returns: A publisher that attempts to recreate its subscription to a failed upstream publisher.
func retryOn<E: Error & Equatable>(_ error:E, retries: UInt, chainedPublisher:AnyPublisher<Output, Failure>? = nil) -> Publishers.RetryOn<Self, E> {
return .init(upstream: self,
retries: retries,
error: error,
chainedPublisher: chainedPublisher)
}
}
We added documentation here so that it felt like a very natural Combine operator and not something specific to our use-case. One could imagine quite a few use-cases where it'd be nice to retry on a specific error, and chain an additional publisher on.
Mastered the concept of creating Combine operators? Take a look at the next steps in the series: Creating a Highly Testable Networking Layer in Combine.