Union Find Template

class UnionFind(object):
    def __init__(self,n):
        self.parents = list(range(n))
        self.count = n
    
    def find(self,x):
        if x != self.parents[x]:
            self.parents[x] = self.find(self.parents[x])
        return self.parents[x]
    
    def union(self,x,y):
        xSet = self.find(x)
        ySet = self.find(y)
        self.parents[ySet] = xSet
        if xSet == ySet:
            return False
        self.count-=1
        return True

这里面这段Code非常重要

def find(self,x):
        if x == self.parents[x]:
            return x
        return self.find(self.parents[x])

这么写其实没错的,但是效率会很低,因为每次你call这个find recursion function的时候,你并没有把parents这个arr更新,也就是说你每次都在做重复计算。但如果下面这么写,你就是在每次find时把结果更新到parents里,相当于cache了。

def find(self,x):
        if x != self.parents[x]:
            self.parents[x] = self.find(self.parents[x])
        return self.parents[x]

Last updated

Was this helpful?