c++recursionbinary-treetree-traversal

Why does modifying the returned pair from a recursive call break my solution in Binary Tree Consecutive Sequence II?


I'm trying to solve the Binary Tree Longest Consecutive Sequence II problem (Leetcode 549). The goal is to find the length of the longest consecutive path (increasing or decreasing) in a binary tree. The path can go in both directions (child → parent → child), not just downward.

I initially implemented the following solution using recursion. Here's the code that does NOT work:

class Solution {
public:
    int maxLen = 0;

    pair<int, int> toNode(TreeNode *node) {
        if (!node) return {0, 0};

        pair<int, int> linc(0, 0), rdec(0, 0);

        if (node->left) {
            linc = toNode(node->left);
            if (node->left->val - 1 == node->val) {
                linc.second++;
            }
            if (node->left->val + 1 == node->val) {
                linc.first++;
            }
        }

        if (node->right) {
            rdec = toNode(node->right);
            if (node->right->val - 1 == node->val) {
                rdec.second++;
            }
            if (node->right->val + 1 == node->val) {
                rdec.first++;
            }
        }

        int inc = max(1, max(linc.first, rdec.first));
        int dec = max(1, max(linc.second, rdec.second));
        maxLen = max(maxLen, inc + dec - 1);
        return {inc, dec};
    }

    int longestConsecutive(TreeNode* root) {
        toNode(root);
        return maxLen;
    }
};

Now, this version doesn't give the correct result for all test cases. However, when I refactor it like below — calculating inc and dec separately, without mutating the recursive return values — it works fine:

class Solution {
public:
    int maxLen = 0;
    pair<int, int> toNode(TreeNode *node){
        if(!node)return {1, 1};
        pair<int, int>linc(0, 0), rdec(0, 0);

        int lic=0, lid=0, ric=0, rid=0;

        if(node->left){
            linc = toNode(node->left);
            if(node->left->val == node->val-1)lid = linc.second+1;
            if(node->left->val == node->val+1)lic = linc.first+1;
        }
       
        if(node->right){
            rdec = toNode(node->right);
            if(node->right->val == node->val-1)rid = rdec.second+1;
            if(node->right->val == node->val+1)ric = rdec.first+1;
        }
        
        int inc = max(1, max(lic , ric));
        int dec = max(1, max(lid, rid));
        
        int cur = inc + dec - 1;
        
        maxLen = max(cur, maxLen);
        return {inc, dec};

    }
    int longestConsecutive(TreeNode* root) {
        toNode(root).first;
        return maxLen;
    }
};

Question Why does the first version fail, even though I only increment the pair values (linc.first++ or linc.second++) based on conditions? I'm not explicitly using the returned pair anywhere else, so I expected this mutation to be safe.

Edit: I've now rewritten the code in a way that looks nearly identical to the original version.

Edit 2 (Reproducible Case):

[3, null, 4, null, 1, null, 2]  

Expected Outcome: 2
My code returns: 3
Demo


Solution

  • When the inner if conditions are not true, your faulty version still continues with linc as it was returned by the recursive call. This is wrong, and is not what the correct version does: there, if the inner condition is not true, the first/second component of the pair returned from recursion is completely ignored, and the 0 initialised value is used instead. This is how it should be: when the parent-child values don't have the required relation, then it breaks the relevant inc/dec chain, and you should start counting again from scratch.

    It may help if you would look at this fix of the if statements in the faulty version: else clauses are added to reset the relevant count to zero:

            if (node->left) {
                linc = toNode(node->left);
                if (node->left->val - 1 == node->val) {
                    linc.second++;
                } else {
                    linc.second = 0; // Reset length!
                }
                if (node->left->val + 1 == node->val) {
                    linc.first++;
                } else {
                    linc.first = 0; // Reset
                }
            }
    
            if (node->right) {
                rdec = toNode(node->right);
                if (node->right->val - 1 == node->val) {
                    rdec.second++;
                } else {
                    rdec.second = 0; // Reset
                }
                if (node->right->val + 1 == node->val) {
                    rdec.first++;
                } else {
                    rdec.first = 0; // Reset
                }
            }
    

    Not a problem, but in both versions of the code you could use 1 as default (initialisation) and reset value. That way you don't have to wrap the max call with another max(1, ...) call.