Intro to Differentiable Swift, Part 4: Differentiable Swift API Details
In Part 3, we got a feel for the AutoDiff API. Now let’s go a little deeper.
Custom Derivatives
Some functions aren’t already differentiable, but you can register custom derivatives to make them differentiable.
sin(x)
is an automatically differentiable function. But let’s pretend it wasn’t, and we needed to make it so:
func sinButNotDifferentiable(x: Double) -> Double{ return sin(x) }
If you’ll recall, the derivative of sin(x) is cos(x). Here’s how we tell that to the compiler: We’ll make a function that returns a tuple of two things: the regular output of the function (called value
), and a pullback
, which will calculate the derivative.
Remember from Part 3, the pullback defines the reverse pass for sin(x). It should receive the derivative with respect to the output, and multiply it by the derivative of sin(x), which is cos(x). This produces the derivative with respect to the input, which the pullback returns.
By doing this, the pullback automatically becomes a link in the chain of pullbacks which make up the reverse pass.*
We’ll label this function with @derivative(of:)
so the compiler can recognize it properly.
typealias D = Double //just for readability @derivative(of: sinButNotDifferentiable) func valueAndPullbackOfSin(x: D) -> (value: D, pullback: (D) -> D) { let regularOutput = sinButNotDifferentiable(x: x) func pullback(derivativeToOutput: D) -> D { let derivativeToInput = cos(x) * derivativeToOutput return derivativeToInput } return (value: regularOutput, pullback: pullback) }
Let’s test our derivative against the built-in automatic differentiation:
let gradAutomatic = gradient(at: 2.2, of: sin) let gradManual = gradient(at: 2.2, of: sinButNotDifferentiable) if gradManual == gradAutomatic { print(“the manually registered derivative is correct!”) } //Prints: the manually registered derivative is correct!
Here are further reference and examples for manually registering derivatives: https://www.tensorflow.org/swift/tutorials/custom_differentiation
Finally, structs are differentiable, for example:
struct Thing: Differentiable{ var a: Double = 5 var b: Double = 5 } func takesThing(thing: Thing) -> Double{ return thing.a * 3 + thing.b * 2 } let thing = Thing() let gradWithRespectToThing = gradient(at: thing, of: takesThing) print(gradWithRespectToThing) // prints: TangentVector(a: 3.0, b: 2.0)
The derivatives with respect to a Thing
are held in an new, automatically created type: Thing.TangentVector
. Thing.TangentVector
also has an a
and b
, and they hold the derivatives that correspond to Thing.a
and Thing.b
.
If Thing
had a var bool: Bool
member along with a
and b
, the bool
member would be left out of Thing.TangentVector
, since Bool
isn’t differentiable.
(Neither is Int
. This is because Bool
and Int
values aren’t continuous, they’re discrete. Discrete values aren’t natively at home in the world of infinitesimals that is Calculus.)
I hope AutoDiff makes some sense now! May you find efficient solutions to all your optimization problems! Happy Differentiating!
If you’re hungry for more, here’s another great AutoDiff tutorial. And of course, here’s the Source of Truth.
Automatic Differentiation in Swift is still in beta. You can download an Xcode toolchain with
import _Differentiable
included from here (You must use a toolchain under the title “Snapshots -> Trunk Development (main)”).
When the compiler starts giving you errors you don’t recognize, check out thisshort guide on the less mature aspects of Differentiable Swift. (Automatic Differentiation in Swift has come a long way, but there are still sharp edges. They are slowly disappearing, though!)
Automatic Differentiation in Swift exists thanks to the Differentiable Swift authors (Richard Wei, Dan Zheng, Marc Rasi, Brad Larson et al.) and the Swift Community!See the latest pull requests involving AutoDiff here.
*Note: If you decide to manully register derivatives for functions that involve nested operations, for example sin(x²), don’t forget about the chain rule!
As a quick clue, a function like sin(x²) is actually two functions, x² wrapped by sin. Or, f(g(x)) where f(x) = sin(x) and g(x) = x². In this case of nested functions, the chain rule defines the derivative: f’(g(x)) * g’(x). You might intuitively guess that it’s just f’(g’(x)), but it’s not.
The derivative of f(x), aka f’(x), is cos(x), and the derivative of g(x), aka g’(x), is 2x, so the derivative of sin(x²) is actually
f’(g(x)) * g’(x), or cos(x²) * 2x.
In most cases however, this can be avoided by registering derivatives for f(x) and g(x) separately, and then the compiler will implement the chain rule for you automatically.