10a: Simple Example comparing Different Optimizers#
Problem formulation#
We’ll consider a very simple setup to clearly illustrate what’s happening: we have a two-dimensional input space
Specifically, we’ll use the following
i.e.
For us as humans it’s quite obvious that
3 Spaces#
In order to understand the following examples, it’s important to keep in mind that we’re dealing with mappings between the three spaces we’ve introduced here:
We’re targeting inverse problems to retrieve an entry in
Implementation#
For this example we’ll use the JAX framework, which represents a nice alternative for efficiently working with differentiable functions.
JAX also has a nice numpy wrapper that implements most of numpy’s functions. Below we’ll use this wrapper as np
, and the original numpy as onp
.
import jax
import jax.numpy as np
import numpy as onp
We’ll start by defining the fun
which calls L and y. Having a single native python function is necessary for many of the JAX operations.
# "physics" function y
def physics_y(x):
return np.array( [x[0], x[1]*x[1]] )
# simple L2 loss
def loss_y(y):
#return y[0]*y[0] + y[1]*y[1] # "manual version"
return np.sum( np.square(y) )
# composite function with L & y , evaluating the loss for x
def loss_x(x):
return loss_y(physics_y(x))
x = np.asarray([3,3], dtype=np.float32)
print("Starting point x = "+format(x) +"\n")
print("Some test calls of the functions we defined so far, from top to bottom, y, manual L(y), L(y):")
physics_y(x) , loss_y( physics_y(x) ), loss_x(x)
Starting point x = [3. 3.]
Some test calls of the functions we defined so far, from top to bottom, y, manual L(y), L(y):
(DeviceArray([3., 9.], dtype=float32),
DeviceArray(90., dtype=float32),
DeviceArray(90., dtype=float32))
Now we can evaluate the derivatives of our function via jax.grad
. E.g., jax.grad(loss_y)(physics_y(x))
evaluates the Jacobian
# this works:
print("Jacobian L(y): " + format(jax.grad(loss_y)(physics_y(x))) +"\n")
# the following would give an error as y (and hence physics_y) is not scalar
#jax.grad(physics_y)(x)
# computing the jacobian of y is a valid operation:
J = jax.jacobian(physics_y)(x)
print( "Jacobian y(x): \n" + format(J) )
# the code below also gives error, JAX grad needs a single function object
#jax.grad( loss_y(physics_y) )(x)
print( "\nSanity check with inverse Jacobian of y, this should give x again: " + format(np.linalg.solve(J, np.matmul(J,x) )) +"\n")
# instead use composite 'fun' from above
print("Gradient for full L(x): " + format( jax.grad(loss_x)(x) ) +"\n")
Jacobian L(y): [ 6. 18.]
Jacobian y(x):
[[1. 0.]
[0. 6.]]
Sanity check with inverse Jacobian of y, this should give x again: [3. 3.]
Gradient for full L(x): [ 6. 108.]
The last line is worth a closer look: here we print the gradient
Let’s see how the different methods cope with this situation. We’ll compare
the first order method gradient descent (i.e., regular, non-stochastic, “steepest gradient descent”),
Newton’s method as a representative of the second order methods,
and scale-invariant updates from inverse simulators.
Gradient descent#
For gradient descent, the simple gradient based update from equation (2)
in our setting gives the following update step in
where
Let’s start the optimization via gradient descent at
x = np.asarray([3.,3.])
eta = 0.01
historyGD = [x]; updatesGD = []
for i in range(10):
G = jax.grad(loss_x)(x)
x += -eta * G
historyGD.append(x); updatesGD.append(G)
print( "GD iter %d: "%i + format(x) )
GD iter 0: [2.94 1.9200001]
GD iter 1: [2.8812 1.6368846]
GD iter 2: [2.823576 1.4614503]
GD iter 3: [2.7671044 1.3365935]
GD iter 4: [2.7117622 1.2410815]
GD iter 5: [2.657527 1.1646168]
GD iter 6: [2.6043763 1.1014326]
GD iter 7: [2.5522888 1.0479842]
GD iter 8: [2.501243 1.0019454]
GD iter 9: [2.4512184 0.96171147]
Here we’ve already printed the resulting positions in
Let’s take a look at the progression over the course of the iterations (the evolution was stored in the history
list above). The blue points denote the positions in
import matplotlib.pyplot as plt
axes = plt.figure(figsize=(4, 4), dpi=100).gca()
historyGD = onp.asarray(historyGD)
updatesGD = onp.asarray(updatesGD) # for later
axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='#1F77B4', label='GD')
axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0
axes.set_xlabel('x0'); axes.set_ylabel('x1'); axes.legend()
<matplotlib.legend.Legend at 0x7fc742a69430>

