matlabparfor

Matlab parallel code with many anonymous functions leading to memory errors


I have a code that solves a scientific problem with many different inputs/parameters. I'm using a parallel for loop to iterate through a range of parameters, and running into trouble with memory usage. I've done my best to put together a MWE that represents my code.

Basically, for each parameter combination I run a small loop over several different solver options. In my real code, this is changing solver tolerances and the equations used (we have a few different transformation which can help conditioning). Each computation is effectively a shooting method for a small ODE system (3 equations, but each is quite complicated and generally stiff), with an optimisation routine calling the ODE solver. This takes seconds/minutes to run each time, the the parallelisation overhead is negligible, and the speedup scales pretty much exactly with the number of cores.

To explain the code below, start with driver. First define some parameters (a and f in the MWE) and save them in a file. The filename gets passed around between functions. Then create the 3 (in this case) sets of solver parameters, which choose the ode solver, tolerance, and set of equations to use. Then enter the for loop, looping over some other parameter c, at each iteration using each of the sets of solver parameters to call the optimisation function. Finally, I save a temporary file with the results of each iteration (so I don't lose everything if the server goes down). These files are about 1kB, and I will only have around 10,000 of them, so the overall size is on the order of 10MB. After the main loop I recombine everything back into single vectors.

The equations function creates the actual differential equations to solve, this is done using a switch statement to choose which equations to return. The objectiveFunction function uses str2func to specific the ODE solver, calls equations to get the equations to solve, then solves them and computes an objective function value.

The problem is that there appears to be some sort of memory leak. After some time, on the order of days, the code slows down and finally gives an out-of-memory error (running on 48 cores with ~380GB memory available, ode15s gave the error). The increase in memory usage over time is fairly gradual, but is definitely there, and I can't figure out what is causing it.

The MWE with 10,000 values c takes quite a while to run (1,000 is probably sufficient actually), and the memory usage per worker does increase over time. I think the file loading/saving and job distribution cause quite a lot of overhead, unlike my actual code, but this doesn't affect memory usage.

My question is, what could be causing this slow increase in memory usage?

My ideas for what is causing the problem are:

  1. Using str2func isn't great, should I use a switch instead and accept having to write the solvers into the code explicitly?
  2. All the anonymous functions getting called all the time (in the ODE solver) are holding on to workspace data and not releasing it at the end of each parfor iteration
  3. Suppressed warnings are causing issues: I suppress lots of ODE step size warnings (this shouldn't be a factor because the bug that means this causes issues was fixed in 2017a, and the server I use runs 2017b)
  4. Something in fminbnd or ode15s is actually leaking memory

I can't come up with a way to get around 1 and 2 nicely and efficiently (both from a code performance and code writing point of view), and I doubt 3 or 4 are actually the problem.

Here is the driver function:

function [xi,mfv] = driver()

% a and f are used in all cases. In actual code these are defined in a
% separate function
paramFile = 'params';
a = 4;
f = @(x) 2*x;

% this filename (params) gets passed around from function to function
save('params.mat','a','f')

% The struct setup has specifc options for the each iteration
setup(1).method = 'ode45'; % any ODE solver can be used here
setup(1).atol = 1e-3; % change the ODE solver tolerance
setup(1).eqs = 'second'; % changes what equations are solved

setup(2).method = 'ode15s';
setup(2).atol = 1e-3;
setup(2).eqs = 'second';

setup(3).method = 'ode15s';
setup(3).atol = 1e-4;
setup(3).eqs = 'first';

c = linspace(0,1);

parfor i = 1:numel(c) % loop over parameter c
    xi = 0;
    minFVal = inf;
    for j = 1:numel(setup) % loop over each set configuration setup

        % find optimal initial condition and record corresponding value of
        % objective function
        [xInitial,fval] = fminsearch(@(x0) objectiveFunction(x0,c(i),...
            paramFile,setup(j)),1);

        if fval<minFVal % keep the best solution
            xi = xInitial;
            minFVal = fval;
        end
    end
    % save some variables
    saveInParForLoop(['tempresult_' num2str(i)],xi,minFVal);
end

% Now combine temporary files into single vectors
xi = zeros(size(c)); mfv = xi;
for i = 1:numel(c)
    S = load(['tempresult_' num2str(i) '.mat'],'xi','minFVal');
    xi(i) = S.xi;
    mfv(i) = S.minFVal;
end

% delete the temporary files now that the data has been consolidated
for i = 1:numel(c)
    delete(['tempresult_' num2str(i) '.mat']);
end
end

function saveInParForLoop(filename,xi,minFVal)
% you can't save directly in a parfor loop, this is the workaround
save(filename,'xi','minFVal')
end

Here is the function to define the equations

function [der,transform] = equations(paramFile,setup)
% Defines the differential equation and a transformation for the solution
% used to calculate the objective function
% Note in my actual code I generate these equations earlier
% and pass them around directly, rather than always redefining them

load(paramFile,'a','f')

switch setup.eqs
    case 'first'
        der = @(x) f(x)*2+a;
        transform = @(x) exp(x);
    case 'second'
        der = @(x) f(x)/2-a;
        transform = @(x) sqrt(abs(x));
end

and here is the function to evaluate the objective function

function val = objectiveFunction(x0,c,paramFile,setup)

load(paramFile,'a')

% specify the ODE solver and AbsTol from s
solver = str2func(setup.method);
options = odeset('AbsTol',setup.atol);

% get the differential equation and transform equations
[der,transform] = equations(paramFile,setup);
dxdt = @(t,y) der(y);

% solve the IVP
[~,y] = solver(dxdt,0:.05:1,x0,options);

% calculate the objective function value
val = norm(transform(y)-c*a);

If you run this code it will create 100 temporary files, then delete them, and it will also create the params file, which won't be deleted. You will need the parallel computing toolbox.


Solution

  • There's just a chance you might be running into this known problem: https://uk.mathworks.com/support/bugreports/1976165 . This is marked as being fixed in R2019b, which has just been released. (The leak caused by this is tiny but persistent - so it might indeed take days to become apparent).