题目传送门qwq

发现并不需要求出期望,之需要求出所有方案的和即可。考虑 $\rm{DP}$ :

  • 设 $f(i,j)$ 表示选出的 $m$ 张牌中有 $i$ 张是强化牌,一共打出了 $j$ 张强化牌时所有方案的倍率和。
  • 设 $g(i,j)$ 表示选出的 $m$ 张牌中有 $i$ 张是攻击牌,一共打出了 $j$ 张攻击牌时所有方案的伤害和。

然后打牌的策略很显然是:

  • 如果 $m$ 张牌中有 $i$ 张强化牌,$i<k-1$ ,那么将所有强化牌打完,然后打出前 $k-i$ 大的攻击牌。
  • 如果 $m$ 张牌中有 $i$ 张强化牌,$i\geq k-1$ ,那么打出前 $k-1$ 大的强化牌,然后打出最大的攻击牌。

上面的 $\rm{DP}$ 不好转移。

考虑设 $dp_1(i,j)$ 表示一共选了 $i$ 张强化牌,其中最小的一张在所有强化牌中为第 $j$ 小时所有方案的倍率和,同样设 $dp_2(i,j)$ 表示一共选了 $i$ 张攻击牌,其中最小的一张在所有攻击牌中为第 $j$ 小时所有方案的伤害和。

先考虑 $dp_1$ 的转移,假设我们是从大往小选牌,上一次如果选到了第 $k$ 小,那么对于所有的 $k>j$ ,都可以用 $w_j\times dp_{1}(i-1,k)$ 来更新 $dp_1(i,j)$ :

$$ dp_1(i,j)=w_j\times \sum_{k=j+1}^{n} dp_1(i-1,k) $$

注意这里是将 $w_1$ 数组从小到大排了序的。

上面的 $\rm{DP}$ 是 $O(n^3)$ 的,用前缀和优化可以做到 $O(n^2)$ 。

然后考虑 $dp_2$ 的转移,一样是从大到小选牌,如果上一次选到了第 $k$ 小,那么对于所有的 $k>j$ ,都可以用 $dp_2(i-1,k)$ 来更新 $dp_2(i,j)$ 。然后显然 $dp_{2}(i,j)$ 还有 $w_j$ 的贡献,从 $n-j$ 这些牌里面选 $i$ 张牌的方案数一共有 ${n-j\choose i-1}$ 种,每一种方案现在的伤害和都需要加上 $w_j$ ,所以总共加上 ${n-j\choose i-1}\times w_j$ :

$$ dp_2(i,j)={n-j\choose i-1}\times w_j+\sum_{k=j+1}^{n} dp_2(i-1,k) $$

用前缀和优化一样能做到 $O(n^2)$ 。


最后统计答案,枚举 $i$ ,表示 $m$ 张牌中有 $i$ 张强化牌。

对于 $i<k-1$ ,答案为 $f(i,i)\times g(m-i,k-i)$ ,否则答案为 $f(i,k-1)\times g(m-i,1)$ ,求个和即可。


还需要考虑 $f,g$ 与 $dp_1,dp_2$ 的关系。

对于 $f(i,j)$ ,显然有 $j$ 张是需要打出的,$i-j$ 张不需要打出,一定要满足打出的牌中最小的比不打出的牌中最大的要大,枚举打出的牌中最小的牌的排名 $k$ ,然后用组合数计算即可。

$g(i,j)$ 的计算方式如法炮制。


Code:

#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;

const int N=1.5e3+5;
const int mod=998244353;

int C[N<<1][N<<1];
int T,n,m,k,w1[N],w2[N],dp1[N][N],dp2[N][N],sum1[N],sum2[N];

template <typename _Tp> inline void IN(_Tp&x) {
    char ch;bool flag=0;x=0;
    while(ch=getchar(),!isdigit(ch)) if(ch=='-') flag=1;
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    if(flag) x=-x;
}

inline int modpow(int x,int y,int res=1) {
    for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) res=1ll*res*x%mod;
    return res;
}

inline int F(int x,int y,int res=0) {
    if(x<y) return 0;
    if(!y) return C[n][x];
    for(int i=x-y+1;i<=n-y+1;++i) (res+=1ll*dp1[y][i]*C[i-1][x-y]%mod)%=mod;
    return res;
}
inline int G(int x,int y,int res=0) {
    if(x<y) return 0;
    for(int i=x-y+1;i<=n-y+1;++i) (res+=1ll*dp2[y][i]*C[i-1][x-y]%mod)%=mod;
    return res;
}

inline void solve() {
    IN(n),IN(m),IN(k);
    for(int i=1;i<=n;++i) IN(w1[i]);
    for(int i=1;i<=n;++i) IN(w2[i]);
    sort(w1+1,w1+1+n),
    sort(w2+1,w2+1+n);
    for(int i=1;i<=n;++i) {
        for(int j=1;j<=n;++j) dp1[i][j]=dp2[i][j]=0;
        sum1[i]=(sum1[i-1]+w1[i])%mod,dp1[1][i]=w1[i];
        sum2[i]=(sum2[i-1]+w2[i])%mod,dp2[1][i]=w2[i];
    }
    for(int i=2;i<=n;++i) {
        for(int j=1;j<=n-i+1;++j)
            dp1[i][j]=1ll*w1[j]*(sum1[n]-sum1[j]+mod)%mod,
            dp2[i][j]=(1ll*w2[j]*C[n-j][i-1]%mod+(sum2[n]-sum2[j]+mod)%mod)%mod;
        for(int j=1;j<=n;++j)
            sum1[j]=(sum1[j-1]+dp1[i][j])%mod,
            sum2[j]=(sum2[j-1]+dp2[i][j])%mod;
    }
    int ans=0;
    for(int i=0;i<m;++i)
        if(i<k-1) (ans+=1ll*F(i,i)*G(m-i,k-i)%mod)%=mod;
        else (ans+=1ll*F(i,k-1)*G(m-i,1)%mod)%=mod;
    printf("%d\n",ans);
}

int main() {
    C[0][0]=1;
    for(int i=1,limit=3000;i<=limit;++i) {
        C[i][0]=C[i][i]=1;
        for(int j=1;j<i;++j) C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
    }
    IN(T);
    while(T--) solve();
    return 0;
}
最后修改:2019 年 12 月 10 日 07 : 34 PM