-
Notifications
You must be signed in to change notification settings - Fork 134
Description
There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue #131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:
from torchdyn.core import NeuralODE
import torch
import torch.nn as nn
class VectorField(nn.Module):
def __init__(self):
super(VectorField, self).__init__()
self.net = nn.Linear(2, 2)
def forward(self, t, x):
print(f"In VectorField, t is fed as {t}")
return self.net(t+x)
vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is
def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
False, maxiter, fine_steps, save_at)
ctx.save_for_backward(sol, t_sol)
return t_sol, solSo, basically I don't have any chance to switch it on except changing the source code.
Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like
torchdyn/torchdyn/core/problems.py
Line 85 in a0d0fc5
| return self._autograd_func()(self.vf_params, x, t_span, save_at, args) |
save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.
Screenshots
There is a traceback that shows the problem.

Expected behavior
The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.