「BZOJ 4033」「HAOI2015」树上染色 - DP

Link: 树上染色

题面

Description

有一棵点数为 N 的树,树边有边权。给你一个在 0 ~ N 之内的正整数 K ,你要在这棵树中选择 K 个点,将其染成黑色,并将其他的 N-K 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。

问收益最大值是多少。

Input

第一行两个整数 N, K

接下来 N-1 行每行三个正整数 fr , to , dis,表示该树中存在一条长度为 dis 的边(fr, to)。

输入保证所有点之间是联通的。

N\leq 2000,0\leq K\leq N

Output

输出一个正整数,表示收益的最大值。

Sample Input

1
2
3
4
5
5 2
1 2 3
1 5 1
2 3 1
2 4 2

Sample Output

1
17

样例解释

将点 1 , 2 染黑就能获得最大收益。

HINT

2018.2017.9.12新加数据一组 By GXZlegend


题解

考虑进行DP

首先DP出以 i 为根子树的方案数是大概不可做的, 一条边会被后来的许多对点更新, 于是就考虑每条边被计算的次数 cnt 即可, 可以直接枚举边计算

当然上面的算法大概会 TLE, 于是重新拾回刚开始想法, 不在考虑以 i 为根子树的方案数, 而是考虑以 i 为根子树对答案的 贡献. 这时候考虑对于边 (u, v) 被计算的次数, 令 \text{siz}_x x 的子树大小, k_0 为以 v 为根的黑点个数, 那么答案就是

k_0\times(k-k_0)+(\text{siz}_v-k_0)\times(n-\text{siz}_v-k+k_0)

\text{dp}_{i,j} 表示以 i 为根的子树中有 j 个黑点对答案的贡献, 然后 dp 方程就是

\text{dp}_{u,j}=\max\{\text{dp}_{u,j-k_0}+\text{dp}_{v,k_0}+\text{cnt}\times \text{dist}_{u,v}\}

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <cstdio>
#include <cstring>

#define int long long

const int N = 2000 + 10;

struct Edge {
int v, w;
int n;
Edge () {}
Edge (int a, int b, int c) : n(a), v(b), w(c) {}
};

int tot;
int hd[N];
Edge e[N << 1];

int dp[N][N];
int sz[N];

int n, k;

int min (int a, int b) {
return a > b ? b : a;
}

int max (int a, int b) {
return a + b - min(a, b);
}

void add (int u, int v, int w) {
tot++;
e[tot] = Edge (hd[u], v, w);
hd[u] = tot;
}

void init (int u, int f = 0) {
sz[u] = 1;
dp[u][0] = dp[u][1] = 0;
for (int i = hd[u]; i; i = e[i].n) {
int v = e[i].v;
if (v == f) continue;
init(v, u);
sz[u] += sz[v];
}
for (int i = hd[u]; i; i = e[i].n) {
int v = e[i].v;
if (v == f) continue;
for (int j = min(sz[u], k); j >= 0; j--) {
for (int _k = 0; _k <= min(sz[v], j); _k++) {
if (dp[u][j - _k] == -1) continue;
int cnt = _k * (k - _k) + (sz[v] - _k) * (n - sz[v] - k + _k);
dp[u][j] = max(dp[u][j], dp[u][j - _k] + dp[v][_k] + cnt * e[i].w);
}
}
}
}

signed main () {
scanf ("%lld%lld", &n, &k);

int u, v, w;
for (int i = 1; i < n; i++) {
scanf("%lld%lld%lld", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}

memset (dp, -1, sizeof dp);
init(1);

printf("%lld\n", dp[1][k]);
}
坚持原创技术分享,您的支持将鼓励我继续创作!