Pyro 是一個由Uber AI Lab所開發的Probabilistic Programming Language(PPL),用程式語言來描述具有隨機性的程序或是過程;從另一個角度來看,圖機率模型(PGM)是用圖的方式來描述一個機率過程,而PPL則是可以讓你用上程式語言來描述,條件、迴圈等等都可以用上,能描述更為複雜的機率過程
PPL也存在好一陣子了,像是MIT基於LISP所開發的Church、webppl都是,但都偏向是學術性的語言;隨著深度學習也用上了一些機率推斷方法,像是MCMC、Variational inference等,也出現基於Tensorflow、Pytorch框架的PPL,能更好的與深度學習方法結合,並且引入很多有用的機率數學工具
Introduction of Pyro
根據官網描述,Pyro有幾個特點
- Universal: Pyro can represent any computable probability distribution.
- Scalable: Pyro scales to large data sets with little overhead.
- Minimal: Pyro is implemented with a small core of powerful, composable abstractions.
- Flexible: Pyro aims for automation when you want it, control when you need it.
基於Pytorch開發,所以深度學習那些當然都可以整合一起用上,並且語法非常的Pythonic,也用上了很多高階特性,例如context manager,機率推斷演算法主要提供Stochastic Variational Inference(SVI),可以用上SGD來訓練模型,整體上的確以最小需要、高彈性、透明為主要設計,相較於基於Tensorflow的Edward、Tensorflow_probability,提供不只VI相關演算法,還有很多MCMC相關的,或是前一篇提到的NF,更多工具可以使用,但也相對不是那麼容易上手;不過結合更多機率特性是趨勢之一,這兩個主流的Deep PPL都很值得關注與學習。
Implement LDA model
以上介紹了Pyro,接著就來實作Latent Dirichlet allocation(LDA)模型,常用文本主題分析,且為圖機率模型的經典模型!接著用Pyro與SVI來實作並求解參數,本文並不會從頭介紹Pyro基礎,建議可以先看官網教學,程式碼我放在Colab,大家可以直接跑看看
Model
我們就直接從模型看起,先來看一下LDA式子長怎樣,$\phi$指每個字對應個主題的機率分佈,$\theta_d$指每一文件對應主題的機率分佈,$z$指每個字被分配到的主題,式子可以寫成:
$$
p(\phi,\theta,z,w) = p(\phi) \prod_{d=1}^D \prod_{n=1}^{N_d} p(z_{dn}|\theta_d)p(w_{dn}|z_{dn},\phi)
$$
式子對應Pyro程式碼如下:
1 |
|
Guide
$w$是已知觀測變量,想要估計的隱變量有$\phi,\theta,z$,我們可以用上Variational inference方法來最大化資料似然$\log p(w)$,並用另一個guide來近似$p(\phi,\theta,z|w)$,有關於VI相關會在後續寫個幾篇好好來講講
$$
q(\phi,\theta,z) = q(\phi) \prod_{d=1}^D q(\theta_d)\prod_{n=1}^{N_d} q(z_{dn})
$$
式子對應Pyro程式碼如下:
1 |
|
注意guide裡面有參數是需要訓練的,比如每一個字所對應的$z_dn$,我們都用一個Categorical分佈來代表主題,都有參數需要訓練,很直接去估文件裡的每個字代表什麼主題
Training
最後用SVI去最大化ELBO
1 | adam_params = {"lr": 0.01, "betas": (0.90, 0.999)} |
TraceEnum_ELBO適用於有離散變量需要估計的時候,會直接去做enumerate,不然單用Trace_ELBO遇到離散變量效果都很差
Result
1 | pyro.param('q_2') |
看一下第二筆文件裡的字,估計主題效果還不錯
1 | phi[0] |
來看第1個主題對應字的機率,都是第2個字的機率最高,代表主題1最相關的字為第2個,雖然分佈可能還有點差異,但效果也還不差
結語
深度學習結合機率方法是趨勢,這些PPL框架能更快投入研究與產品,後續會再分享一些Pyro應用與心得