pythonalgorithmsortingmergetradeoff

Merging k sorted lists in Python3, problem with trade-off between memory and time


The input is: The first line - a number of arrays (k); Each next line - the first number is the array size, next numbers are elements.

Max k is 1024. Max array size is 10*k. All numbers between 0 and 100. Memory limit - 10MB, time limit - 1s. Recommended complexity is k ⋅ log(k) ⋅ n, where n is an array length.

Example input:

4            
6 2 26 64 88 96 96
4 8 20 65 86
7 1 4 16 42 58 61 69
1 84

Example output:

1 2 4 8 16 20 26 42 58 61 64 65 69 84 86 88 96 96 

I have 4 solutions. One uses heapq and reading input lines by blocks, one uses heapq, one uses Counter and one uses nothing.

This one uses heapq (good for time but bad for memory, I think heaps are right way, however maybe it can be optimized if I will read lines by parts, so that I won't need a memory for the whole input):

from heapq import merge


if __name__ == '__main__':
    print(*merge(*[[int(el) for el in input().split(' ')[1:]] for _ in range(int(input()))]), sep=' ')

This one is advanced version of the previous one. It reads lines by blocks, however it is very complex solution, I don't know how to optimize those reading:

from heapq import merge
from functools import reduce


def read_block(n, fd, cursors, offset, has_unused_items):
    MEMORY_LIMIT = 10240000
    block_size = MEMORY_LIMIT / n
    result = []

    for i in range(n):
        if has_unused_items[i]:
            if i == 0:
                fd.seek(cursors[i] + offset)
            else:
                fd.read(cursors[i])

            block = ''
            c = 0
            char = ''

            while c < block_size or char != ' ':
                if cursors[i] == 0:
                    while char != ' ':
                        char = fd.read(1)
                        cursors[i] += 1

                char = fd.read(1)

                if char != '\n':
                    block += char
                    cursors[i] += 1
                    c += 1
                else:
                    has_unused_items[i] = False
                    break

            result.append([int(i) for i in block.split(' ')])

            while char != '\n':
                char = fd.read(1)

    return result


def to_output(fd, iter):
    fd.write(' '.join([str(el) for el in iter]))


if __name__ == '__main__':
    with open('input.txt') as fd_input:
        with open('output.txt', 'w') as fd_output:
            n = int(fd_input.readline())
            offset = fd_input.tell()
            cursors = [0] * n
            has_unused_items = [True] * n
            result = []

            while reduce(lambda x, p: x or p, has_unused_items):
                result = merge(
                    result,
                    *read_block(n, fd_input, cursors, offset, has_unused_items)
                )

            to_output(fd_output, result)

This one is good for memory (using sorting with counter, but I didn't use the information that all arrays are sorted):

from collections import Counter


def solution():
    A = Counter()

    for _ in range(int(input())):
        A.update(input().split(' ')[1:])

    for k in sorted([int(el) for el in A]):
        for _ in range(A[str(k)]):
            yield k

This one is good for time (but maybe not enough good):

def solution():
    A = tuple(tuple(int(el) for el in input().split(' ')[1:]) for _ in range(int(input())) # input data
    c = [0] * len(A) # cursors for each array

    for i in range(101):
        for j, a in enumerate(A):
            for item in a[c[j]:]:
                if item == i:
                    yield i
                    c[j] += 1
                else:
                    break 

Perfectly, if I would have arrays by parts in the first example, so that I won't need a memory for the whole input, but I don't know how to read lines by blocks correctly.

Could you please suggest something to solve the problem?


Solution

  • O Deep Thought computer, what is the answer to life the universe and everything

    Here is the code I used for the tests

    """4
    6 2 26 64 88 96 96
    4 8 20 65 86
    7 1 4 16 42 58 61 69
    1 84"""
    
    from heapq import merge
    from io import StringIO
    from timeit import timeit
    
    def solution():
        pass
    
    times = []
    for i in range(5000):
        f = StringIO(__doc__)
        times.append(timeit(solution, number=1))
    
    print(min(times))
    

    And here are the results, I tested solutions proposed in the comments:

    6.5e-06 sec

    def solution():
        A = []
        A = merge(A, *((int(i)
                        for i in line.split(' ')[1:])
                        for line in f.readlines()))
        return A
    

    7.1e-06 sec

    def solution():
        A = []
        for _ in range(int(f.readline())):
            A = merge(A, (int(i) for i in f.readline().split(' ')[1:]))
        return A
    

    7.9e-07 sec

    def solution():
        A = Counter()
        for _ in range(int(f.readline())):
            A.update(f.readline().split(' ')[1:])
        for k in sorted([int(el) for el in A]):
            for _ in range(A[str(k)]):
                yield k
    

    8.3e-06 sec

    def solution():
        A = []
        for _ in range(int(f.readline())):
            for i in f.readline().split(' ')[1:]:
                insort(A, i)
        return A
    

    6.2e-07 sec

    def solution():
        A = Counter()
        for _ in range(int(f.readline())):
            A.update(f.readline().split(' ')[1:])
        l = [int(el) for el in A]
        l.sort()
        for k in l:
            for _ in range(A[str(k)]):
                yield k
    

    Your code is great, don't use sorted (impact becomes more significant with bigger arrays). You should test it with bigger inputs (I used what you gave). enter image description here

    This is with only the winners of the previous one (plus solution 6 which is the second one you gave). It appears that the speed limit is given by the I/O of the program and not the sorting itself. enter image description here

    Note that I generate squares (number of line == numbers per line)