SketchRNN

之前Google有一個有趣的小遊戲Quickdraw,給你一個題目要你20秒內畫出來,當然除了讓你打發時間外,大家手繪大作也變成機器的精神糧食,拿來學習了XD

7FB391C9-1C05-4051-88FE-80FF2961ED8F

SketchRNN是一個向量軌跡的生成模型,這篇部落格文章可以看到很多有趣的結果,詳細的模型架構與訓練過程可以看這一篇論文

整個模型可分為3個部分:

  • Seq2seq autoencoder + Gaussian mixture model
  • Variational inference

Seq2seq autoencoder + Gaussian mixture model

Seq2seq+GMM這是很經典的序列至序列架構,最早是出現在這一篇Generating Sequences With Recurrent Neural Networks,並且有一個手寫筆跡產生的應用,也算是這個塗鴉的前身之作,作者是Alex Graves,他可是Hinton的博士生喔!

架構上是一個序列至序列模型,$encoder$是雙向的LSTM,每個時間點吃入筆跡$x,y$實數座標值與下筆狀態類別,而$decoder$是單向的LSTM,每個時間點輸出下一個時間點的筆跡座標與下筆狀態

下筆狀態是離散的類別,這個很容易處理,只是簡單的分類問題,用個$softmax$就搞定了,輸出就是離散類別的機率分佈,但是連續資料我們要怎麼處理呢?怎麼建立連續的機率分佈呢?

我們常用高斯分佈是來作為連續變量的機率分佈,每個時間點可以輸出$\mu,\sigma$來建立高斯分佈,但對於複雜的資料,只用一個高斯分佈可能太過簡單,無法很好的捕捉資料的分佈,例如:筆跡座標可能有蠻大的變異性,很難說每個時間點就只用一個高斯就能代表,畢盡畫圖很隨性的呀~

所以用上GMM來增加輸出的機率分佈複雜度,由多個高斯分佈組合起來,形成一個複雜的機率分佈,如圖,用上3個不同參數的高斯分佈,就能組合出紅色的這個複雜分佈

每個時間點模型給出的預測$\hat{y}$為公式17,分別為是否停止、GMM的參數等,就可以直接做MLE訓練

BD8784F1-5355-4BE2-A726-BBB86271B739

Variational inference

VAE這個東西是一個非常熱門的生成模型,下次有機會再開一篇來講講,如果想簡單了解,可以直接看這裡,直觀上SketchRNN透過$encoder$把序列壓縮到一個連續空間中,這個連續空間為高斯分佈$N(0,1)$,而$decoder$要能還原回原本的軌跡,但我們能從連續空間中抽樣出新的點,由$decoder$產生軌跡,就是產生新的塗鴉囉!

Implement and Tricks

我的實作版本,以Pytorch實作

Data processing & Training

  • quickdraw-dataset有約300多種類別的塗鴉資料可以下載
  • 軌跡都是正整數,要先剔除一些異常資料,再做縮放正規化,這部分可以參考實作
  • Google實作還有做一些data augmentation,有些類別資料很少,但我實作裡沒有
  • 訓練就跟一般seq2seq差不多,而且蠻快的
  • 產生新樣本時,不論是直接從高斯抽樣,或是從給定的序列去還原,有一個部分要特別注意,就是tempture,調整tempture會改變輸出的機率分佈,低的溫度會使分佈更為集中,高的反之亦然,溫度低筆觸就不會到處亂飄,因為機率分佈集中了,比較容易畫出像樣的圖,而溫度高,分佈就比較均勻,筆觸就容易到處飄,就不容易畫出好的圖

Reference

  1. Teaching Machines to Draw
  2. A Neural Representation of Sketch Drawings
  3. magenta/models/sketch_rnn
  4. alexis-jacq/Pytorch-Sketch-RNN