No surprise here: the initial step mostly moves downwards along
Newton#
For Newton’s method, the update step is given by
Hence, in addition to the same gradient as for GD, we now need to evaluate and invert the Hessian of
This is quite straightforward in JAX: we can call jax.jacobian
two times, and then use the JAX version of linalg.inv
to invert the resulting matrix.
For the optimization with Newton’s method we’ll use a larger step size of
In the next cell, we apply the Newton updates ten times starting from the same initial guess:
x = np.asarray([3.,3.])
eta = 1./3.
historyNt = [x]; updatesNt = []
Gx = jax.grad(loss_x)
Hx = jax.jacobian(jax.jacobian(loss_x))
for i in range(10):
g = Gx(x)
h = Hx(x)
hinv = np.linalg.inv(h)
x += -eta * np.matmul( hinv , g )
historyNt.append(x); updatesNt.append( np.matmul( hinv , g) )
print( "Newton iter %d: "%i + format(x) )
Newton iter 0: [2. 2.6666667]
Newton iter 1: [1.3333333 2.3703704]
Newton iter 2: [0.88888884 2.1069958 ]
Newton iter 3: [0.59259254 1.8728852 ]
Newton iter 4: [0.39506167 1.6647868 ]
Newton iter 5: [0.26337445 1.4798105 ]
Newton iter 6: [0.17558296 1.315387 ]
Newton iter 7: [0.1170553 1.1692328]
Newton iter 8: [0.07803687 1.0393181 ]
Newton iter 9: [0.05202458 0.92383826]
The last line already indicates: Newton’s method does quite a bit better. The last point
Below, we plot the Newton trajectory in orange next to the GD version in blue.
axes = plt.figure(figsize=(4, 4), dpi=100).gca()
historyNt = onp.asarray(historyNt)
updatesNt = onp.asarray(updatesNt)
axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='#1F77B4', label='GD')
axes.scatter(historyNt[:,0], historyNt[:,1], lw=0.5, color='#FF7F0E', label='Newton')
axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0
axes.set_xlabel('x0'); axes.set_ylabel('x1'); axes.legend()
<matplotlib.legend.Legend at 0x7fc7428c5bb0>

Not completely surprising: for this simple example we can reliably evaluate the Hessian, and Newtons’s method profits from the second order information. It’s trajectory is much more diagonal (that would be the ideal, shortest path to the solution), and does not slow down as much as GD.
Inverse simulators#
Now we also use an analytical inverse of
Below, we define our inverse function physics_y_inv_analytic
, and then evaluate an optimization with the PG update for ten steps:
x = np.asarray([3.,3.])
eta = 0.3
historyPG = [x]; historyPGy = []; updatesPG = []
def physics_y_inv(y):
return np.array( [y[0], np.power(y[1],0.5)] )
Gy = jax.grad(loss_y)
Hy = jax.jacobian(jax.jacobian(loss_y))
for i in range(10):
# Newton step for L(y)
zForw = physics_y(x)
g = Gy(zForw)
h = Hy(zForw)
hinv = np.linalg.inv(h)
# step in y space
zBack = zForw -eta * np.matmul( hinv , g)
historyPGy.append(zBack)
# "inverse physics" step via y-inverse
x = physics_y_inv(zBack)
historyPG.append(x)
updatesPG.append( historyPG[-2] - historyPG[-1] )
print( "PG iter %d: "%i + format(x) )
PG iter 0: [2.1 2.5099802]
PG iter 1: [1.4699999 2.1000001]
PG iter 2: [1.0289999 1.7569861]
PG iter 3: [0.72029996 1.47 ]
PG iter 4: [0.50421 1.2298902]
PG iter 5: [0.352947 1.029 ]
PG iter 6: [0.24706289 0.86092323]
PG iter 7: [0.17294402 0.7203 ]
PG iter 8: [0.12106082 0.60264623]
PG iter 9: [0.08474258 0.50421 ]
Now we obtain
Let’s directly visualize how the PGs (in red) fare in comparison to Newton’s method (orange) and GD (blue).
historyPG = onp.asarray(historyPG)
updatesPG = onp.asarray(updatesPG)
axes = plt.figure(figsize=(4, 4), dpi=100).gca()
axes.scatter(historyGD[:,0], historyGD[:,1], lw=0.5, color='#1F77B4', label='GD')
axes.scatter(historyNt[:,0], historyNt[:,1], lw=0.5, color='#FF7F0E', label='Newton')
axes.scatter(historyPG[:,0], historyPG[:,1], lw=0.5, color='#D62728', label='PG')
axes.scatter([0], [0], lw=0.25, color='black', marker='x') # target at 0,0
axes.set_xlabel('x0'); axes.set_ylabel('x1'); axes.legend()
<matplotlib.legend.Legend at 0x7fc742ddba60>

