Say I have some complicated function f(fvar1, ..., fvarN)
such as:
def f(fvar1,..., fvarN):
return (complicated function of fvar1, ..., fvarN).
Now function g(gvar1, ..., gvarM)
has an expression in terms of f(fvar1, ..., fvarN)
, let's say:
def g(gvar1, ..., gvarM):
return stuff * f(gvar1 * gvar2, ..., gvar5 * gvarM) - stuff * f(gvar3, gvar2, ..., gvarM)
where the arguments of f
inside g
can be different linear combinations of gvar1, ..., gvarM
.
Because f
is a complicated function, it is costly to call f
, but it is also difficult to store the value locally in g
because g
has many instances of f
with different argument combinations.
Is there a way to store values of f
such that f
of the same values are not called again and again without having to define every different instance of f
locally within g
?
Yes, this is called memoisation. The basic idea is to have f()
maintain some sort of data store based on the parameters passed in. Then, if it's called with the same parameters, it simply returns the stored value rather than recalculating it.
The data store probably needs to be limited in size and optimised for the pattern of calls you expect, by removing parameter sets based on some rules. For example, if the number of times a parameter set is used indicates its likelihood of being used in future, you probably want to remove patterns that are used infrequently, and keep those that are use more often.
Consider, for example, the following Python code for adding two numbers (let us pretend that this is a massively time-expensive operation):
import random
def addTwo(a, b):
return a + b
for _ in range(100):
x = random.randint(1, 5)
y = random.randint(1, 5)
z = addTwo(x, y)
print(f"{x} + {y} = {z}")
That works but, of course, is inefficient if you use the same numbers as used previously. You can add memoisation as follows.
The code will "remember" a certain number of calculations (probably random, given the dictionaries but I won't guarantee that). If it gets a pair it already knows about, it just returns the cached value.
Otherwise, it calculates the value, storing it into the cache, and ensuring said cache doesn't grow too big:
import random, time
# Cache, and the stats for it.
(pairToSumMap, cached, calculated) = ({}, 0, 0)
def addTwo(a, b):
global pairToSumMap, cached, calculated
# Attempt two different cache lookups first (a:b, b:a).
sum = None
try:
sum = pairToSumMap[f"{a}:{b}"]
except:
try:
sum = pairToSumMap[f"{b}:{a}"]
except:
pass
# Found in cache, return.
if sum is not None:
print("Using cached value: ", end ="")
cached += 1
return sum
# Not found, calculate and add to cache (with limited cache size).
print("Calculating value: ", end="")
calculated += 1
time.sleep(1) ; sum = a + b # Make expensive.
if len(pairToSumMap) > 10:
del pairToSumMap[list(pairToSumMap.keys())[0]]
pairToSumMap[f"{a}:{b}"] = sum
return sum
for _ in range(100):
x = random.randint(1, 5)
y = random.randint(1, 5)
z = addTwo(x, y)
print(f"{x} + {y} = {z}")
print(f"Calculated {calculated}, cached {cached}")
You'll see I've also added cached/calculated information, including a final statistics line which shows the caching in action, for example:
Calculated 29, cached 71
I've also made the calculation an expensive operation so you can see it in action (as per the speed of output). Ones that are cached will come back immediately, calculating the sum will take a second.