halcの競プロ精進ブログ

はるくが競プロしたことを書き留めるなにか

【今日の精進】ABC215H - Cabbage Master

毎日格上の問題を倒すやつの68日目です。

問題リンク

atcoder.jp

問題概要

互いに区別できるキャベツがあるよ。

正確には、品種 i のキャベツが A_i 個あるよ。

また、店 i はいくつかの品種のキャベツを、合計 B_i 個欲しがってるよ。

高橋君は、持っているキャベツをすべての店の要望を満たすように出品できたとき、「キャベツ名人」になれるよ。

キャベツ嫌いのすぬけ君は、できるだけ少ない個数のキャベツを食べることで高橋君がキャベツ名人になれないようにしたいよ。

その個数と、食べ方の通り数を求めてね。

解法

前にも使った、Hallの結婚定理と高速ゼータ変換をフル活用します。

ビット列を全探索した時、それらの品種をすべて受け入れる店を高速に求められれば、食べるべき個数は容易に求められます。

これに関しては、やってることが完全にゼータ変換と同じです。

その後、パターン数を求めます。

最小値を達成可能なbit列も容易に求められるので、それらから奪う個数を求めます。

たとえば補集合をとった高速ゼータ変換で、bit列の部分bit列も容易に求めることが可能です。

それらに関して、「各品種を必ず1つ選んだ時の選び方のパターン」を求めます。

これは、包除原理を高速ゼータ変換でやればよいです。

提出

fact=[]
rev=[]
def comb_init(N=1000000,MOD=998244353):
    global fact,rev
    fact=[1]
    rev=[1]
    for i in range(N):
        fact.append((fact[-1]*(i+1))%MOD)
        rev.append((rev[-1]*pow(i+1,-1,MOD))%MOD)

def comb(n,r,MOD=998244353):
    if n<r:
        return 0
    return (fact[n]*rev[r]*rev[n-r])%MOD

comb_init(3000000)
N,M=map(int,input().split())
A=list(map(int,input().split()))
B=list(map(int,input().split()))
c=[list(map(int,input().split())) for i in range(N)]
ship=[0]*(1<<N)
for i in range(M):
    bit=0
    for j in range(N):
        if c[j][i]==1:
            bit+=1<<j
    ship[bit]+=B[i]
for i in range(N):
    bit=1<<i
    for j in range(1<<N):
        if j&bit:
            ship[j]+=ship[j^bit]
minima=float("inf")
for i in range(1<<N):
    if ship[i]==0:
        continue
    lim=0
    for j in range(N):
        if (i>>j)&1:
            lim+=A[j]
    minima=min(minima,max(0,lim-ship[i]+1))
ok=[0]*(1<<N)
for i in range(1<<N):
    if ship[i]==0:
        continue
    lim=0
    for j in range(N):
        if (i>>j)&1:
            lim+=A[j]
    if max(0,lim-ship[i]+1)==minima:
        ok[i]+=1
for i in range(N):
    bit=1<<i
    for j in range(1<<N):
        if j&bit==0:
            ok[j]+=ok[j|bit]
pat=[0]*(1<<N)
for i in range(1<<N):
    lim=0
    for j in range(N):
        if (i>>j)&1:
            lim+=A[j]
    if i.bit_count()%2:
        pat[i]+=comb(lim,minima)
    else:
        pat[i]-=comb(lim,minima)
for i in range(N):
    bit=1<<i
    for j in range(1<<N):
        if j&bit:
            pat[j]+=pat[j^bit]
ans=0
for i in range(1<<N):
    if ok[i]!=0:
        if i.bit_count()%2:
            ans+=pat[i]
        else:
            ans-=pat[i]
        ans%=998244353
print(minima,ans)

Submission #47978684 - AtCoder Beginner Contest 215

感想

キャベツ名人に、なりました…*1

*1:好きな問題の1位を更新しました