algorithmgraphtreebinary-searchminimum-spanning-tree

Constrained MST: Find a spanning tree of weight ≤ T that minimizes the number of red edges


We are given an undirected graph ( G = (V, E) ), where:

Each edge ( e ∈ E ) has two attributes:

In addition, we are given a threshold value ( T ), which represents the maximum total weight allowed for our final solution.


🎯 Objective

Our goal is to construct a spanning tree of the graph ( G ) that satisfies two conditions:

  1. The total weight of the spanning tree is less than or equal to ( T ).
  2. Among all such spanning trees, we want to find one that uses the minimum possible number of red edges.

Formally, we seek the smallest integer ( k ) such that there exists a spanning tree of ( G ) with:


Sample Input:

Let the graph be:
Vertices: 4
Edges:
1 - 2 (weight = 1, color = blue)
2 - 3 (weight = 2, color = red)
3 - 4 (weight = 2, color = blue)
4 - 1 (weight = 3, color = red)
1 - 3 (weight = 4, color = blue)
Threshold T = 7


Expected Output:

The minimal number of red edges needed in a spanning tree with total weight ≤ T is: 1

Example valid spanning tree:

Total weight = 5, red edges = 1 ✅


❓ Question:

How can I efficiently check, for a given k, whether a spanning tree with at most k red edges and total weight ≤ T exists?

Can this be done in near-linear time? Any insights or algorithmic strategies would be appreciated.

Constraints:


My Code:

#include <bits/stdc++.h>
using namespace std;

struct Edge {
    int u, v;
    long long w;
    int red;
};

struct DSU {
    vector<int> p, r;
    DSU(int n): p(n+1), r(n+1,0) {
        iota(p.begin(), p.end(), 0);
    }
    int find(int x) { return p[x]==x ? x : p[x]=find(p[x]); }
    bool unite(int a, int b) {
        a = find(a); b = find(b);
        if (a == b) return false;
        if (r[a] < r[b]) swap(a, b);
        p[b] = a;
        if (r[a] == r[b]) r[a]++;
        return true;
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    long long T;
    cin >> n >> m >> T;

    vector<Edge> edges(m);
    for (int i = 0; i < m; i++) {
        cin >> edges[i].u >> edges[i].v >> edges[i].w >> edges[i].red;
    }

    auto test = [&](long long lambda, long long &outW, int &outR) {
        vector<pair<long long,int>> order(m);
        for (int i = 0; i < m; i++) {
            order[i] = { edges[i].w + lambda * edges[i].red, i };
        }
        sort(order.begin(), order.end());

        DSU dsu(n);
        outW = 0; outR = 0;
        int cnt = 0;
        for (auto &p : order) {
            int i = p.second;
            if (dsu.unite(edges[i].u, edges[i].v)) {
                outW += edges[i].w;
                outR += edges[i].red;
                if (++cnt == n-1) break;
            }
        }
        return cnt == n-1;
    };

    long long lo = 0, hi = 1000000;
    int bestR = m;
    long long bestW = LLONG_MAX;

    while (lo <= hi) {
        long long mid = (lo + hi) / 2;
        long long w; int r;
        bool ok = test(mid, w, r);
        if (ok && w <= T) {
            if (r < bestR || (r == bestR && w < bestW)) {
                bestR = r;
                bestW = w;
            }
            lo = mid + 1;
        } else {
            hi = mid - 1;
        }
    }

    cout << bestR << "\n" << bestW << "\n";
    return 0;
}

e.g where it fails :-
5 # n = 5 vertices
6 # m = 6 edges
8 # T = 8 (max total weight)
u v w red?
1 2 1 1
2 3 1 1
3 4 1 1
4 5 1 1
1 5 5 0
2 4 5 0

this code output :
4 (red-edge count)
4 (total weight)
expected output :
3 (red-edge count)
8 (total weight)


Solution

  • You can instead binary search over the number of red edges used. On each iteration, check if it is possible to create a spanning tree of weight at most T using no more than a certain number of red edges.

    #include <vector>
    #include <iostream>
    #include <numeric>
    #include <algorithm>
    
    // ...
    
    int main() {
        int n, m, T;
        std::cin >> n >> m >> T;
        std::vector<Edge> edges(m);
        for (auto& edge : edges) std::cin >> edge.u >> edge.v >> edge.w >> edge.red;
        std::ranges::sort(edges, {}, &Edge::w);
        int low = 0, high = n, ansRedUsed = -1, ansWeight = -1;
        while (low <= high) {
            int mid = std::midpoint(low, high), redUsed = 0, totalWeight = 0, edgesUsed = 0;
            DSU dsu(n);
            for (auto& edge : edges)
                if (redUsed + edge.red <= mid && dsu.unite(edge.u, edge.v)) 
                    redUsed += edge.red, totalWeight += edge.w, ++edgesUsed;
            if (edgesUsed == n - 1 && totalWeight <= T) {
                ansRedUsed = redUsed;
                ansWeight = totalWeight;
                high = mid - 1;
            } else low = mid + 1;
        }
        std::cout << ansRedUsed << '\n'<< ansWeight << '\n';
    }