AliasDict, a python dict for alias resolver

用 dict 存 alias 相當直覺

在做 NN optimization 的時候,會常常用到 alias,像是有些 operations 例如 RELU,是直接 execute in place 的,這時候原本的表達格式是不同的 tensors 的話,就可以將 output tensor 視為 input tensor 的一個 alias,名字不同但其實是同一塊這樣。

使用 python 的話,會很直覺地將這個關係用 dict 存起來,到時候就可以查詢這個 tensor name,是不是另一個的 alias。剛開始是很美好,但是實際用起來需要考慮到不少事情。

像是萬一是多層的關係呢?所以需要 recursive 的去尋找到最後的 alias。那萬一 recursive 過程中有 circular reference 了呢?這樣會一直跑不完,所以還需要有檢查並指示出這種錯誤。其實為了這個另外寫了不少東西。

因此在這次有機會重新實做的時候,

設計了一個 AliasDict class 封裝了上面這些邏輯

AliasDict 是個特別的 python dict

首先這個 AliasDict 可能還是要是個 dict,因為用 subscript operator [] 依然還是相當直覺,但希望要能直接在內部實做上解決上面的問題,不用另外處理,所以可以來定義這個 AliasDict 該有些什麼能力

  1. 如果沒有 key 的話,回傳 key 本身
  2. 可以 recursive 找到沒有 alias key 的 value
  3. 如果有 circular reference 會 raise ValueError

根據上面的描述,可以來寫些測試了,用 pytest 來寫的話大致上就是這樣

class TestAliasDict:
    def test_aliasdict(self, request):
        maps = AliasDict()
        maps['a'] = 'A'
        maps['c'] = 'a'
        maps['d'] = 'c'
        assert maps['a'] == 'A'
        assert maps['b'] == 'b'
        assert maps['c'] == 'A'
        assert maps['d'] == 'A'

def test_aliasdict_invalid_multiple(self, request):
        maps = AliasDict()
        maps['a'] = 'b'
        maps['b'] = 'c'
        maps['c'] = 'a'
        with pytest.raises(ValueError):
            a = maps['a']
        with pytest.raises(ValueError):
            a = maps['b']
        with pytest.raises(ValueError):
            a = maps['c']

AliasDict 的 getitem, missing

接著來實做 AliasDict,按造需求就是先繼承自 python dict,然後要改的實做就是兩個內部的 methods,__getitem__()__missing__(),而為了要檢查有沒有 circular reference,另外新增了一個 _getfinalitem(),會一直記住曾經找過的 alias 的 keys,如果新要找的 key 竟然曾經出現過,那就是有 circular 了,就不要再繼續下去直接 raise Error 出來。

lass AliasDict(dict):
    def __missing__(self, key):
        return key

    def __getitem__(self, key):
        cur = []
        return self._getfinalitem(key, cur)

    def _getfinalitem(self, key, cur=[]):
        val = super().__getitem__(key)
        if val == key:
            return val
        else:
            if val in cur:
                raise ValueError('circular reference detected {val} in {cur}')
            cur.append(val)
            return self._getfinalitem(val, cur)

AliasDict 也許其他地方也用得到

實做完 AliasDict 之後,跟之前的相比覺得封裝起來清爽許多,雖然並沒有真的行動,但的確是有股衝動將之前的 code 換成這組。

另外總覺得在其他的地方也應該會有相似的需求才是,所以特別寫出來紀錄一下,也許下次不是需要這麼一樣的功能,但透過這個 AliasDict 的例子,可以了解到原來 python 的 dict 是可以被繼承的,而透過改寫 __getitem__, __missing__ 可以得到很有趣的東西。

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *