2022 CCPC 绵阳站 E题 (图上DP,根号分治)
题意
有一个由 n n n城市组成的国家,城市之间由一条权值为 w w w的边连接,共m条这样的边,并且保证整个国家是连通的。每个城市中有 a i a_i ai 个居民。
在接下来的q天,每天都会有一个城市遭受灾难 b 1 , b 2 , . . . , b q b_1,b_2,...,b_q b1,b2,...,bq,你必须将该城市的所有人都转移到其它城市才能避免居民受到灾难,转移一个居民到相邻城市的代价为两个城市之间路径的权值w。
请问你最少需要多少代价才能让所有居民都安全度过q天的灾难。
思路
我们不必管每个城市中有多少个人,我们只需要求出每个城市中转移一个人的最小代价,在最终计算总代价时再乘上人数即可。很容易想到一个暴力的DP解法如下:
令 f ( i , j ) f(i,j) f(i,j) 表示在第 j j j号点,第 i i i天后所有的操作中最小的代价。那么有 f ( i , j ) = M I N ( v , w ) ∈ e d g e b i { w + f ( i + 1 , v ) } f(i,j) = MIN_{(v,w) \in edge_{b_i}}\{w+f(i+1,v)\} f(i,j)=MIN(v,w)∈edgebi{w+f(i+1,v)} 。
我们发现每天只会更新一个dp值,于是我们可以直接省去f的第一维,然后倒序枚举天数 $i $ 从 q q q到 1 1 1 。 总的时间复杂度为 O ( ∑ i = 1 q d e g ( b i ) O(\sum_{i=1}^q deg(b_i) O(∑i=1qdeg(bi) ,我们发现当 b i b_i bi 的度较大时,复杂度会退化到 O ( q n ) O(qn) O(qn) 这是不被允许的。
于是考虑根号分治,设分治边界为 S Q SQ SQ ,那么当 d e g ( b i ) ≤ S Q deg(b_i) \le SQ deg(bi)≤SQ 时,枚举他的所有边来更新dp, 如果 d e g ( b i ) > S Q deg(b_i) > SQ deg(bi)>SQ 那么我们为这个节点建立一个multiset ,存储所有邻边的 d p [ v ] + w dp[v]+w dp[v]+w 的值,则multiset的第一项即为当前最小的dp值。每当一个节点的dp值更新时,将与他相邻的 d e g ( v ) > S Q deg(v) > SQ deg(v)>SQ的点更新。
当 S Q = 2 × m × l o g n SQ = \sqrt{2 \times m \times logn} SQ=2×m×logn 时,复杂度为 O ( q m × l o g n ) O(q\sqrt{m \times logn}) O(qm×logn)
代码
#include<bits/stdc++.h>
using namespace std;
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define int long long
#define rep(i,l,r) for(int i = l;i<=r;i++)
#define per(i,r,l) for(int i = r;i>=l;i--)
const int INF = 0x3f3f3f3f3f3f3f3f;
typedef pair<int,int> PII;
int a[100005];
vector<PII> edge[100005];
vector<PII> edgeB[100005];
int deg[100005];
int que[100005];
int dp[100005];
multiset<int> mulst[100005];
const int mod = 998244353;
void solve(){int n,m,q;cin>>n>>m>>q;int SQ = sqrt(2LL * m * log2(n));for(int i = 1;i<=n;i++){cin >> a[i];}for(int i = 1;i<=m;i++){int u,v,w;cin>>u>>v>>w;edge[u].push_back({v,w});edge[v].push_back({u,w});deg[v]++;deg[u]++;}for(int u = 1;u<=n;u++){if(deg[u] > SQ){for(auto [v,w] : edge[u]){if(deg[v] > SQ){edgeB[u].push_back({v,w});}mulst[u].insert(w);}}}for(int i = 1;i<=q;i++){cin>>que[i];}for(int i = q;i>=1;i--){int u = que[i];if(deg[u] <= SQ){int cost = INF;for(auto [v,w] : edge[u]){cost = min(cost,dp[v]+w);}for(auto [v,w] : edge[u]){if(deg[v] > SQ){mulst[v].erase(mulst[v].find(w+dp[u]));mulst[v].insert(w+cost);}}dp[u] = cost;}else{int cost = *mulst[u].begin();for(auto [v,w] : edgeB[u]){if(deg[v] > SQ){mulst[v].erase(mulst[v].find(w+dp[u]));mulst[v].insert(w+cost); }}dp[u] = cost;}}int ans = 0;for(int i = 1;i<=n;i++){ans = (ans + dp[i] * a[i]) %mod; }cout<<ans;
}
signed main(){int T = 1;// cin>>T;while(T--){solve();}return 0;
}