pythonmultithreadingbytendjson

Calculate byte positions to split ndjson files into chunks


Below is the code extracted from this repo:

import os.path, io
filename = ""

n_chunks = 12  # Number of processes to use -- will split the file up into this many pieces

def find_newline_pos(f,n):
    f.seek(n)
    c = f.read(1)
    while c != '\n' and n > 0:
        n-=1
        f.seek(n)
        c = f.read(1)
    return(n)

def prestart():
    fsize = os.path.getsize(filename)
    pieces = []   # Holds start and stop position of each chunk
    initial_chunks=list(range(0,fsize,int(fsize/n_chunks)))[:-1]
    f = io.open(filename,'rb')
    pieces = sorted(set([find_newline_pos(f,n) for n in initial_chunks]))
    pieces.append(fsize)
    args = zip([x+1 if x > 0 else x for x in pieces],[x for x in pieces[1:]])
    return(args)

args = prestart()

The purpose of above snippet is

| Part               | Purpose                                                                               |
| ------------------ | --------------------------------------------------------------------------------------|
| `find_newline_pos` | Moves backward from a byte offset to find the previous newline (`\n`).                |
| `prestart`         | Splits the file into roughly equal chunks that align with newline positions.          |
| `args`             | The list of `(start, end)` byte positions for each chunk — ready for multiprocessing. |

For a 2GB file, the code above runs for more than 10 minutes.

Is there a more efficient method to determine the byte positions to split a ndjson file?


Solution

  • Memory-mapping avoids Python I/O overhead and lets you search with .find()

    import os.path, io
    import mmap
    filename = ""
    
    n_chunks = 12  # Number of processes to use -- will split the file up into this many pieces
    
    def find_newline_pos_fast(mm, n):
        return mm.find(b'\n', n)
    
    def prestart():
        fsize = os.path.getsize(filename)
        chunk_size = fsize // n_chunks
    
        with open(filename, "rb") as f:
            mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
    
            pieces = sorted(set([
                find_newline_pos_fast(mm, chunk_size * i) + 1
                for i in range(1, n_chunks)
                if find_newline_pos_fast(mm, chunk_size * i) != -1
            ]))
            
            pieces.insert(0, 0)  # start at byte 0
            pieces.append(fsize)
            mm.close()
    
        args = list(zip(pieces[:-1], pieces[1:]))
        return args
    
    args = prestart()