halcの競プロ精進ブログ

はるくが競プロしたことを書き留めるなにか

【今日の精進】ARC061E - Snuke's Subway Trip / すぬけ君の地下鉄旅行

毎日格上の問題を倒すやつの65日目です。

問題リンク

atcoder.jp

問題概要

町が電車でつながってるよ。

それぞれの電車は、同じ会社のものはいくら乗り換えても 1 円で済むけど、別の会社に乗り換えるともう 1 円かかるよ。

1 から町 N まで移動する値段の最小値を求めてね。

解法

それぞれの電車で行き来できる頂点対に辺を張ると、頂点数の二乗くらいの変数になってしまい、なかなかに厳しいです。

そのため、それぞれの塊ごとに頂点を追加して、そこにつないでいくと良いです。いわゆる超頂点です。

あとはこの上でBFSして距離を 2 で割って終わりです。

提出コード

from collections import defaultdict

class Graph():
    def __init__(self,size,directed=False):
        self.dir=directed
        self.size=size
        self.gr=[[] for i in range(size)]
        
    
    def add_edge(self,u,v):
        self.gr[u].append(v)
        self.gr[v].append(u)

    def node(self,v):
        return self.gr[v]

from collections import deque

def bfs(graph,start):
    dist=[-2 for i in range(graph.size)]
    dist[start]=0
    vert=deque([start])
    while len(vert)>0:
        pos=vert.popleft()
        for i in graph.node(pos):
            if dist[i]==-2:
                vert.append((i))
                dist[i]=dist[pos]+1
    return dist


class UnionFind():
    def __init__(self,size):
        self.uf=[-1 for i in range(size)]
        
    def unite(self,fir,sec):
        one=self.root(fir)
        two=self.root(sec)
        if one!=two:
            if self.uf[one]<self.uf[two]:
                one,two=two,one
            self.uf[two]+=self.uf[one]
            self.uf[one]=two
    
    def root(self,node):
        pos=node
        change=[]
        while self.uf[pos]>=0:
            change.append(pos)
            pos=self.uf[pos]
        for i in change:
            self.uf[i]=pos
        return pos

N,M=map(int,input().split())
subway=defaultdict(list)
for i in range(M):
    p,q,c=map(int,input().split())
    subway[c].append((p-1,q-1))
gr=Graph(N+M)
cnt=N
for i in subway:
    use=set()
    for j,k in subway[i]:
        use.add(j)
        use.add(k)
    use=list(use)
    use.sort()
    ser={}
    for j in range(len(use)):
        ser[use[j]]=j
    union=UnionFind(len(use))
    for j,k in subway[i]:
        union.unite(ser[j],ser[k])
    used={}
    for j in range(len(use)):
        if union.root(j) not in used:
            used[union.root(j)]=cnt
            cnt+=1
        gr.add_edge(used[union.root(j)],use[j])
print(bfs(gr,0)[N-1]//2)

Submission #47856645 - AtCoder Regular Contest 061

感想

今なら青までありそう