AT274の技術だったりなかったり

あっとのTECH LOG

競プロのこととか技術のこととか。たまに日常を書きます。

Python:重み付きUnion-Find木について

今回は素集合データ構造であるUnion-Find木に重みを付けた、Weighted(重み付き) Union-Find木についてまとめます。
Union-Find木についてよくわからないという方は、

at274.hatenablog.com
こちらを先に見ていただいた方がいいかと思います。

Weighted Union-Find木について

実装前に、Weighted Union-Findがどんなものかざっくり説明しておきましょう。
まずはイメージ図を見てやってください。
f:id:AT274:20180203125413p:plain

特に深い説明はいらないかと思います。 イメージ図通りです。

例えば、「AさんはBさんより3歳年上。CさんはBさんより2つ下。ではAさんはCさんよりいくつ上?(もしくは下?)」みたいな問題に対して使えますね。
グラフの形状によっては、もっといろんな問題に使うことができるでしょう。(私の頭は悪いので応用できるかわかりませんが)

Pythonによる実装

ではPythonで実装していきましょう。

準備

イメージ図で説明したように、根への重みを管理することでWeighted Union-Findは実現します。
ですのでそれを管理するリスト、weightを初期化時に定義してあげましょう。ただし、根への重みでなく、親への重みという管理の仕方をします。
また例によってノード番号とリストのインデックスを揃えたいので、長さは(n+1)になっています。

    def __init__(self, n):
        self.par = [i for i in range(n+1)]
        self.rank = [0] * (n+1)
        # 親への重みを管理
        self.weight = [0] * (n+1)

検索(find)

次にあるノードの根を返すfind関数に手を加えます。

この関数は再起関数を用いることで、検索時に通ったノードを全て根につなぎ直すことができるものでした。この仕組みを重み付けにも応用し、検索時に通った全てのノードについて、根への重みをweightに格納するようにしましょう。実装例は以下のようになります。

    # 検索
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            y = self.find(self.par[x])
            # 親への重みを追加しながら根まで走査
            self.weight[x] += self.weight[self.par[x]]
            self.par[x] = y
            return y

併合(union)

次はグループ同士を併合するunionです。
ここで、union関数はx, y, wを引数とし、これらは「xからyへの重み(距離)はwである」ことを示すとします。

さて、union関数では「木の高さを比較し、低い方の根から高い方の根につなぐ」のでした。
ですから、xとyそれぞれが属する木の高さによって重み付けの方法を変えなければなりません。

ベースとなるUnionFindクラスの実装の都合上、「xの木の高さ<yの木の高さ」の場合と、「xの木の高さ≧yの木の高さ」の場合に分けて考えることにしましょう。
まずは、「xの木の高さ<yの木の高さ」の場合です。 上のイメージ図で用いた例を再使用しています。
f:id:AT274:20180203133718p:plain
こんな感じになります。ちなみに2の子から3への重みへの場合には、2の子から2に対する重みを加算する必要があります。
これらより、xの属する木の根の、yの属する木の根に対する重みは、「与えられた重み+yの根への重み-xの根への重み」となることがわかります。

次は、「xの木の高さ≧yの木の高さ」の場合です。
f:id:AT274:20180203134751p:plain
こんな感じになりますね。この場合においては、yの属する木の根の、xの属する木の根に対する重みは、「xの根への重み - 与えられた重み - yの根への重み」となります。一番最後の-yの根への重みは、3よりも深いノードに対する重みを計算する際に必要になります。(深いからと言って重みが大きくなるわけではないことに注意!)

これらから、union関数は以下のようになります。また見やすさのために、xの根、yの根をそれぞれrx, ryとしています。(しなくてもいいですが、findのタイミングに注意してくださいね。)

    # 併合
    def union(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        # xの木の高さ < yの木の高さ
        if self.rank[rx] < self.rank[ry]:
            self.par[rx] = ry
            self.weight[rx] = w - self.weight[x] + self.weight[y]
        # xの木の高さ ≧ yの木の高さ
        else:
            self.par[ry] = rx
            self.weight[ry] = -w - self.weight[y] + self.weight[x]
            # 木の高さが同じだった場合の処理
            if self.rank[rx] == self.rank[ry]:
                self.rank[rx] += 1

完成形

完成したWeightedUnionFindクラスは以下のようになります。
また、よく使うのでxからyへのコストを返す関数diffを作成しています。

class WeightedUnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n+1)]
        self.rank = [0] * (n+1)
        # 根への距離を管理
        self.weight = [0] * (n+1)

    # 検索
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            y = self.find(self.par[x])
            # 親への重みを追加しながら根まで走査
            self.weight[x] += self.weight[self.par[x]]
            self.par[x] = y
            return y

    # 併合
    def union(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        # xの木の高さ < yの木の高さ
        if self.rank[rx] < self.rank[ry]:
            self.par[rx] = ry
            self.weight[rx] = w - self.weight[x] + self.weight[y]
        # xの木の高さ ≧ yの木の高さ
        else:
            self.par[ry] = rx
            self.weight[ry] = -w - self.weight[y] + self.weight[x]
            # 木の高さが同じだった場合の処理
            if self.rank[rx] == self.rank[ry]:
                self.rank[rx] += 1

    # 同じ集合に属するか
    def same(self, x, y):
        return self.find(x) == self.find(y)

    # xからyへのコスト
    def diff(self, x, y):
        return self.weight[x] - self.weight[y]

なんかうまいことまとまらなくて申し訳ないです。
読んでいただきありがとうございましたm(._.)m