算法競賽2021 ICPC Southeastern Europe Regional Contest_Werewolves
//#include "stdafx.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <cstdlib>
#include <vector>
using namespace std;
typedef long long ll;
const int MAXN = 3000 + 10;
const int mod = 998244353;
vector<int> g[MAXN];
int n, m;
?
int c[MAXN]={0,1,1,3,3};
//int c[MAXN]={0,2,3,3};
char vis[MAXN];
int dp1[MAXN][MAXN], dp2[MAXN][MAXN], dp3[MAXN];
ll tmp1[MAXN], tmp2[MAXN], tmp3;
int res=0;
int edge[610]={1, 2, 1, 3,1,4};
//int edge[610]={1, 2, 2, 3};
int dfs(int u, int start, int ci) {
//printf("df? ?u=%d, start=%d\n",u,start);?
? ? int p=1;
? ? if(c[u] == c[ci]) dp1[u][1] = 1;
? ? else dp2[u][1] = 1;
? ? for(auto v : g[u])?
{
? ? ? ? if(v == start) continue;
? ? ? ? int sz = dfs(v, u,ci);?
//printf("sz = %d,? %d? ?%d\n",sz,v,u);?
tmp3 = dp3[u];
? ? ? ? for(int i = 1; i <= min(p, m); i++) {
? ? ? ? ? ? tmp1[i] = dp1[u][i];
? ? ? ? ? ? tmp2[i] = dp2[u][i];? ? ? ? ? ?
? ? ? ? }
? ? ? ? dp3[u] = (dp3[u] + tmp3 * dp3[v]) % mod;
? ? ? ? for(int j = 1; j <= min(sz,m); j++) {
? ? ? ? ? ? dp1[u][j] = (dp1[u][j] + tmp3 * dp1[v][j]) % mod;
? ? ? ? ? ? dp2[u][j] = (dp2[u][j] + tmp3 * dp2[v][j]) % mod;
? ? ? ? }
? ? ? ? for(int i = 1; i <= min(p, m); i++) {
? ? ? ? ? ? dp1[u][i] = (dp1[u][i] + tmp1[i] * dp3[v]) % mod;
? ? ? ? ? ? dp2[u][i] = (dp2[u][i] + tmp2 [i] * dp3[v]) % mod;
? ? ? ? ? ? for(int j = 1; j <= min(sz,m); j++) {
? ? ? ? ? ? ? ?if(i+j <= m)?
{
? ? ? ? ? ? ? ? ? ? dp1[u][i+j] = (dp1[u][i+j] + tmp1[i] * dp1[v][j]) % mod;
? ? ? ? ? ? ? ? ? ? dp2[u][i+j] = (dp2[u][i+j] + tmp2[i] * dp2[v][j]) % mod;
? ? ? ? ? ? ? ? }
? ? ? ? ? ? ? ? if(i>j ) {
? ? ? ? ? ? ? ? ? ? dp1[u][i-j] = (dp1[u][i-j] + tmp1[i] * dp2[v][j]) % mod;
? ? ? ? ? ? ? ? ? ? dp2[u][i-j] = (dp2[u][i-j] + tmp2[i] * dp1[v][j]) % mod;
? ? ? ? ? ? ? ? }
? ? ? ? ? ? ? ? if(j>i ) {
? ? ? ? ? ? ? ? ? ? dp1[u][j-i] = (dp1[u][j-i] + tmp2[i] * dp1[v][j]) % mod;
? ? ? ? ? ? ? ? ? ? dp2[u][j-i] = (dp2[u][j-i] + tmp1[i] * dp2[v][j]) % mod;
? ? ? ? ? ? ? ? }
? ? ? ? ? ? ? ? if(i == j) {
? ? ? ? ? ? ? ? ? ? dp3[u] = (dp3[u] + tmp1[i] * dp2[v][j] + tmp2[i] * dp1[v][j]) % mod;
? ? ? ? ? ? ? ? }
? ? ? ? ? ? }
? ? ? ? }
? ? ? ?p += sz;
? ? }
? ? for(int i = 1; i <= min(p, m); i++) {
? ? ? ? res = (res + dp1[u][i]) % mod;
//printf("dfs? %d,? ?dp1[u][i]=%d, u=%d,? ?i=%d,? p=%d,? m=%d\n",res,? dp1[u][i],u,i,p,m );?
? ? }
? ? return p;
}
void solve() {
/*/ ?
n=4;
? ? for(int i = 1; i < n; i++) {
? ? ? ? g[edge[2*(i-1)]].push_back(edge[2*(i-1)+1]);?
? ? ? ? g[edge[2*(i-1)+1]].push_back(edge[2*(i-1)]);
? ? }
//*/
? ? for(int i = 1; i <= n; i++)?
{
? ? ? ? if(vis[c[i]]) continue;
? ? ? ? vis[c[i]] = 1; m = 0;? ? ?
? ? ? ? for(int j = 1; j <= n; j++)? ?if(c[j] == c[i]) m++;
? ? ? ??
? ? ? ? for(int j = 1; j <= n; j++)?
{
? ? ? ? ? ? for(int k = 0; k <= m; k++) {
? ? ? ? ? ? ? ? dp1[j][k] = dp2[j][k] = dp3[j] = 0;
? ? ? ? ? ? }
? ? ? ? }
? ? ? ? dfs(1,? 0, i);
}
printf("%d\n",res);?
}
int main() {
//*
scanf("%d",&n);?
? ?for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
? ?int u, v;?
? ?for(int i = 1; i < n; i++) {
? ?scanf("%d", &u); scanf("%d", &v);
? ? g[u].push_back(v);
g[v].push_back(u);
? ?}//*/
? ? solve();
? ? return 0;
}