pythonperformance

Optimizing Python Code faster that 4 seconds


I have written some code to solve an online problem, Here.

mustConstraint = set()
notConstraint = set()
violated = 0
satisfied = 0 

for i in range(0, int(input(''))):
    constraint = input('')
    mustConstraint.add(frozenset(constraint.split()))

for i in range(0, int(input(''))):
    constraint = input('')
    notConstraint.add(frozenset(constraint.split()))

for i in range(0, int(input(''))):
    group = input('')
    group = set(group.split())

    for x in mustConstraint:
        if x & group == x:
            satisfied +=1
        
    for y in notConstraint:
        if y & group == y:
            violated += 1

violated += len(mustConstraint) - satisfied

print(violated)

To summarize the problem, the first line of input contains a positive integer X, with X => 0. the following X lines of input will contain two words, separated by whitespace. These two words must be in the same group. the next line of input will contain another positive integer Y, with Y => 0. the next Y lines of input will contain two words, separated by whitespace. these two words must not be in the same group. the next line of input will contain a positive integer, G with G >= 1. The next G lines of input will each consist of three different words, separated by single spaces. These three words have been placed in the same group.

Output an integer between 0 and X+Y which is the number of constraints that are violated.

I highly recommend you visit the problem site here, since the problem is much easier to understand with the sample cases provided and their explanations.

Unfortunately, since the last batch has ~300,000 inputs, my nested for loops are far to slow, and fail to meet the 4 second time limit - Can someone help me optimize my code?
Execution results

Most of the delay comes from this block:

for i in range(0, int(input(''))):
    group = input('')
    group = set(group.split())

    for x in mustConstraint:
        if x & group == x:
            satisfied +=1
        
    for y in notConstraint:
        if y & group == y:
            violated += 1

the nested for loops perform 100,000^2 iteration each, amounting to a total 20 billion iterations form that code block (2*100,000^2)

if someone could find a way to reduce the number of iterations, it would make a substantial difference.


Solution

  • Two solutions, both easily pass all tests.

    Solution 1: Minimal change

    Replace your large inner loops with the tiny loop over the group's three pairs, i.e., change

        for x in mustConstraint:
            if x & group == x:
                satisfied +=1
            
        for y in notConstraint:
            if y & group == y:
                violated += 1
    

    to this:

        a, b, c = group
        for pair in {a, b}, {a, c}, {b, c}:
            if pair in mustConstraint:
                satisfied += 1
            if pair in notConstraint:
                violated += 1
    

    Solution 2: Just two set operations

    All three sections specify a set of pairs. Helps to put that into a function. Then find/count the violations:

    from itertools import combinations
    
    def pairs():
        return {
            frozenset(pair)
            for _ in range(int(input()))
            for pair in combinations(input().split(), 2)
        }
    
    must = pairs()
    must_not = pairs()
    are = pairs()
    
    print(len(must - are) + len(must_not & are))