11class UnionFind :
2- def __init__ (self , n ):
3- self ._parent = [i for i in range (n )]
4- self ._rank = [0 for _ in range (n )]
5- self ._group_size = [1 for _ in range (n )]
6- self .num_of_groups = n
2+ def __init__ (self , N ): self .N , self .group_count , self .root , self .rank = N , N , [- 1 ] * N , [0 ] * N
3+ def __repr__ (self ): return str (self .all_groups ())
74
85 def find (self , x ):
9- vs = []
10- while self ._parent [x ] != x :
11- vs .append (x )
12- x = self ._parent [x ]
13- for v in vs : self ._parent [v ] = x
6+ while self .root [x ] >= 0 : x = self .root [x ]
147 return x
158
169 def union (self , x , y ):
17- px , py = self .find (x ), self .find (y )
18- if px == py : return
19- if self ._rank [px ] < self ._rank [py ]:
20- self ._parent [px ] = py
21- self ._group_size [py ] += self ._group_size [px ]
22- else :
23- self ._parent [py ] = px
24- self ._group_size [px ] += self ._group_size [py ]
25- if self ._rank [px ] == self ._rank [py ]: self ._rank [py ] += 1
26- self .num_of_groups -= 1
10+ x , y = self .find (x ), self .find (y )
11+ if x == y : return
12+ if self .rank [x ] > self .rank [y ]: x , y = y , x
13+ self .root [y ] += self .root [x ]
14+ self .root [x ] = y
15+ if self .rank [x ] == self .rank [y ]: self .rank [y ] += 1
16+ self .group_count -= 1
2717
28- def is_same (self , x , y ): return self .find (x ) == self .find (y )
29- def group_size (self , x ): return self ._group_size [self .find (x )]
18+ def same (self , x , y ): return self .find (x ) == self .find (y )
19+ def count (self , x ): return - self .root [self .find (x )]
20+ def members (self , x ): return [i for i in range (self .N ) if self .same (x , i )]
21+ def roots (self ): return [i for i , x in enumerate (self .root ) if x < 0 ]
22+ def all_groups (self ):
23+ d = defaultdict (lambda : [])
24+ for i in range (self .N ): d [self .find (i )].append (i )
25+ return dict (d )
0 commit comments