This module implements automatic differentiation. Automatic (or algorithmic) differentiation is exact (up to floating point arithmetic limitations) and it is based on the mathematical rules of symbolic differentiation. No numerical approximation is introduced as it is the case for finite-difference or interpolation methods commonly used in numerical differentiation. Moreover automatic differentiation can differentiate arbitrary functions, dealing automatically with conditional tests and loops:
local math = require("sci.math").generic
local diff = require("sci.diff")
-- Scalar case, one variable:
local function f1(x)
return math.sin(x)
end
local x = 1.3
-- Partial derivative, exact:
local df1 = diff.derivativef(f1, 1)
local y, dx = df1(x)
assert(y == f1(x))
assert(dx == math.cos(x))
-- Scalar case, two variables:
local function f2(x1, x2)
return math.sin(2*x1 - x2)
end
local x1, x2 = 1.3, 0.6
-- Partial derivatives, exact:
local df2 = diff.derivativef(f2, 2)
local y, dx, dy = df2(x1, x2)
assert(y == f2(x1, x2))
assert(dx == 2*math.cos(2*x1 - x2))
assert(dy == -math.cos(2*x1 - x2))
-- Vector case, dimension 5:
local function f3(x) -- Sum of squares.
local sum = 0
for i=1,#x do sum = sum + x[i]^2 end
return sum
end
local x = alg.tovec({ 1,2,3,4,5 })
-- The function df3 will compute value and gradient:
local df3 = diff.gradientf(f3, 5)
local grad = alg.vec(5) -- To store the gradient.
local y = df3(x, grad) -- Compute value and derivatives.
assert(y == f3(x)) -- Value is exact.
for i=1,5 do
assert(grad[i] == 2*x[i]) -- Gradient is exact.
end
First of all use sci.math.generic
instead of sci.math
(and math
) everywhere the arguments of the to-be-differentiated function
can propagate. For example in the code below f1()
must use gmath.exp
even if we're interested in differentiating f2()
only:
local math = require("sci.math")
local gmath = math.generic
local function f1(x)
return gmath.exp(x)
end
local function f2(x)
return f1(x)
end
Then create a new function that that computes jointly the function value and the derivatives using diff.derivativef()
if the function to be differentiated
takes as input multiple scalars or using diff.gradientf()
if the function to be differentiated takes as input a single vector.
In most cases the two steps detailed above are everything that needs to be done. Automatic differentiation is here implemented via operators and functions overloading: instead of passing Lua numbers (or a vector of Lua numbers) to the function to be differentiated we pass a custom type which carries both value and derivative information. Cases that requires further changes:
Case | Fix |
---|---|
sci.alg vectors and matrices of Lua numbers are homogeneous in the element type | vectors and matrices with element type diff.dn |
in the loop for i=first,last,incr do , the expressions first,last,incr must evaluate to Lua numbers |
tonumber() |
type(x) == "number" only evaluates true for Lua numbers |
type(x) == "number" or ffi.istype(diff.dn, x) |
Takes as input an integer n
, a function f
which can accept n
arguments of type diff.dn
and an optional list of
distinct integers between 1
and n
. Returns a function with signature df(x1, x2, ..., xn)
which takes as input n
Lua
numbers. If no optional list of integers has been passed to diff.derivativef()
then df()
returns
f(x), df/dx1, df/dx2, ..., df/dxn
where all derivatives are evaluated at x1, x2, ..., xn
. Otherwise the function df()
returns
f(x), df/dxj1, df/dxj2, ..., df/dxjm
where j1, j2, ..., jm
are the distinct integers that has been passed to
diff.derivativef()
and all derivatives are evaluated at x1, x2, ..., xn
. This allows for the computation of a limited subset of partial
derivatives.
Takes as input an integer n
and a function f
which can accept as single argument a vector of element type diff.dn
of length
n
. Returns a function with signature df(x, grad)
which takes as input two vectors of Lua numbers of length n
. The function
df()
sets grad
to the gradient of f
computed at x
and returns the value of f(x)
.
FFI ctype which is used to implement automatic differentiation: instead of passing Lua numbers (or a vector of Lua numbers) to the function to be differentiated objects of this type (or a vector with this element type) are passed.