LDA model use Pyro

Pyro 是一個由Uber AI Lab所開發的Probabilistic Programming Language(PPL),用程式語言來描述具有隨機性的程序或是過程;從另一個角度來看,圖機率模型(PGM)是用圖的方式來描述一個機率過程,而PPL則是可以讓你用上程式語言來描述,條件、迴圈等等都可以用上,能描述更為複雜的機率過程

PPL也存在好一陣子了,像是MIT基於LISP所開發的Churchwebppl都是,但都偏向是學術性的語言;隨著深度學習也用上了一些機率推斷方法,像是MCMC、Variational inference等,也出現基於Tensorflow、Pytorch框架的PPL,能更好的與深度學習方法結合,並且引入很多有用的機率數學工具

Introduction of Pyro

根據官網描述,Pyro有幾個特點

  • Universal: Pyro can represent any computable probability distribution.
  • Scalable: Pyro scales to large data sets with little overhead.
  • Minimal: Pyro is implemented with a small core of powerful, composable abstractions.
  • Flexible: Pyro aims for automation when you want it, control when you need it.

基於Pytorch開發,所以深度學習那些當然都可以整合一起用上,並且語法非常的Pythonic,也用上了很多高階特性,例如context manager,機率推斷演算法主要提供Stochastic Variational Inference(SVI),可以用上SGD來訓練模型,整體上的確以最小需要、高彈性、透明為主要設計,相較於基於Tensorflow的EdwardTensorflow_probability,提供不只VI相關演算法,還有很多MCMC相關的,或是前一篇提到的NF,更多工具可以使用,但也相對不是那麼容易上手;不過結合更多機率特性是趨勢之一,這兩個主流的Deep PPL都很值得關注與學習。

Implement LDA model

以上介紹了Pyro,接著就來實作Latent Dirichlet allocation(LDA)模型,常用文本主題分析,且為圖機率模型的經典模型!接著用Pyro與SVI來實作並求解參數,本文並不會從頭介紹Pyro基礎,建議可以先看官網教學,程式碼我放在Colab,大家可以直接跑看看

Model

我們就直接從模型看起,先來看一下LDA式子長怎樣,$\phi$指每個字對應個主題的機率分佈,$\theta_d$指每一文件對應主題的機率分佈,$z$指每個字被分配到的主題,式子可以寫成:

$$
p(\phi,\theta,z,w) = p(\phi) \prod_{d=1}^D \prod_{n=1}^{N_d} p(z_{dn}|\theta_d)p(w_{dn}|z_{dn},\phi)
$$

式子對應Pyro程式碼如下:

1
2
3
4
5
6
7
8
9
10
@pyro.poutine.broadcast
def model(data):
phi = pyro.sample("phi",dist.Dirichlet(torch.ones([K, V])).independent(1))

for d in pyro.irange("documents", D):
theta_d = pyro.sample("theta_%d"%d, dist.Dirichlet(torch.ones([K])))

with pyro.iarange("words_%d"%d, N[d]):
z = pyro.sample("z_%d"%d, dist.Categorical(theta_d))
pyro.sample("w_%d"%d, dist.Categorical(phi[z]), obs=data[d])

Guide

$w$是已知觀測變量,想要估計的隱變量有$\phi,\theta,z$,我們可以用上Variational inference方法來最大化資料似然$\log p(w)$,並用另一個guide來近似$p(\phi,\theta,z|w)$,有關於VI相關會在後續寫個幾篇好好來講講

$$
q(\phi,\theta,z) = q(\phi) \prod_{d=1}^D q(\theta_d)\prod_{n=1}^{N_d} q(z_{dn})
$$

式子對應Pyro程式碼如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
@pyro.poutine.broadcast
def guide(data):
beta_q = pyro.param("beta_q", torch.ones([K, V]),constraint=constraints.positive)
phi_q = pyro.sample("phi",dist.Dirichlet(beta_q).independent(1))

for d in pyro.irange("documents", D):
alpha_q = pyro.param("alpha_q_%d"%d, torch.ones([K]),constraint=constraints.positive)
q_theta_d = pyro.sample("theta_%d"%d, dist.Dirichlet(alpha_q))

with pyro.iarange("words_%d"%d, N[d]):
q_i = pyro.param("q_%d"%d, torch.randn([N[d], K]).exp(),
constraint=constraints.simplex)
pyro.sample("z_%d"%d, dist.Categorical(q_i))

注意guide裡面有參數是需要訓練的,比如每一個字所對應的$z_dn$,我們都用一個Categorical分佈來代表主題,都有參數需要訓練,很直接去估文件裡的每個字代表什麼主題

Training

最後用SVI去最大化ELBO

1
2
3
4
5
6
7
adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(model, config_enumerate(guide, 'parallel'), optimizer, loss=TraceEnum_ELBO(max_iarange_nesting=1))

for _ in range(3000):
loss = svi.step(data)

TraceEnum_ELBO適用於有離散變量需要估計的時候,會直接去做enumerate,不然單用Trace_ELBO遇到離散變量效果都很差

Result

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
pyro.param('q_2')
>>>
tensor([[0.8072, 0.0559, 0.0606, 0.0543, 0.0220],
[0.9249, 0.0258, 0.0146, 0.0242, 0.0104],
[0.9250, 0.0258, 0.0145, 0.0243, 0.0105],
[0.3842, 0.0408, 0.5215, 0.0381, 0.0153],
[0.9406, 0.0206, 0.0110, 0.0193, 0.0085],
[0.9406, 0.0206, 0.0111, 0.0193, 0.0084],
[0.9406, 0.0204, 0.0110, 0.0192, 0.0088],
[0.9406, 0.0205, 0.0111, 0.0193, 0.0085],
[0.9250, 0.0258, 0.0145, 0.0242, 0.0105]], grad_fn=<DivBackward1>)

z[2]
>>>
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0])

看一下第二筆文件裡的字,估計主題效果還不錯

1
2
3
4
5
6
7
8
9
10
phi[0]
>>>
tensor([1.1714e-03, 7.1110e-01, 1.4448e-03, 1.8015e-03, 1.0734e-01, 2.6613e-06,
1.1921e-07, 3.6218e-03, 9.7121e-02, 1.1921e-07, 8.0717e-03, 7.9834e-07,
6.8322e-02, 1.1921e-07, 1.1921e-07])

dist.Dirichlet(pyro.param('beta_q')).sample()[0]
>>>
tensor([0.0089, 0.4932, 0.0083, 0.0549, 0.1486, 0.0355, 0.0272, 0.0291, 0.0509,
0.0098, 0.0045, 0.0631, 0.0331, 0.0021, 0.0308])

來看第1個主題對應字的機率,都是第2個字的機率最高,代表主題1最相關的字為第2個,雖然分佈可能還有點差異,但效果也還不差

結語

深度學習結合機率方法是趨勢,這些PPL框架能更快投入研究與產品,後續會再分享一些Pyro應用與心得