This illustrates that the inverse simulator variant, PG in red, does even better than Newton’s method in orange. It yields a trajectory that is better aligned with the ideal diagonal trajectory, and its final state is closer to the origin. A key ingredient here is the inverse function for
This difference also shows in first update step for each method: below we measure how well it is aligned with the diagonal.
def mag(x):
return np.sqrt(np.sum(np.square(x)))
def one_len(x):
return np.dot( x/mag(x), np.array([1,1]))
print("Diagonal lengths (larger is better): GD %f, Nt %f, PG %f " %
(one_len(updatesGD[0]) , one_len(updatesNt[0]) , one_len(updatesPG[0])) )
Diagonal lengths (larger is better): GD 1.053930, Nt 1.264911, PG 1.356443
The largest value of 1.356 for PG confirms what we’ve seen above: the PG gradient was the closest one to the diagonal direction from our starting point to the origin.
y Space#
To understand the behavior and differences of the methods here, it’s important to keep in mind that we’re not dealing with a black box that maps between
A first thing to note is that for PG, we explicitly map from historyPGy
list above.
Let’s directly take a look what the inverse simulator did in
historyPGy = onp.asarray(historyPGy)
axes = plt.figure(figsize=(4, 4), dpi=100).gca()
axes.set_title('y space')
axes.scatter(historyPGy[:,0], historyPGy[:,1], lw=0.5, color='#D62728', marker='*', label='PG')
axes.scatter([0], [0], lw=0.25, color='black', marker='*')
axes.set_xlabel('z0'); axes.set_ylabel('z1'); axes.legend()
<matplotlib.legend.Legend at 0x7fc742e83a90>

With this variant, we’re making explicit steps in
Interestingly, neither GD nor Newton’s method give us information about progress in intermediate spaces (like the
For GD we’re concatenating the Jacobians, so we’re moving in directions that locally should decrease the loss. However, the
More specifically, we have an update
And
Newton’s method does not fare much better: we compute first-order derivatives like for GD, and the second-order derivatives for the Hessian for the full process. But since both are approximations, the actual intermediate states resulting from an update step are unknown until the full chain is evaluated. In the Consistency in function compositions paragraph for Newton’s method in physgrad the squared
With inverse simulators we do not have this problem: they can directly map points in
In the simple setting of this section, we only have a single latent space, and we already stored all values in history
lists). Hence, now we can go back and re-evaluate physics_y
to obtain the positions in
x = np.asarray([3.,3.])
eta = 0.01
historyGDy = []
historyNty = []
for i in range(1,10):
historyGDy.append(physics_y(historyGD[i]))
historyNty.append(physics_y(historyNt[i]))
historyGDy = onp.asarray(historyGDy)
historyNty = onp.asarray(historyNty)
axes = plt.figure(figsize=(4, 4), dpi=100).gca()
axes.set_title('y space')
axes.scatter(historyGDy[:,0], historyGDy[:,1], lw=0.5, marker='*', color='#1F77B4', label='GD')
axes.scatter(historyNty[:,0], historyNty[:,1], lw=0.5, marker='*', color='#FF7F0E', label='Newton')
axes.scatter(historyPGy[:,0], historyPGy[:,1], lw=0.5, marker='*', color='#D62728', label='PG')
axes.scatter([0], [0], lw=0.25, color='black', marker='*')
axes.set_xlabel('z0'); axes.set_ylabel('z1'); axes.legend()
<matplotlib.legend.Legend at 0x7fc7430c4b20>

