My purpose is to count permutations with certain properties. I first generate the permutations and then remove those that do not satisfy the desired properties. How could I improve the code to be able to enumerate more permutations?
from itertools import permutations
def check(seq, verbose=False):
"""check that the elements of the sequence equal a difference of previous elements"""
n = len(seq)
for k in range(1, n-1):
# build a set of admissible values
dk = {abs(seq[i]-seq[j]) for i in range(0, k) for j in range(i+1, k+1) if i < j}
if k > 0 and verbose:
print('current index = ', k)
print('current subsequence = ', seq[:k+1])
print('current admissible values = ', dk)
print('next element = ', seq[k+1])
# check if the next element is in the set of admissible values
if k > 0 and seq[k+1] not in dk:
# return an invalid subsequence (k+2 to include the invalid element)
return seq[:k+2]
return seq
def is_valid(seq):
"""check that the sequence satisfies certain properties"""
n = len(seq)
if n < 3:
return False
if len(check(seq)) == n:
return True
return False
def filter_perms(perms):
for perm in perms:
if is_valid(perm): yield perm
def make_perms(n):
"""The elements of the list are integers, where a list of length n stores all integers from 1 to n."""
for p in permutations(range(1,n-1)):
yield (n,) + p + (n-1,)
def enumerate_perms(n):
perms = make_perms(n)
return filter_perms(perms)
# testing a good sequence
seq=(5, 2, 3, 1, 4)
check(seq, verbose=True)
# True
# testing a bad sequence
seq=[5, 2, 1, 3, 4]
check(seq, verbose=True)
# False
# testing permutations
# testing enumeration
# ((5, 2, 3, 1, 4), (5, 3, 2, 1, 4))
# 29340
Summary of discussions in the comments section: Should I use numpy arrays? Should I save the permutations to a database?
Filtering is a good approach in terms of implementation simplicity. However, the problem is that it has to iterate through all possible permutations, which is impossible when N becomes too large.
For example, for N=20, you are going to iterate over approx 6 quadrillion elements. It will take several days just to iterate over them. This means, instead of filtering, it is necessary to prevent the generation of unwanted elements (it is called pruning).
Since you have already commented about this idea, I am going to skip the explanation and just show you how to implement it. (I'll attach the full code, including the benchmark, at the bottom, so please use that if you want to try it out.)
def _filtered_permutations(
items: Sequence[int],
not_used_indexes: set[int],
current_permutation: list[int],
# The condition function is configurable to make this function more versatile.
is_valid: Callable[[Sequence[int]], bool],
) -> Iterable[list[int]]:
if len(current_permutation) == len(items):
# Since we picked a valid element below, there is no need to validate it here.
yield current_permutation
for i in not_used_indexes:
# This is the in-progress validation for the partial permutation.
# If it is found to violate the conditions, we skip the further search.
# This greatly reduces the generation of unwanted elements.
next_permutation = [*current_permutation, items[i]]
if not is_valid(next_permutation):
yield from _filtered_permutations(
not_used_indexes=not_used_indexes - {i},
def is_valid_original(seq: Sequence[int]) -> bool:
# This is equivalent to your original is_valid function, except I've removed the unnecessary list creation.
for i in range(len(seq) - 2):
if seq[i + 2] not in {abs(seq[j] - seq[k]) for j in range(1, i + 2) for k in range(i + 1)}:
return False
return True
def enumerate_perms_pre_filter(n: int) -> Iterable[tuple[int]]:
lst = list(range(1, n - 1))
head = n
tail = n - 1
def _is_valid(current):
"""A wrapper function that inserts the first and last elements."""
if len(current) == len(lst):
return is_valid_original((head, *current, tail))
return is_valid_original((head, *current))
for p in _filtered_permutations(
yield head, *p, tail
This ran 20-30 times faster than the original code on my PC.
At this point, the bottleneck is that the is_valid
function is calculating the differences each time.
I'm not sure if I can make this versatile enough, so here I will provide an implementation specific to your problem.
This may look similar to the implementation above, but achieves a significant speedup by decomposing the in_valid
function and carrying the differences instead.
def _enumerate_perms_optimized(
items: Sequence[int],
current_permutation: list[int],
not_used_indexes: set[int],
current_diffs: set[int],
) -> Iterator[tuple[int]]:
if len(current_permutation) == len(items) + 1:
# Since we picked a valid element below, there is no need to validate it here.
yield current_permutation
for i in not_used_indexes:
# By carrying the differences, we can validate it by simply checking whether the set contains the value.
next_value = items[i]
if len(current_permutation) > 1 and next_value not in current_diffs:
yield from _enumerate_perms_optimized(
not_used_indexes=not_used_indexes - {i},
current_permutation=[*current_permutation, next_value],
# The differences between the elements that have already been added have already been calculated,
# so we only need to calculate the differences between the new element and them.
current_diffs={*current_diffs, *(abs(next_value - prev_value) for prev_value in current_permutation)},
def enumerate_perms_optimized(n: int) -> Iterator[tuple[int]]:
lst = list(range(1, n - 1))
head = n
tail = n - 1
for p in _enumerate_perms_optimized(
yield *p, tail
Here is the test and the benchmark.
import time
from import Iterator, Sequence
from itertools import permutations
from typing import Callable, Iterable
# ---------- Original implementation ----------
def is_valid_original(seq: Sequence[int]) -> bool:
# This is equivalent to your original is_valid function, except I've removed the unnecessary list creation.
for i in range(len(seq) - 2):
if seq[i + 2] not in {abs(seq[j] - seq[k]) for j in range(1, i + 2) for k in range(i + 1)}:
return False
return True
def enumerate_perms_1(n: int) -> list[tuple[int]]:
lst = list(range(1, n - 1))
perms = [p for perm in permutations(lst) for p in [(n, *perm, n - 1)] if is_valid_original(p)]
return perms
# ---------- Updated implementation ----------
def check(seq, verbose=False):
"""check that the elements of the sequence equal a difference of previous elements"""
n = len(seq)
for k in range(1, n - 1):
# build a set of admissible values
dk = {abs(seq[i] - seq[j]) for i in range(0, k) for j in range(i + 1, k + 1) if i < j}
if k > 0 and verbose:
print("current index = ", k)
print("current subsequence = ", seq[: k + 1])
print("current admissible values = ", dk)
print("next element = ", seq[k + 1])
# check if the next element is in the set of admissible values
if k > 0 and seq[k + 1] not in dk:
# return an invalid subsequence (k+2 to include the invalid element)
return seq[: k + 2]
return seq
def is_valid(seq):
"""check that the sequence satisfies certain properties"""
n = len(seq)
if n < 3:
return False
if len(check(seq)) == n:
return True
return False
def filter_perms(perms):
for perm in perms:
if is_valid(perm):
yield perm
def make_perms(n):
"""The elements of the list are integers, where a list of length n stores all integers from 1 to n."""
for p in permutations(range(1, n - 1)):
yield (n,) + p + (n - 1,)
def enumerate_perms_2(n):
perms = make_perms(n)
return filter_perms(perms)
# ---------- Pre-filtered implementation ----------
def _filtered_permutations(
items: Sequence[int],
not_used_indexes: set[int],
current_permutation: list[int],
# The condition function is configurable to make this function more versatile.
is_valid: Callable[[Sequence[int]], bool],
) -> Iterable[list[int]]:
if len(current_permutation) == len(items):
# Since we picked a valid element below, there is no need to validate it here.
yield current_permutation
for i in not_used_indexes:
# This is the in-progress validation for the partial permutation.
# If it is found to violate the conditions, we skip the further search.
# This greatly reduces the generation of unwanted elements.
next_permutation = [*current_permutation, items[i]]
if not is_valid(next_permutation):
yield from _filtered_permutations(
not_used_indexes=not_used_indexes - {i},
def enumerate_perms_pre_filter(n: int) -> Iterable[tuple[int]]:
lst = list(range(1, n - 1))
head = n
tail = n - 1
def _is_valid(current):
"""Wrapper function that inserts the first and last elements."""
if len(current) == len(lst):
return is_valid_original((head, *current, tail))
return is_valid_original((head, *current))
for p in _filtered_permutations(
yield head, *p, tail
# ---------- Optimized implementation ----------
def _enumerate_perms_optimized(
items: Sequence[int],
current_permutation: list[int],
not_used_indexes: set[int],
current_diffs: set[int],
) -> Iterator[tuple[int]]:
if len(current_permutation) == len(items) + 1:
# Since we picked a valid element below, there is no need to validate it here.
yield current_permutation
for i in not_used_indexes:
# By carrying the differences, we can validate it by simply checking whether the set contains the value.
next_value = items[i]
if len(current_permutation) > 1 and next_value not in current_diffs:
yield from _enumerate_perms_optimized(
not_used_indexes=not_used_indexes - {i},
current_permutation=[*current_permutation, next_value],
# The differences between the elements that have already been added have already been calculated,
# so we only need to calculate the differences between the new element and them.
current_diffs={*current_diffs, *(abs(next_value - prev_value) for prev_value in current_permutation)},
def enumerate_perms_optimized(n: int) -> Iterator[tuple[int]]:
lst = list(range(1, n - 1))
head = n
tail = n - 1
for p in _enumerate_perms_optimized(
yield *p, tail
def test_implementations(candidates, n: int):
expected = sorted(candidates[0](n))
expected = [tuple(p) for p in expected]
for f in candidates[1:]:
actual = sorted(f(n))
actual = [tuple(p) for p in actual]
assert expected == actual, f"Results differ for n={n} with {f.__name__}"
def measure_performance(f, n_range: range):
for n in n_range:
n_perms = 0
started = time.perf_counter()
for _ in f(n):
n_perms += 1
elapsed = time.perf_counter() - started
print(f"{f.__name__}({n=}): {elapsed:.3f} sec, {n_perms=:,}")
def main():
candidates = [
for n in range(3, 10):
test_implementations(candidates, n)
print(f"Tests passed for {n=}")
measure_performance(enumerate_perms_1, range(10, 14))
measure_performance(enumerate_perms_2, range(10, 14))
measure_performance(enumerate_perms_pre_filter, range(10, 16))
measure_performance(enumerate_perms_optimized, range(10, 21))
Tests passed for n=3
Tests passed for n=4
Tests passed for n=5
Tests passed for n=6
Tests passed for n=7
Tests passed for n=8
Tests passed for n=9
enumerate_perms_1(n=10): 0.029 sec, n_perms=36
enumerate_perms_1(n=11): 0.268 sec, n_perms=598
enumerate_perms_1(n=12): 2.413 sec, n_perms=1,096
enumerate_perms_1(n=13): 27.832 sec, n_perms=14,030
enumerate_perms_2(n=10): 0.035 sec, n_perms=36
enumerate_perms_2(n=11): 0.327 sec, n_perms=598
enumerate_perms_2(n=12): 3.098 sec, n_perms=1,096
enumerate_perms_2(n=13): 34.459 sec, n_perms=14,030
enumerate_perms_pre_filter(n=10): 0.001 sec, n_perms=36
enumerate_perms_pre_filter(n=11): 0.025 sec, n_perms=598
enumerate_perms_pre_filter(n=12): 0.058 sec, n_perms=1,096
enumerate_perms_pre_filter(n=13): 0.982 sec, n_perms=14,030
enumerate_perms_pre_filter(n=14): 2.595 sec, n_perms=29,340
enumerate_perms_pre_filter(n=15): 24.978 sec, n_perms=223,350
enumerate_perms_optimized(n=10): 0.000 sec, n_perms=36
enumerate_perms_optimized(n=11): 0.003 sec, n_perms=598
enumerate_perms_optimized(n=12): 0.005 sec, n_perms=1,096
enumerate_perms_optimized(n=13): 0.064 sec, n_perms=14,030
enumerate_perms_optimized(n=14): 0.138 sec, n_perms=29,340
enumerate_perms_optimized(n=15): 1.095 sec, n_perms=223,350
enumerate_perms_optimized(n=16): 9.790 sec, n_perms=1,936,172
enumerate_perms_optimized(n=17): 164.876 sec, n_perms=28,038,794
enumerate_perms_optimized(n=18): 542.482 sec, n_perms=90,125,652
As you can see, it is hundreds of times faster than the original implementation, but it will still take several hours if not days for n=20.
Finally, I would also like to mention the options that were mentioned in the comments.
All of the above codes are PyPy-compatible, so if you install PyPy, you can run them without any modifications, and they will probably be several times faster, but that's all. You can't expect it to be any faster than that from PyPy.
Cython may be more effective, but as you know, you need to learn its syntax, so you may find it difficult to learn.
Numba does not work well with this approach. Numba does not support recursive-generator functions. Some major changes are needed. I'm not sure if we can get a good performance without significantly complicating the code.