halcの競プロ精進ブログ

はるくが競プロしたことを書き留めるなにか

【今日の精進】ABC221F - Diameter set

毎日格上の問題を倒すやつの13日目です。1

問題リンク

atcoder.jp

問題概要

木が与えられるよ。

木から頂点を2つ以上いくつか選んだ時、選んだ頂点の組すべてについて距離が木の直径と等しい選び方の個数を 998244353 で割ったあまりを求めてね。

解法

木の直径は、DFS*2で簡単に求められます。

また、経路復元すると木の中心も求まります。木の中心は1点か2点ありますが、1点の時はそのままでよいです。

2点の時は、頂点をひとつ追加してその間に張ることで1点の時と同じようにできます。

また、木の中心を消して木をいくつかに分けたとき、その中で選べる頂点は高々一つです。これは、2つ以上選ぶと距離がかならず元の木の直径未満になるからです。

なので、木の中心から遠い頂点を木DPで求めて、総積を求めればよいです。

提出コード

import sys
from collections import deque
from math import inf
import pypyjit
sys.setrecursionlimit((1<<19)-1)
pypyjit.set_param('max_unroll_recursion=-1')

class Graph():
    def __init__(self,size,directed=False,weighted=False):
        self.dir=directed
        self.wei=weighted
        self.size=size
        self.gr=[[] for i in range(size)]
        self.edges=[]
    
    def add_edge(self,u,v,w=1):
        if not(self.wei):
            w=1
        self.gr[u].append(self.Edge(u,v,w))
        if not(self.dir):
            self.gr[v].append(self.Edge(v,u,w))
        self.edges.append(self.Edge(u,v,w))

    def node(self,v):
        return self.gr[v]
    
    class Edge():
        def __init__(self,st,to,weight):
            self.st=st
            self.to=to
            self.weight=weight

def bfs(graph,start):
    dist=[inf for i in range(graph.size)]
    used=[False for i in range(graph.size)]
    dist[start]=0
    vert=deque([(0,start)])
    while len(vert)>0:
        dis,pos=vert.popleft()
        if used[pos]:
            continue
        used[pos]=True
        for i in graph.node(pos):
            if dis+i.weight<dist[i.to]:
                vert.append((dis+i.weight,i.to))
                dist[i.to]=dis+i.weight
    return dist

def two(s,g):
    dist1=bfs(gr,s)
    dist2=bfs(gr,g)
    mul=[0 for i in range(N*2)]
    for i in range(N):
        mul[min(dist1[i],dist2[i])]=max(mul[min(dist1[i],dist2[i])],len(gr.node(i)))
    ans=2
    for i in mul:
        if i<=1:
            break
        ans*=i-1
    return ans

def dfs(pos,bef,cnt):
    if cnt==0:
        ans=1
        wa=0
        for i in gr.node(pos):
            now=dfs(i.to,pos,cnt+1)
            wa+=now
            ans*=(now+1)
            ans%=MOD
        ans-=wa+1
        ans%=MOD
        return ans
    else:
        ans=0
        for i in gr.node(pos):
            if i.to!=bef:
                ans+=dfs(i.to,pos,cnt+1)
        if cnt==diam//2:
            ans+=1
        return ans

MOD=998244353
N=int(input())
gr=Graph(N)
for i in range(N-1):
    a,b=map(int,input().split())
    gr.add_edge(a-1,b-1)
dist=bfs(gr,0)
dist=bfs(gr,dist.index(max(dist)))
diam=max(dist)
pos=dist.index(diam)
ans=1
for i in range(diam//2):
    for j in gr.node(pos):
        if dist[j.to]<dist[pos]:
            pos=j.to
            break
if diam%2==1:
    rep=pos
    for i in gr.node(pos):
        if dist[i.to]<dist[pos]:
            pos=i.to
            break
    reg=Graph(N+1)
    for i in gr.edges:
        if (i.to==pos and i.st==rep) or (i.st==pos and i.to==rep):
            reg.add_edge(i.st,N)
            reg.add_edge(i.to,N)
        else:
            reg.add_edge(i.to,i.st)
    gr=reg
    pos=N
    diam+=1
print(dfs(pos,-1,0))

Submission #46188189 - AtCoder Beginner Contest 221

感想

今日何も解けなくて死ぬかと思った2


  1. 滑り込みセーフ、爆発回避
  2. 普通の人は競プロに命を懸けません