首先说说重链序:
第一次dfs,算出dep[]、fa[N]、size[N]、son[N],其中son[x]通过size[y]计算最大值得到。
第二次dfs,传入x和链首元素tp:
对其儿子列表,首先递归son[x],然后再递归其它儿子(轻儿子)
重儿子递归,tp元素不变,轻儿子递归,tp改成轻儿子自己
每个结点都存储到o[++last]=x里,同时存储p[x]=last
最后,输出o[N]即为重链剖分以后的序。
首先树链剖分得到重链序,然后考虑操作:
1-单点更新:
2-区间更新:以某点为根的区间是p[x]和p[x]+size[x]-1,刚好连续
3-求路径:某点到根的路径,可以按重链、跳、重链…的组合来实现,也即多个区间的和组合起来。
显然,可以用线段树来维护区间操作和查询。
AC代码:
#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
const int N = 1e5 + 1;
typedef long long ll;
ll n, m, fa[N], top[N], siz[N], son[N], o[N], las, p[N], v[N];
bool vis[N];
vector<ll> g[N];
inline ll read()
{
ll x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
return x * f;
}
void dfs1(ll x, ll fat, ll len)
{
fa[x] = fat;
siz[x] = 1;
for (ll i = 0; i < g[x].size(); ++i)
{
ll y = g[x][i];
if (y == fat)
continue;
dfs1(y, x, len + 1);
siz[x] += siz[y];
if (siz[y] > siz[son[x]])
son[x] = y;
}
}
void dfs2(ll x, ll dp)
{
top[x] = dp;
o[++las] = v[x];
p[x] = las;
if (son[x] == 0)
return;
dfs2(son[x], dp);
for (ll i = 0; i < g[x].size(); ++i)
{
ll y = g[x][i];
if (y == fa[x] || y == son[x])
continue;
dfs2(y, y);
}
}
struct node
{
ll l, r, sum, tag;
ll len()
{
return r - l + 1;
}
};
node tree[4 * N];
void bud(ll l, ll r, ll k)
{
tree[k].l = l;
tree[k].r = r;
if (l == r)
{
tree[k].sum = o[l];
return;
}
ll mid = (l + r) / 2;
bud(l, mid, 2 * k);
bud(mid + 1, r, 2 * k + 1);
tree[k].sum = tree[2 * k].sum + tree[2 * k + 1].sum;
}
void but(ll k, ll tag)
{
tree[k].sum += tree[k].len() * tag;
tree[k].tag += tag;
}
void pud(ll k)
{
but(2 * k, tree[k].tag);
but(2 * k + 1, tree[k].tag);
tree[k].tag = 0;
}
void upd(ll x, ll y, ll a, ll k)
{
ll l = tree[k].l, r = tree[k].r;
if (x <= l && y >= r)
{
but(k, a);
return;
}
if (tree[k].tag)
pud(k);
ll mid = (l + r) / 2;
if (x <= mid)
upd(x, y, a, 2 * k);
if (y >= mid + 1)
upd(x, y, a, 2 * k + 1);
tree[k].sum = tree[2 * k].sum + tree[2 * k + 1].sum;
}
ll que(ll x, ll y, ll k)
{
ll l = tree[k].l, r = tree[k].r;
if (x <= l && y >= r)
return tree[k].sum;
if (tree[k].tag)
pud(k);
ll mid = (l + r) / 2, ans = 0;
if (x <= mid)
ans += que(x, y, 2 * k);
if (y >= mid + 1)
ans += que(x, y, 2 * k + 1);
return ans;
}
int main()
{
n = read();
m = read();
for (ll i = 1; i <= n; ++i)
v[i] = read();
for (ll i = 1; i <= n - 1; ++i)
{
ll x, y;
x = read();
y = read();
g[x].push_back(y);
g[y].push_back(x);
}
dfs1(1, 0, 1);
dfs2(1, 1);
bud(1, n, 1);
for (ll i = 1; i <= m; ++i)
{
ll cmd, x, a;
cmd = read();
x = read();
if (cmd == 1)
{
a = read();
upd(p[x], p[x], a, 1);
}
if (cmd == 2)
{
a = read();
upd(p[x], p[x] + siz[x] - 1, a, 1);
}
if (cmd == 3)
{
ll sum = 0;
while (top[x] != 1)
{
sum += que(p[top[x]], p[x], 1);
x = fa[top[x]];
}
sum += que(1, p[x], 1);
printf("%lld\n", sum);
}
}
return 0;
}