These trajectories confirm the intuition outlined in the previous sections: GD in blue gives a very sub-optimal trajectory in
The behavior in intermediate spaces becomes especially important when they’re not only abstract latent spaces as in this example, but when they have actual physical meanings.
Conclusions#
Despite its simplicity, this example already shows surprisingly large differences between gradient descent, Newton’s method, and using the inverse simulator.
The main takeaways of this section are the following.
GD easily yields “unbalanced” updates, and gets stuck.
Newtons method does better, but is far from optimal.
the higher-order information of the inverse simulator outperform both, even if it is applied only partially (we still used Newton’s method for
above).Also, the methods (and in general the choice of optimizer) strongly affects progress in latent spaces, as shown for
above.
In the next sections we can build on these observations to use PGs for training NNs via invertible physical models.
Approximate inversions#
If an analytic inverse like the physics_y_inv_analytic
above is not readily available, we can actually resort to optimization schemes like Newton’s method or BFGS to obtain a local inverse numerically. This is a topic that is orthogonal to the comparison of different optimization methods, but it can be easily illustrated based on the inverse simulator variant from above.
Below, we’ll use the BFGS variant fmin_l_bfgs_b
from scipy
to compute the inverse. It’s not very complicated, but we’ll use numpy and scipy directly here, which makes the code a bit messier than it should be.
def physics_y_inv_opt(target_y, x_ini):
# a bit ugly, we switch to pure scipy here inside each iteration for BFGS
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
target_y = onp.array(target_y)
x_ini = onp.array(x_ini)
def physics_y_opt(x,target_y=[2,2]):
y = onp.array( [x[0], x[1]*x[1]] ) # we cant use physics_y from JAX here
ret = onp.sum( onp.square(y-target_y) )
return ret
ret = fmin_l_bfgs_b(lambda x: physics_y_opt(x,target_y), x_ini, approx_grad=True )
#print( ret ) # return full BFGS details
return ret[0]
print("BFGS optimization test run, find x such that y=[2,2]:")
physics_y_inv_opt([2,2], [3,3])
BFGS optimization test run, find x such that y=[2,2]:
array([2.00000003, 1.41421353])
Nonetheless, we can now use this numerically inverted physics_y_inv_opt
, the rest of the code is unchanged.
x = np.asarray([3.,3.])
eta = 0.3
history = [x]; updates = []
Gy = jax.grad(loss_y)
Hy = jax.jacobian(jax.jacobian(loss_y))
for i in range(10):
# same as before, Newton step for L(y)
y = physics_y(x)
g = Gy(y)
y += -eta * np.matmul( np.linalg.inv( Hy(y) ) , g)
# optimize for inverse physics, assuming we dont have access to an inverse for physics_y
x = physics_y_inv_opt(y,x)
history.append(x)
updates.append( history[-2] - history[-1] )
print( "PG iter %d: "%i + format(x) )
PG iter 0: [2.09999967 2.50998022]
PG iter 1: [1.46999859 2.10000011]
PG iter 2: [1.02899871 1.75698602]
PG iter 3: [0.72029824 1.4699998 ]
PG iter 4: [0.50420733 1.22988982]
PG iter 5: [0.35294448 1.02899957]
PG iter 6: [0.24705997 0.86092355]
PG iter 7: [0.17294205 0.72030026]
PG iter 8: [0.12106103 0.60264817]
PG iter 9: [0.08474171 0.50421247]
This confirms that the approximate inversion works, in line with the regular PG version above. There’s not much point plotting this, as it’s basically the same, but let’s measure the difference. Below, we compute the MAE, which for this simple example turns out to be on the order of our floating point accuracy.
historyPGa = onp.asarray(history)
updatesPGa = onp.asarray(updates)
print("MAE difference between analytic PG and approximate inversion: %f" % (np.average(np.abs(historyPGa-historyPG))) )
MAE difference between analytic PG and approximate inversion: 0.000001