, 2 min read
Reply to: Neural Network Back-Propagation Revisited with Ordinary Differential Equations
This article is worth reading: Neural Network Back-Propagation Revisited with Ordinary Differential Equations
I replied:
Thank you very much for this very informative article providing many links, the Python code, and the results.
According the mentioned paper from Owens + Filkin the speedup expected by using a stiff ODE solver should be two to 1,000. So your results demonstrate that the number of iterations clearly is way less than for gradient descent and all its variants. Unfortunately, the run times reported are way slower than for gradient descent. This was not expected.
The following factors could play a role in this:
- All the methods search for a local minimum (gradient should be zero, Hessian is not checked or known). They do not necessarily find a global minimum. So when these different methods are run, they each probably iterate towards different local minima. So each methods likely has done something different.
- I wonder why you have used the zvode solver from scipy.integrate. I would recommend vode, or even better lsoda. Your chosen tolerances are quite high (atol=1e-8, rtol=1e-6), at least for the beginning. These high tolerances may force smaller step-sizes than actually required. In particular, as the start-values are random, there seems to be no compelling reason to use these strict tolerances right from the start. Also it is known, that strict tolerances might lead the ODE code to use higher order BDF which in turn are not stable enough for highly stiff ODEs. Only BDF up to order 2 are A-stable. So atol=1e-4, rtol=1e-4 might show different behaviour.
- Although it is expected that the resulting ODE is stiff, it might be the case that in your particular setting the system was only very mildly stiff. Your charts give some indications of that, at least in the beginning. This can be checked by simply re-running with a non-stiff code, e.g., like dopri5. Again with lower tolerances.
- As can be seen in your chart, above learning accuracy of 80% the ODE solver takes a huge number of steps. It would be interesting to know whether at this stage there are many Jacobian evaluations for the ODE solver, or whether there are many rejected steps due to Newton convergence failures within the ODE code.
- scipy.integrate apparently does not allow for automatic differentiation. Therefore the ODE solver must resort to numerical differencing for evaluating the Jacobian, which is very slow. So using something like Zygote might improve here.
I always wondered why the findings from Owens+Filkin were not widely adopted. Your paper provides an answer, although a negative one. Taking into account the above points, I still have hope that stiff ODE solvers have great potential for machine learning with regard to performance. You already mentioned that ODE solvers provide the benefit that hyperparameters no longer need to be estimated.