You are given an unweighted tree with N nodes and each node starts with value 0. There will be Q queries. The 1st type of query will change the value at a node to a new value. The 2nd type of query asks for the sum of values of nodes on the simple path between two given nodes.
The value of any node is always between 0 and 10^9 and N, Q are between 1 and 10^5. Offline query answering is allowed.
I tried with storing all node values in array and running dfs to find path sum, but it exceeds 1s time limit.
#include <vector>
#include <cstring>
#include <iostream>
using namespace std;
const int MAX_INPUT = 1e5 + 5;
vector<int> graph[MAX_INPUT];
int values[MAX_INPUT];
bool seen[MAX_INPUT];
long long getSum(int from, int to) {
if (seen[from]) {
return -1;
}
seen[from] = true;
if (from == to) {
return values[from];
}
for (int n : graph[from]) {
int nsum = getSum(n, to);
if (nsum >= 0) {
return values[from] + nsum;
}
}
return -1;
}
int main() {
ios::sync_with_stdio(0); cin.tie(0);
int N;
cin >> N;
for (int i = 0; i < N - 1; i++) {
int a, b;
cin >> a >> b;
graph[a].push_back(b);
graph[b].push_back(a);
}
int Q;
cin >> Q;
for (int q = 0; q < Q; q++) {
int type;
cin >> type;
if (type == 1) {
int node, newvalue;
cin >> node >> newvalue;
values[node] = newvalue;
} else {
int node1, node2;
cin >> node1 >> node2;
memset(seen, 0, N + 1);
long long pathsum = getSum(node1, node2);
cout << pathsum << '\n';
}
}
}
The input format is:
An example input is
5
1 2
1 5
2 3
2 4
5
1 2 10
2 3 4
1 1 5
2 2 5
2 4 3
and output
10
15
10
What is a more efficient way to solve these queries?
Root the tree arbitrarily and perform a DFS traversal. During the traversal, record the times when we start and finish processing a node (using an integer variable incremented on each recursive call). All the ancestors of a node have a lower start time as they are encountered first in the DFS. Furthermore, the start times of the nodes in any subtree form a consecutive range of integers, from the start time of the subtree root to its finish time.
We can build a binary indexed tree (which enables logarithmic point updates and prefix sum queries) using these start times as positions. The binary indexed tree initially contains all 0s and we will maintain the fact that the prefix sum from position 1 to the start time of any node is equal to the sum of values along the path from the root to that node. When a node's value is updated, only the nodes in its subtree are affected as only the paths from the root to those nodes would pass through the updated node. As such, we can increment the value at the start time of the updated node in the binary indexed tree by the difference between the new value and the previous value, and decrement by the same difference at the position after the finish time of the node. Due to the consecutive start times of subtree nodes, this ensures that no nodes outside the updated subtree are changed, in much the same way as with a difference array.
For any two tree nodes, the path between them can be expressed as the path from one node to the lowest common ancestor (LCA) and the path from the LCA to the other node. If we now add the sum of values from the root to each of the two nodes (by querying the binary indexed tree at the start times of the nodes), the path from the root to the LCA is counted twice (but the ancestors of the LCA are not on the path between the two nodes). To rectify this, we can subtract out that path sum twice and then manually add the value at the LCA back to arrive at the sum of values on the path between the two nodes.
In order to quickly compute the LCA, we can use binary lifting (although there are more efficient ways to do this, this method is simple to implement). During the DFS, we will additionally keep track of all the ancestors of each node with the difference in depth being a power of two. This information can be used to quickly jump to the LCA in O(log N) steps.
The time complexity is O((N + Q)log(N)).
#include <iostream>
#include <vector>
#include <bit>
#include <algorithm>
const unsigned MAX_INPUT = 1e5 + 5;
std::vector<int> graph[MAX_INPUT];
int values[MAX_INPUT], ancestor[MAX_INPUT][std::bit_width(MAX_INPUT)], depth[MAX_INPUT],
start[MAX_INPUT], finish[MAX_INPUT], N, Q, lg, visitTime;
long long bit[MAX_INPUT];
void dfs(int node, int parent) {
depth[node] = depth[parent] + 1;
ancestor[node][0] = parent;
for (int i = 1; i <= lg; ++i) ancestor[node][i] = ancestor[ancestor[node][i - 1]][i - 1];
start[node] = ++visitTime;
for (int next : graph[node])
if (next != parent)
dfs(next, node);
finish[node] = visitTime;
}
int LCA(int node1, int node2) {
if (depth[node1] < depth[node2]) std::swap(node1, node2);
for (int i = lg; i >= 0; --i)
if (depth[ancestor[node1][i]] >= depth[node2])
node1 = ancestor[node1][i];
if (node1 == node2) return node1;
for (int i = lg; i >= 0; --i)
if (ancestor[node1][i] != ancestor[node2][i])
node1 = ancestor[node1][i], node2 = ancestor[node2][i];
return ancestor[node1][0];
}
void update(int i, int delta) {
for (; i <= N; i += i & -i) bit[i] += delta;
}
long long query(int i) {
long long res{};
for (; i; i &= i - 1) res += bit[i];
return res;
}
int main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> N;
lg = std::bit_width(unsigned(N)) - 1;
for (int i = 1, a, b; i < N; ++i) {
std::cin >> a >> b;
graph[a].push_back(b);
graph[b].push_back(a);
}
dfs(1, 1);
std::cin >> Q;
for (int i = 0, type; i < Q; ++i) {
std::cin >> type;
if (type == 1) {
int node, newValue;
std::cin >> node >> newValue;
update(start[node], newValue - values[node]);
update(finish[node] + 1, values[node] - newValue);
values[node] = newValue;
} else {
int node1, node2;
std::cin >> node1 >> node2;
int lca = LCA(node1, node2);
std::cout << query(start[node1]) + query(start[node2]) - 2 * query(start[lca]) + values[lca] << '\n';
}
}
}