Dec 5, 2016 - "Doubly stochastic variational Bayes for non-conjugate inference." (2014). â¡ David Wingate and Theophane
Variational Inference: Foundations and Modern Methods
David Blei, Rajesh Ranganath, Shakir Mohamed
NIPS 2016 Tutorial · December 5, 2016
Communities discovered in a 3.7M node network of U.S. Patents [Gopalan and Blei, PNAS 2013]
Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.
1
2
3
4
5
Game Season Team Coach Play Points Games Giants Second Players
Life Know School Street Man Family Says House Children Night
Film Movie Show Life Television Films Director Man Story Says
Book Life Books Novel Story Man Author House War Children
Wine Street Hotel House Room Night Place Restaurant Park Garden
6
7
8
9
10
Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator
Building Street Square Housing House Buildings Development Space Percent Real
Won Team Second Race Round Cup Open Game Play Win
Yankees Game Mets Season Run League Baseball Team Games Hit
Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers
11
12
13
14
15
Children School Women Family Parents Child Life Says Help Mother
Stock Percent Companies Fund Market Bank Investors Funds Financial Business
Church War Women Life Black Political Catholic Government Jewish Pope
Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition
Police Yesterday Man Officer Officers Case Found Charged Street Shot
Figure 5
Topics found in 1.8M articles from the New York Times
Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013).
a particular movie), our prediction of the rating depends on a linear combination of the user’s Blei, Wang, Paisley,toJMLR embedding and the movie’s embedding. We[Hoffman, can also use these inferred representations find groups of users that have similar tastes and groups of movies that are enjoyed by the same kinds
2013]
Attend, Infer, Repeat: Fast Scene Und
Scenes, concepts andGround-truth control. Figure 12. 3D scenes details: Left: object and camera po cup is closely aligned with ground-truth, thus not clearly [Eslami et al., 2016, Lake etvisible). al. 2015] We AIR framework. Middle: AIR achieves significantly lower reconstru much higher count inference accuracy. Right: Heatmap of locations The learned policy appears to be more dependent on identity (bottom)
ness and accuracy with that of a fully supervised network
zhen Italian Japanese
BalochiBantuKenya BantuSouthAfrica Basque
Bedouin
BiakaPygmy Brahui
Burusho Cambodian Colombian Dai Daur
Druze
French
Han Han−NChina HazaraHezhen Italian Japanese
Kalash Karitiana Lahu Makrani Mandenka
MayaMbutiPygmy Melanesian Miao M
prob
Adygei
Kalash Karitiana Lahu Makrani Mandenka
MayaMbutiPygmy Melanesian Miao Mongola Mozabite
NaxiOrcadian Oroqen
Palestinian
Papuan Pathan
Pima
Russian San Sardinian
She
Sindhi Surui Tu TujiaTuscan UygurXibo
Yakut
Yi
Yoruba
pops 1 2 3 4 5 6 7
Population analysis of 2 billion genetic measurements [Gopalan, Hao, Blei, Storey, Nature Genetics (in press)]
Neuroscience analysis of 220 million fMRI measurements [Manning et al., PLOS ONE 2014]
Mean Sample jpeg2K
jpeg
0.2bits/pixel
Compression and content generation. [Van den Oord et al., 2016, Gregor et al., 2016]
Analysis of 1.7M taxi trajectories, in Stan [Kucukelbir et al., 2016]
The probabilistic pipeline
KNOWLEDGE & QUESTION
DATA
R A. 7Aty
This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
pops 1 2 3 4 5 6 7
K=8
Make assumptions
CEU
Discover patterns
pops 1 2 3 4 5 6 7
Predict & Explore
8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
Customized data analysis is important to many fields.
Pipeline separates assumptions, computation, application
Eases collaborative solutions to statistics problems
28
The probabilistic pipeline
KNOWLEDGE & QUESTION
DATA
R A. 7Aty
This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
pops 1 2 3 4 5 6 7
K=8
Make assumptions
CEU
Discover patterns
pops 1 2 3 4 5 6 7
Predict & Explore
8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
Inference is the key algorithmic problem.
Answers the question: What does this model say about this data?
Our goal: General and scalable approaches to inference
28
KNOWLEDGE & QUESTION
DATA
R A. 7Aty
This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7
K=8 LWK
Make assumptions
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
Discover patterns
pops 1 2 3 4 5 6 7
Predict & Explore
8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
28
Criticize model Revise
[Box, 1980; Rubin, 1984; Gelman et al., 1996; Blei, 2014]
PART I Main ideas and historical context
Probabilistic Machine Learning
A probabilistic model is a joint distribution of hidden variables z and observed variables x, p(z, x).
Inference about the unknowns is through the posterior, the conditional distribution of the hidden variables given the observations p(z | x) =
p(z, x) . p(x)
For most interesting models, the denominator is not tractable. We appeal to approximate posterior inference.
Variational Inference p.z j x/ q.zI ⌫/
⌫
⇤
KL.q.zI ⌫ ⇤ / jj p.z j x//
⌫ init
VI turns inference into optimization.
Posit a variational family of distributions over the latent variables, q(z; ν)
Fit the variational parameters ν to be close (in KL) to the exact posterior. (There are alternative divergences, which connect to algorithms like EP, BP, and others.)
(a) Initialization
Example: Mixture of Gaussians
(c) Iteration 28 (a) Initialization
(d) Iteration 35
(c)Iteration Iteration2028 (b)
Evidence Lower Bound
(b) Iteration 20
(e) Iteration 50 (d) Iteration 35
(e) Iteration 50
Average Log Predictive
3;200
(c) Iteration 28
1 Evidence Lower Bound Average Log Predictive 3;500 1:1 3;200 1:2 1 3;800 1:3 3;500 1:1 4;100 1:4 1:2 3;800 0 10 20 30 40 50 60 0 10 201:330 40 50 60 4;100 1:4 Iterations Iterations (d) Iteration 35 (f) subcaption elbo (g) subcaption avelogpred 0 10 (e) 20Iteration 30 4050 50 60 0 10 20 30 40 50 60 Iterations Iterations Figure 1: Main caption (f) subcaption elbo
Evidence Lower Bound
(g) subcaption avelogpred
[images by Alp Kucukelbir] Figure 1: Main caption Average Log Predictive
History 1006
Carsten Peterson and James R . An derson
Coovergence 0 1 eM ColTeia llon Stabiles 2-04 - 1 XOR wilh Random WetgltJ
µj
Sj 0 .0
µi
Si - 0. 5
‘~ '0
' 00
10 00
10000
t.\ntler 0 1 S We8plll
Figure 5: {sf' B out} a nd vt vout from th e BM and MFT respec tively as functions of Nsweep o For details on architect ure, an nealing schedule , an d Tij values, see figure 3.
[Peterson and Anderson 1987]
[Jordan et (a) al. 1999]
(b)
[Hinton and van Camp 1993]
Figure 22: (a) A node Si in a sigmoid belief network machine with its Markov blanket. (b) The mean field equations yield a deterministic relationship, represented in the figure with the dotted lines, between the variational parameters µi and µj for nodes j in the Markov blanket of node i.
Corw flfOOOCll 0 1 L4ean Carela llon Oillerence
2-4-1 XOR with Random Welltlll
J
Variational inference adapts ideas from statistical physics to probabilistic • inference. Arguably, it began in the late eighties with Peterson and a tractable lower bound on the log likelihood and the variational parameter ξi can be • optimized along withwho the other variational parameters. Anderson (1987), used mean-field methods to fit a neural network. Saul and Jordan (1998) show that in the limiting case of networks in which each hidden • node • has• a • large _ number of parents, so that a central limit theorem can be invoked, the
This idea wasξipicked up byinterpretation Jordan’s lab in theexpectation early 1990s—Tommi parameter has a probabilistic as the approximate of σ(zi ), where σ(·) is again the logistic function. Jaakkola, Lawrence Saul, Zoubin Gharamani—who generalized it to For fixed values of the parameters ξi , by differentiating the KL divergence with respect to the variational parameters µi , we obtain the following consistency equations: many probabilistic models. (A review paper is Jordan et al., 1999.) ⎛ ⎞ _ _
'0
' 00
1000
10000
tbnber 01 SWeepli
is ters has for
The com with ing
Figure 6: Do as defined in equa tion (3 .17) as a functio n of Ns tKeep • For details on a rchitect ure, a nnealing sched ule, and T i j values, see figure
3.
#
µi = σ ⎝
θij µj + θi0 +
#
θji (µj − ξj ) +
#
Kji ⎠
(67)
j In parallel, Hinton andj Van Camp (1993)j also developed mean-field for & ' −ξ z + e(1−ξ )z where K is the derivative of − ln e with respect to µi . As Saul, et al. idea to the EM ji neural networks. Neal and Hinton (1993) connected this show, this term depends on node i, its child j, and the other parents (the “co-parents”) of algorithm, further variational for i,mixtures of node j.which Given thatlead the firstto term is a sum over contributions frommethods the parents of node the second term is a sum over contributions from the children of node i, we see that the expertsand (Waterhouse al., and HMMs 1997). consistency equation for aet given node1996) again involves contributions from(MacKay, the Markov blanket j j
j
j
of the node (see Fig. 22). Thus, as in the case of the Boltzmann machine, we find that the variational parameters are linked via their Markov blankets and the consistency equation (Eq. (67)) can be interpreted as a local message-passing algorithm.
We dom we of Thi net pen erro line cho cor dec ma
models with two-dimensional latent is Gaussian, linearly spaced coorCDF of the Gaussian to produce (a) NORB otted the corresponding generative
Today
(b) CIFAR
(c) Frey
igure 4. a) Performance on the NORB dataset. Left: Samples from the training data. Right: sampled pixel means from he model. b) Performance on CIFAR10 patches. Left: Samples from the training data. Right: Sampled pixel means om the model. c) Frey faces data. Left: data samples. Right: model samples.
˛ D 1:5;
D1
✓
data { i n t N; // number o f o b s e r v a t i o n s i n t x [ N ] ; // d i s c r e t e - v a l u e d o b s e r v a t i o n s } parameters { // l a t e n t v a r i a b l e , must be p o s i t i v e r e a l < l o w e r =0> t h e t a ; } model { // non - c o n j u g a t e p r i o r f o r l a t e n t v a r i a b l e theta ~ w e i b u l l ( 1 . 5 , 1) ;
xn
igure 5. Imputation results on MNIST digits. The first olumn shows the true data. Column 2 shows pixel locaons set as missing in grey. The [Kingma andremaining Wellingcolumns 2013]show mputations and denoising of the images for 15 iterations, arting left to right. Top: 60% missingness. Middle: 80% missingness. Bottom: 5x5 patch missing.
space
(d) 20-D latent space
NIST for different dimensionalities
N Figure 6. Two dimensional embedding of the MNIST data et al.to2014] set. Each[Rezende colour corresponds one of the digit classes.
6.5. Data Visualisation
}
// l i k e l i h o o d f o r ( n i n 1 :N) x [ n ] ~ poisson ( theta ) ;
Figure 2: Specifying a simple nonconjugate probability model in Stan. [Kucukelbir et al. 2015]
analysis posits a prior density p.✓/ on the latent variables. Combining the likelihood with the pri gives the joint density p.X; ✓/ D p.X j ✓/ p.✓/.
variable models such as DLGMs are often We usedfocus on approximate inference for differentiable probability models. These models have conti There is now a Latent flurry new work variational inference, making it for visualisationof of high-dimensional data sets.on We latent uous variables ✓. They also have a gradient of the log-joint with respect to the latent˚ variabl project the MNIST data set to a 2-dimensional latent r log p.X; ✓/. The gradient is valid within the support of the prior supp.p.✓// D ✓ j ✓ space and use this 2-D embedding as a visualisation of accurate, and applying it to more scalable, easier to derive, faster, more matics and experimental design. We show the ability R and p.✓/ > 0 ✓ R , where K is the dimension of the latent variable space. This support s the data. A 2-dimensional embedding of the MNIST f the model to impute missing data using the MNIST is important: it determines the support of the posterior density and plays a key role later in the pap data set is shown in figure 6. The classes separate ata set in figure 5. complicated We test the imputation abilitymodels and applications. into di↵erent regions indicating that such a toolWe canmake no assumptions about conjugacy, either full or conditional. nder two di↵erent missingness types (Little & Rubin,
✓ K
K
2
987): Missing-at-random (MAR), where we consider 0% and 80% of the pixels to be missing randomly, and Not Missing-at-random (NMAR), where we consider a quare region of✓the image to be missing. The model roduces very good completions in both test cases. here is uncertainty in the identity of the image. This expected and reflected in the errors in these complej j ons as the resampling procedure is run, and further emonstrates the ability of the model to capture the iversity of the underlying data. We do not integrate ver the missing values in our imputation procedure, that simulates a Markov chain ut use a procedure hat we show converges to the true marginal distribuon. The procedure to sample from the missing pixels iven the observed pixels is explained in appendix E.
be useful in gaining insight into the structure of highFor example, consider a model that contains a Poisson likelihood with unknown rate, p.x j ✓/. T dimensional data sets.
ntains a KL term that can often be Modern VI touches many important areas: probabilistic programming, he prior p (z) = N (0, I) and the learning, neural networks, convex optimization, Bayesian dimensionalityreinforcement of z. Let µ and d let µ and simply denote the
observed variable x is discrete; the latent rate ✓ is continuous and positive. Place a Weibull pri on ✓, defined over the positive real numbers. The resulting joint density describes a nonconjuga 7. Discussion differentiable probability model. (See Figure 2.) Its partial derivative @=@✓ p.x; ✓/ is valid within t support of the Weibull distribution, supp.p.✓// D RC ⇢ R. Because this model is nonconjugate, t Our algorithm generalises to a large class of models with continuous latent variables, which include Gausposterior is not a Weibull distribution. This presents a challenge for classical variational inferenc sian, non-negative or sparsity-promoting latent In variSection 2.3, we will see how handles this model.
statistics, and myriad applications. ables. For models with discrete latent variables (e.g.,
Many machine learning models are differentiable. For example: linear and logistic regression, matr sigmoid belief networks), policy-gradient approaches factorization with continuous or discrete measurements, linear dynamical systems, and Gaussian pr that improve upon the REINFORCE approach remain
but intelligent design is needed to cesses. Mixture models, hiddensome Markov models, and topicnewer models have discrete random variabl Our goal today the ismosttogeneral, teach you the basics, explain of the ideas, control the gradient-variance in high dimensionalMarginalizing setout these discrete variables renders these models differentiable. (We show an examp N (z; 0, I) dz tings. in Section 3.3.) However, marginalization is not tractable for all models, such as the Ising mod and to suggest These open areas of new research. sigmoid belief networks, and (untruncated) Bayesian nonparametric models. models are typically used with a large number
X
1
(µ2j +
2 j)
2.2
Variational Inference
Bayesian inference requires the posterior density p.✓ j X/, which describes how the latent variabl vary when conditioned on a set of observations X. Many posterior densities are intractable becau their normalization constants lack closed forms. Thus, we seek to approximate the posterior.
Variational Inference: Foundations and Modern Methods Part II: Mean-field VI and stochastic VI Jordan+, Introduction to Variational Methods for Graphical Models, 1999 Ghahramani and Beal, Propagation Algorithms for Variational Bayesian Learning, 2001 Hoffman+, Stochastic Variational Inference, 2013
Part III: Stochastic gradients of the ELBO Kingma and Welling, Auto-Encoding Variational Bayes, 2014 Ranganath+, Black Box Variational Inference, 2014 Rezende+, Stochastic Backpropagation and Approximate Inference in Deep Generative Models, 2014
Part IV: Beyond the mean field Agakov and Barber, An Auxiliary Variational Method, 2004 Gregor+, DRAW: A recurrent neural network for image generation, 2015 Rezende+, Variational Inference with Normalizing Flows, 2015 Ranganath+, Hierarchical Variational Models, 2015 Maaløe+, Auxiliary Deep Generative Models, 2016
Variational Inference: Foundations and Modern Methods p.z j x/ q.zI ⌫/
⌫
⇤
KL.q.zI ⌫ ⇤ / jj p.z j x//
⌫ init
VI approximates difficult quantities from complex models. With stochastic optimization we can
scale up VI to massive data
enable VI on a wide class of difficult models
enable VI with elaborate and flexible families of approximations
PART II Mean-field variational inference and stochastic variational inference
Motivation: Topic Modeling
Topic models use posterior inference to discover the hidden thematic structure in a large collection of documents.
Example: Latent Dirichlet Allocation (LDA)
Documents exhibit multiple topics.
Example: Latent Dirichlet Allocation (LDA) Topics gene dna genetic .,,
Documents
0.04 0.02 0.01
life 0.02 evolve 0.01 organism 0.01 .,,
brain neuron nerve ...
0.04 0.02 0.01
data 0.02 number 0.02 computer 0.01 .,,
Each topic is a distribution over words
Each document is a mixture of corpus-wide topics
Each word is drawn from one of those topics
Topic proportions and assignments
Example: Latent Dirichlet Allocation (LDA) Topics
Documents
Topic proportions and assignments
But we only observe the documents; everything else is hidden.
So we want to calculate the posterior p(topics, proportions, assignments | documents) (Note: millions of documents; billions of latent variables)
LDA as a Graphical Model
Proportions parameter
Per-word topic assignment
Per-document topic proportions
˛
✓d
zd;n
Topic parameter
Observed word
Topics
wd;n
⌘
ˇk N D
K
Encodes assumptions about data with a factorization of the joint
Connects assumptions to algorithms for computing with data
Defines the posterior (through the joint)
Posterior Inference
˛
✓d
zd;n
wd;n
⌘
ˇk N D
K
The posterior of the latent variables given the documents is p(β, θ , z, w) p(β, θ , z | w) = R R P . z p(β, θ , z, w) β θ
We can’t compute the denominator, the marginal p(w).
We use approximate inference.
Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.
1
2
3
4
5
Game Season Team Coach Play Points Games Giants Second Players
Life Know School Street Man Family Says House Children Night
Film Movie Show Life Television Films Director Man Story Says
Book Life Books Novel Story Man Author House War Children
Wine Street Hotel House Room Night Place Restaurant Park Garden
6
7
8
9
10
Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator
Building Street Square Housing House Buildings Development Space Percent Real
Won Team Second Race Round Cup Open Game Play Win
Yankees Game Mets Season Run League Baseball Team Games Hit
Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers
11
12
13
14
15
Children School Women Family Parents Child Life Says Help Mother
Stock Percent Companies Fund Market Bank Investors Funds Financial Business
Church War Women Life Black Political Catholic Government Jewish Pope
Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition
Police Yesterday Man Officer Officers Case Found Charged Street Shot
Figure 5
Topics found in 1.8M articles from the New York Times
Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013).
a particular movie), our prediction of the rating depends on a linear combination of the user’s embedding and the movie’s embedding. We can also use these inferred representations to find
Mean-field VI and Stochastic VI
Subsample data
Infer local structure
Road map:
Define the generic class of conditionally conjugate models
Derive classical mean-field VI
Derive stochastic VI, which scales to massive data
Update global structure
A Generic Class of Models Global variables
Local variables
ˇ xi
zi
n p(β, z, x) = p(β)
n Y i=1
p(zi , xi | β)
The observations are x = x1:n .
The local variables are z = z1:n .
The global variables are β.
The ith data point xi only depends on zi and β. Compute p(β, z | x).
A Generic Class of Models Global variables
Local variables
ˇ xi
zi
n p(β, z, x) = p(β)
n Y i=1
p(zi , xi | β)
A complete conditional is the conditional of a latent variable given the observations and other latent variables. Assume each complete conditional is in the exponential family, p(zi | β, xi ) = h(zi ) exp{η` (β, xi )> zi − a(η` (β, xi ))} p(β | z, x) = h(β) exp{ηg (z, x)> β − a(ηg (z, x))}.
A Generic Class of Models Global variables
Local variables
ˇ xi
zi
n p(β, z, x) = p(β)
n Y i=1
p(zi , xi | β)
A complete conditional is the conditional of a latent variable given the observations and other latent variable. The global parameter comes from conjugacy [Bernardo and Smith, 1994] Pn ηg (z, x) = α + i=1 t(zi , xi ), where α is a hyperparameter and t(·) are sufficient statistics for [zi , xi ].
A Generic Class of Models Global variables
Local variables
ˇ xi
zi
n p(β, z, x) = p(β)
n Y i=1
Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)
p(zi , xi | β) Dirichlet process mixtures, HDPs
Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)
Variational Inference
p.z j x/ q.zI ⌫/
⌫
⇤
KL.q.zI ⌫ ⇤ / jj p.z j x//
⌫ init
Minimize KL between q(β, z; ν) and the posterior p(β, z | x).
The Evidence Lower Bound
L (ν) = Eq [log p(β, z, x)] − Eq [log q(β, z; ν)]
KL is intractable; VI optimizes the evidence lower bound (ELBO) instead.
The ELBO trades off two terms.
It is a lower bound on log p(x). Maximizing the ELBO is equivalent to minimizing the KL.
The first term prefers q(·) to place its mass on the MAP estimate. The second term encourages q(·) to be diffuse.
Caveat: The ELBO is not convex.
Mean-field Variational Inference
ˇ
ˇ ELBO
zi
xi
i
zi
n
We need to specify the form of q(β, z). The mean-field family is fully factorized, Qn q(β, z; λ, φ) = q(β; λ) i=1 q(zi ; φi ). Each factor is the same family as the model’s complete conditional, p(β | z, x) = h(β) exp{ηg (z, x)> β − a(ηg (z, x))} q(β; λ) = h(β) exp{λ> β − a(λ)}.
n
Mean-field Variational Inference
ˇ
ˇ ELBO
zi
xi
i
n
Optimize the ELBO, L (λ, φ) = Eq [log p(β, z, x)] − Eq [log q(β, z)] .
Traditional VI uses coordinate ascent [Ghahramani and Beal, 2001] λ∗ = Eφ ηg (z, x) ; φi∗ = Eλ [η` (β, xi )] Iteratively update each parameter, holding others fixed.
Notice the relationship to Gibbs sampling [Gelfand and Smith, 1990] .
Caveat: The ELBO is not convex.
zi n
Mean-field Variational Inference for LDA
˛
�d
�d;n
�d
zd;n
�k
wd;n
�
ˇk
N D
K
The local variables are the per-document variables θd and zd .
The global variables are the topics β1 , . . . , βK .
The variational distribution is q(β, θ , z) =
K Y k=1
q(βk ; λk )
D Y d=1
q(θd ; γd )
N Y n=1
q(zd,n ; φd,n )
0.2 0.1 0.0
Probability
0.3
0.4
Mean-field Variational Inference for LDA
1 8 16 26 36 46 56 66 76 86 96 Topics
Mean-field Variational Inference for LDA
“Genetics” human genome dna genetic genes sequence gene molecular sequencing map information genetics mapping project sequences
“Evolution” “Disease” evolution disease evolutionary host species bacteria organisms diseases life resistance origin bacterial biology new groups strains phylogenetic control living infectious diversity malaria group parasite new parasites two united common tuberculosis
“Computers” computer models information data computers system network systems model parallel methods networks software new simulations
Classical Variational Inference
Input: data x, model p(β, z, x). Initialize λ randomly. repeat for each data point i do Set local parameter φi ← Eλ [η` (β, xi )]. end
Set global parameter λ←α+ until the ELBO has converged
Pn
i=1 Eφi
[t(Zi , xi )] .
A Generic Class of Models Global variables
Local variables
ˇ xi
zi
n p(β, z, x) = p(β)
n Y i=1
Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)
p(zi , xi | β) Dirichlet process mixtures, HDPs
Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)
Stochastic Variational Inference
˛
�d
�d;n
�d
zd;n
�k
wd;n N D
�
ˇk K
Classical VI is inefficient:
Do some local computation for each data point. Aggregate these computations to re-estimate global structure. Repeat.
This cannot handle massive data.
Stochastic variational inference (SVI) scales VI to massive data.
Stochastic Variational Inference
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7
K=8 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1
GLOBAL HIDDEN STRUCTURE 2 3 4 5 6
MASSIVE DATA
7 8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
28
Subsample data
Infer local structure
Update global structure
Stochastic Optimization
Replace the gradient with cheaper noisy estimates [Robbins and Monro, 1951]
Guaranteed to converge to a local optimum [Bottou, 1996]
Has enabled modern machine learning
Stochastic Optimization
With noisy gradients, update ˆ ν L (νt ) νt+1 = νt + ρt ∇
ˆ ν L (ν) = ∇ν L (ν) Requires unbiased gradients, E ∇
Requires the step size sequence ρt follows the Robbins-Monro conditions
Stochastic Variational Inference
The natural gradient of the ELBO [Amari, 1998; Sato, 2001] Pn ∗ ∇nat λ L (λ) = α + i=1 Eφi [t(Zi , xi )] − λ. Construct a noisy natural gradient, j ∼ Uniform(1, . . . , n)
ˆ nat L (λ) ∇ λ
= α + nEφj∗ [t(Zj , xj )] − λ.
This is a good noisy gradient.
Its expectation is the exact gradient (unbiased). It only depends on optimized parameters of one data point (cheap).
Stochastic Variational Inference Input: data x, model p(β, z, x). Initialize λ randomly.
Set ρt appropriately.
repeat Sample j ∼ Unif(1, . . . , n).
Set local parameter φ ← Eλ η` (β, xj ) . Set intermediate global parameter ˆ = α + nEφ [t(Zj , xj )]. λ
Set global parameter ˆ λ = (1 − ρt )λ + ρt λ. until forever
Stochastic Variational Inference
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7
K=8 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1
GLOBAL HIDDEN STRUCTURE 2 3 4 5 6
MASSIVE DATA
7 8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
28
Subsample data
Infer local structure
Update global structure
Stochastic Variational Inference in LDA
˛
�d
�d;n
�d
zd;n
�k
wd;n N D
�
ˇk K
Sample a document
Estimate the local variational parameters using the current topics
Form intermediate topics from those local parameters
Update topics as a weighted average of intermediate and current topics
Stochastic Variational Inference in LDA
900
Online 98K
Perplexity
850 800 750
Batch 98K
Online 3.3M
700 650 600 103.5
Documents analyzed
Top eight words
2048
4096
104
104.5
105
Documents seen (log scale) 8192
systems systems service road health systems made communication health service service companies announced billion market national language communication west care company language road billion
12288
16384
105.5 32768
106
106.5
49152
65536
service service business business business systems companies service service industry companies systems companies companies service business business industry industry companies company company company services services billion industry management company company health market systems management management industry billion services public public
[Hoffman et al., 2010]
Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.
1
2
3
4
5
Game Season Team Coach Play Points Games Giants Second Players
Life Know School Street Man Family Says House Children Night
Film Movie Show Life Television Films Director Man Story Says
Book Life Books Novel Story Man Author House War Children
Wine Street Hotel House Room Night Place Restaurant Park Garden
6
7
8
9
10
Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator
Building Street Square Housing House Buildings Development Space Percent Real
Won Team Second Race Round Cup Open Game Play Win
Yankees Game Mets Season Run League Baseball Team Games Hit
Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers
11
12
13
14
15
Children School Women Family Parents Child Life Says Help Mother
Stock Percent Companies Fund Market Bank Investors Funds Financial Business
Church War Women Life Black Political Catholic Government Jewish Pope
Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition
Police Yesterday Man Officer Officers Case Found Charged Street Shot
Figure 5
Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013). Topics using the HDP, found in 1.8M articles from the New York Times
a particular movie), our prediction of the rating depends on a linear combination of the user’s embedding and the movie’s embedding. We can also use these inferred representations to find
SVI scales many models
Subsample data
Infer local structure
Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)
Update global structure
Dirichlet process mixtures, HDPs Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)
apanese
activated tyrosine phosphorylation activation phosphorylation kinase
science scientists says research people
research funding support nih program
united states women universities students education
virus hiv aids infection viruses patients disease treatment drugs clinical
Adygei
BalochiBantuKenya BantuSouthAfrica Basque
Bedouin
BiakaPygmy Brahui
Burusho Cambodian Colombian Dai Daur
Druze
French
Han Han−NChina HazaraHezhen Italian Japanese
amino acids cdna sequence isolated protein
wild type mutant mutations mutants mutation
bacteria bacterial host resistance parasite
mice antigen t cells antigens immune response
Kalash Karitiana Lahu Makrani Mandenka
proteins protein binding domain domains
receptor receptors ligand ligands apoptosis
cells cell expression cell lines bone marrow
enzyme enzymes iron active site reduction
rna dna rna polymerase cleavage site
sequence sequences genome dna sequencing
gene disease mutations families mutation
genetic population populations differences variation
Palestinian
surface liquid surfaces fluid model
fossil record birds fossils dinosaurs fossil
development embryos drosophila genes expression
NaxiOrcadian Oroqen
computer problem information computers problems
surface tip image sample device
Papuan Pathan
Pima
physicists particles physics particle experiment
reaction reactions molecule molecules transition state pressure high pressure pressures core inner core
mantle crust upper mantle meteorites ratios
species forest forests populations ecosystems
ancient found impact million years ago africa
synapses ltp glutamate synaptic neurons
materials organic polymer polymers molecules
laser optical light electrons quantum
magnetic magnetic field spin superconductivity superconducting
plants plant gene genes arabidopsis
cells proteins researchers protein found
MayaMbutiPygmy Melanesian Miao Mongola Mozabite
neurons stimulus motor visual cortical
brain memory subjects left task
p53 cell cycle activity cyclin regulation
earthquake earthquakes fault images data
volcanic deposits magma eruption volcanism
Russian San Sardinian
climate ocean ice changes climate change
She
stars astronomers universe galaxies galaxy
sun solar wind earth planets planet
co2 carbon carbon dioxide methane water
ozone atmospheric measurements stratosphere concentrations
Sindhi Surui Tu TujiaTuscan UygurXibo
Yakut
Yi
Yoruba
pops 1 2
prob
3 4 5 6 7
Kalash Karitiana Lahu Makrani Mandenka
MayaMbutiPygmy Melanesian Miao Mongola Mozabite
NaxiOrcadian Oroqen
Palestinian
Papuan Pathan
Pima
Russian San Sardinian
She
Sindhi Surui Tu TujiaTuscan UygurXibo
Yakut
Yi
Yoruba
pops 1 2 3 4 5 6 7
PART III Stochastic Gradients of the ELBO
Review: The Promise
KNOWLEDGE & QUESTION
DATA
R A. 7Aty
This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions
K=7 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH
pops 1 2 3 4 5 6 7
K=8
Make assumptions
CEU
Discover patterns
pops 1 2 3 4 5 6 7 8
K=9 LWK
YRI
ACB
ASW
CDX
CHB
CHS
JPT
KHV
CEU
FIN
GBR
IBS
TSI
MXL
PUR
CLM
PEL
GIH pops 1 2 3 4 5 6 7 8 9
Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.
28
Realized for conditionally conjugate models
What about the general case?
Predict & Explore
The Variational Inference Recipe Start with a model:
p(z, x)
The Variational Inference Recipe Choose a variational approximation:
q(z; ν)
The Variational Inference Recipe Write down the ELBO:
L (ν) = Eq(z;ν) [log p(x, z) − log q(z; ν)]
The Variational Inference Recipe Compute the expectation(integral):
Example: L (ν) = xν2 + log ν
The Variational Inference Recipe Take derivatives:
Example: ∇ν L (ν) = 2xν +
1 ν
The Variational Inference Recipe Optimize:
νt+1 = νt + ρt ∇ν L
The Variational Inference Recipe
p(x, z) q(z; ⌫)
Z
q.zI ⌫/
(· · · )q(z; ⌫)dz
r⌫
Example: Bayesian Logistic Regression
Data pairs yi , xi
xi are covariates
yi are label
z is the regression coefficient
Generative process p(z) ∼ N(0, 1)
p(yi | xi , z) ∼ Bernoulli(σ(zxi ))
VI for Bayesian Logistic Regression
Assume:
We have one data point (y, x)
x is a scalar
The approximating family q is the normal; ν = (µ, σ2 )
The ELBO is L (µ, σ2 ) = Eq [log p(z) + log p(y | x, z) − log q(z)]
VI for Bayesian Logistic Regression
L (µ, σ2 ) =
Eq [log p(z) − log q(z) + log p(y | x, z)]
VI for Bayesian Logistic Regression
L (µ, σ2 ) =
=
Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2
VI for Bayesian Logistic Regression
L (µ, σ2 ) =
= =
Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2
VI for Bayesian Logistic Regression
L (µ, σ2 ) =
= = =
Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2 1 1 2 2 − (µ + σ ) + log σ2 + yxµ − Eq [log(1 + exp(xz))] 2 2
VI for Bayesian Logistic Regression
L (µ, σ2 ) =
= = =
Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2 1 1 2 2 − (µ + σ ) + log σ2 + yxµ − Eq [log(1 + exp(xz))] 2 2
We are stuck. 1. We cannot analytically take that expectation. 2. The expectation hides the objectives dependence on the variational parameters. This makes it hard to directly optimize.
Options?
Derive a model specific bound: [Jordan and Jaakola; 1996], [Braun and McAuliffe; 2008], others
More general approximations that require model-specific analysis: [Wang and Blei; 2013], [Knowles and Minka; 2011]
Nonconjugate Models
Nonlinear Time series Models
Discrete Choice Models
Deep Latent Gaussian Models
Bayesian Neural Networks
Models with Attention (such as DRAW) Generalized Linear Models (Poisson Regression) Stochastic Volatility Models
Deep Exponential Families (e.g. Sparse Gamma or Poisson) Correlated Topic Model (including nonparametric variants) Sigmoid Belief Network
We need a solution that does not entail model specific work
Black Box Variational Inference (BBVI)
Black box variational inference
REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES
MASSIVE DATA
FAMILIES
ANY MODEL BLACK BOX VARIATIONAL INFERENCE
I Sample from
p.ˇ; z j x/
q. /
I Form noisy gradients without model-specific computation I Use stochastic optimization
The Problem in the Classical VI Recipe
p(x, z) q(z; ⌫)
Z
q.zI ⌫/
(· · · )q(z; ⌫)dz
r⌫
The New VI Recipe
p(x, z) q(z; ⌫)
r⌫
Use stochastic optimization!
Z
q.zI ⌫/
(· · · )q(z; ⌫)dz
Computing Gradients of Expectations
Define g(z, ν) = log p(x, z) − log q(z; ν)
What is ∇ν L
Z
∇ν L = ∇ν
q(z; ν)g(z, ν)dz
Z
= Z =
∇ν q(z; ν)g(z, ν) + q(z; ν)∇ν g(z, ν)dz q(z; ν)∇ν log q(z; ν)g(z, ν) + q(z; ν)∇ν g(z, ν)dz
= Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Using ∇ν log q =
∇ν q q
Roadmap
Score Function Gradients
Pathwise Gradients
Amortized Inference
Score Function Gradients of the ELBO
Score Function Estimator Recall ∇ν L = Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Simplify: Eq [∇ν g(z, ν)] = Eq [∇ν log q(z; ν)] = 0 Gives the gradient: ∇ν L = Eq(z;ν) [∇ν log q(z; ν)(log p(x, z) − log q(z; ν))]
Sometimes called likelihood ratio or REINFORCE gradients [Glynn 1990; Williams, 1992; Wingate+ 2013; Ranganath+ 2014; Mnih+ 2014]
Noisy Unbiased Gradients
Gradient: Eq(z;ν) [∇ν log q(z; ν)(log p(x, z) − log q(z; ν))] Noisy unbiased gradients with Monte Carlo!
1X ∇ν log q(zs ; ν)(log p(x, zs ) − log q(zs ; ν)), S s=1 S
where zs ∼ q(z; ν)
Basic BBVI
Algorithm 1: Basic Black Box Variational Inference Input : Model log p(x, z), Variational approximation q(z; ν) Output : Variational Parameters: ν while not converged do z[s] ∼ q // Draw S samples from q ρ = t-th value of a Robbins Monro sequence PS ν = ν + ρ 1S s=1 ∇ν log q(z[s]; ν)(log p(x, z[s]) − log q(z[s]; ν)) t=t+1 end
The requirements for inference
The noisy gradient: 1X ∇ν log q(zs ; ν)(log p(x, zs ) − log q(zs ; ν)), S s=1 S
where zs ∼ q(z; ν) To compute the noisy gradient of the ELBO we need
Sampling from q(z) Evaluating ∇ν log q(z; ν)
Evaluating log p(x, z) and log q(z)
There is no model specific work: black box criteria are satisfied
Black Box Variational Inference
Black box variational inference
REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES
MASSIVE DATA
FAMILIES
ANY MODEL BLACK BOX VARIATIONAL INFERENCE
I Sample from
p.ˇ; z j x/
q. /
I Form noisy gradients without model-specific computation I Use stochastic optimization
Problem: Basic BBVI doesn’t work
Variance of the gradient can be a problem Varq(z;ν) = Eq(z;ν) [(∇ν log q(z; ν)(log p(x, z) − log q(z; ν)) − ∇ν L )2 ]. 2.0
PDF Abs Mu Score
1.5 1.0 0.5 0.0
2.0
1.5
1.0
0.5
0.0
0.5
1.0
1.5
2.0
Intuition: Sampling rare values can lead to large scores and thus high variance
Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4
PDF f = x + x2 fˆ; h = x2 fˆ; h = f
3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5
0.0
0.5
1.0
1.5
2.0
h is a function of our choice a is chosen to minimize the variance Good h have high correlation with the original function f
Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4
PDF f = x + x2 fˆ; h = x2 fˆ; h = f
3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5
0.0
0.5
1.0
1.5
2.0
For variational inference we need functions with known q expectation Set h as ∇ν log q(z; ν) Simple as Eq [∇ν log q(z; ν)] = 0 for any q
Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4
PDF f = x + x2 fˆ; h = x2 fˆ; h = f
3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5
0.0
0.5
1.0
1.5
2.0
Many of the other techniques from Monte Carlo can help: Importance Sampling, Quasi Monte Carlo, Rao-Blackwellization [Ruiz+ 2016; Ranganath+2014; Titsias+2015; Mnih+2016]
Nonconjugate Models
Nonlinear Time series Models
Discrete Choice Models
Deep Latent Gaussian Models
Bayesian Neural Networks
Models with Attention (such as DRAW) Generalized Linear Models (Poisson Regression) Stochastic Volatility Models
Deep Exponential Families (e.g. Sparse Gamma or Poisson) Correlated Topic Model (including nonparametric variants) Sigmoid Belief Network
We can design models based on data rather than inference.
More Assumptions?
The current black box criteria
Sampling from q(z) Evaluating ∇ν log q(z; ν)
Evaluating log p(x, z) and log q(z)
Can we make additional assumptions that are not too restrictive?
Pathwise Gradients of the ELBO
Pathwise Estimator
Assume 1. z = t(ε, ν) for ε ∼ s(ε) implies z ∼ q(z; ν) Example: ε ∼ Normal(0, 1) z = εσ + µ
→ z ∼ Normal(µ, σ2 ) 2. log p(x, z) and log q(z) are differentiable with respect to z
Pathwise Estimator Recall ∇ν L = Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Rewrite using using z = t(ε, ν) ∇ν L = Es(ε) [∇ν log s(ε)g(t(ε, ν), ν) + ∇ν g(t(ε, ν), ν)] To differentiate: ∇L (ν) = Es(ε) [∇ν g(t(ε, ν), ν)]
= Es(ε) [∇z [log p(x, z) − log q(z; ν)]∇ν t(ε, ν) − ∇ν log q(z; ν)]
= Es(ε) [∇z [log p(x, z) − log q(z; ν)]∇ν t(ε, ν)] This is also known as the reparameterization gradient.
[Glasserman 1991; Fu 2006; Kingma+ 2014; Rezende+ 2014; Titsias+ 2014]
ion (7) both lead to unbiased estimates of the exact gradient. Variance Comparison require the gradient of the model and thus applies to more gh variance. 103 101 10
1
10
3
100 101 Number of
Pathwise Score Function Score Function with Control Variate
102 103 samples
(b) Multivariate Nonlinear Regression Model
[Kucukelbir+ 2016]
ator variances. The gradient estimator exhibits lower over, it does not require control variate variance reduction, tions.
Score Function Estimator vs. Pathwise Estimator
Score Function Differentiates the density ∇ν q(z; ν)
Works for discrete and continuous models
Pathwise Differentiates the function ∇z [log p(x, z) − log q(z; ν)]
Works for large class of variational approximations Variance can be a big problem
Requires differentiable models Requires variational approximation to have form z = t(ε, ν) Generally better behaved variance
Amortized Inference
Hierarchical Models
A generic class of models Global variables
Local variables
ˇ xi
zi
n p( , z, x) = p( )
n Y i=1
p(zi , xi | )
Ñ
Bayesian mixture models
Ñ
Dirichlet process mix
Ñ
Time series models
Ñ
Multilevel regression
Mean Field Variational Approximation
ˇ
ˇ ELBO
zi
xi
i
n
zi n
SVI: Revisited Input: data x, model p(β, z, x). Initialize λ randomly.
Set ρt appropriately.
repeat Sample j ∼ Unif(1, . . . , n).
Set local parameter φ ← Eλ η` (β, xj ) .
Set intermediate global parameter ˆ = α + nEφ [t(Zj , xj )]. λ
Set global parameter until forever
ˆ λ = (1 − ρt )λ + ρt λ.
SVI: The problem Input: data x, model p(β, z, x). Initialize λ randomly.
Set ρt appropriately.
repeat Sample j ∼ Unif(1, . . . , n).
Set local parameter φ ← Eλ η` (β, xj ) .
Set intermediate global parameter
ˆ = α + nEφ [t(Zj , xj )]. λ
Set global parameter
ˆ λ = (1 − ρt )λ + ρt λ.
until forever
These expectations are no longer tractable Inner stochastic optimization needed for each data point.
SVI: The problem Input: data x, model p(β, z, x). Initialize λ randomly.
Set ρt appropriately.
repeat Sample j ∼ Unif(1, . . . , n).
Set local parameter φ ← Eλ η` (β, xj ) .
Set intermediate global parameter
ˆ = α + nEφ [t(Zj , xj )]. λ
Set global parameter
ˆ λ = (1 − ρt )λ + ρt λ.
until forever Idea: Learn a mapping f from xi to φi
Amortizing Inference
ELBO: L (λ, φ1...n ) = Eq [log p(β, z, x)] − Eq log q(β; λ) + Amortizing the ELBO with inference network f : L (λ, θ ) = Eq [log p(β, z, x)] − Eq log q(β; λ) +
n X i=1
n X
q(zi ; φi )
i=1
q(zi | xi ; φi = fθ (xi ))
[Dayan+ 1995; Heess+ 2013; Gershman+ 2014, many others]
Amortized SVI Input: data x, model p(β, z, x). Initialize λ randomly.
Set ρt appropriately.
repeat Sample β ∼ q(β; λ).
Sample j ∼ Unif(1, . . . , n).
Sample zj ∼ q(zj | xj ; φθ (xj ).
Compute stochastic gradients ∇ˆλ L = ∇λ log q(β; λ)(log p(β) + n log p(xj , zj | β) − log q(β))
∇ˆθ L = n∇θ log q(zj | xj ; θ )(log p(xj , zj | β) − log q(zj | xk ; θ )) Update λ = λ + ρt ∇ˆλ
θ = θ + ρt ∇ˆθ . until forever
A computational-statistical tradeoff
Amortized inference is faster, but admits a smaller class of approximations
The size of the smaller class depends on the flexibility of f
n Y
i=1
q(zi ;
i)
n Y
i=1
q(zi |xi ; f✓ (xi ))
Example: Variational Autoencoder (VAE)
z
p(z) = Normal(0, 1)
x
p(x|z) = Normal(µ (z),
2
(z))
µ and σ2 are deep networks with parameters β. [Kingma+ 2014; Rezende+ 2014]
Example: Variational Autoencoder (VAE)
z
z ~ q(z | x)
Model p(x |z)
Inference Network q(z |x)
2
q(z|x) = Normal(f✓µ (x), f✓ (x))
x ~ p(x | z) Data x
All functions are deep networks
Example: Variational Autoencoder (VAE)
Analogies Analogy-making
Rules of Thumb for a New Model
If log p(x, z) is z differentiable
Try out an approximation q that is reparameterizable
If log p(x, z) is not z differentiable
Use score function estimator with control variates
Add further variance reductions based on experimental evidence
Rules of Thumb for a New Model
If log p(x, z) is z differentiable
Try out an approximation q that is reparameterizable
If log p(x, z) is not z differentiable
Use score function estimator with control variates
Add further variance reductions based on experimental evidence
General Advice:
Use coordinate specific learning rates (e.g. RMSProp, AdaGrad)
Annealing + Tempering
Consider parallelizing across samples from q
Software
Systems with Variational Inference:
Venture, WebPPL, Edward, Stan, PyMC3, Infer.net, Anglican
Good for trying out lots of models
Differentiation Tools:
Theano, Torch, Tensorflow, Stan Math, Caffe
Can lead to more scalable implementations of individual models
PART IV Beyond the Mean Field
BlackVariational box variational inference Review: Bound and Optimisation
REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES
MASSIVE DATA
FAMILIES
ANY MODEL BLACK BOX VARIATIONAL INFERENCE
I Sample from q. / and variational inference. Probabilistic modelling
ScalableI inference through stochastic optimisation.
p.ˇ; z j x/
I Form noisy gradients without model-specific computation
Use stochastic optimization
Black-box variational inference: Non-conjugate models, Monte Carlo gradient estimators and amortised inference. These advances empower us with new way to design more flexible approximate posterior distributions q(z)
Mean-field Approximations
Fully-factorised
p.z j x/ q.zI ⌫/
⌫
⇤
z2
KL.q.zI ⌫ ⇤ / jj p.z j x//
z1
⌫ init
qM F (z|x) =
z3 Y
q(zk )
k
Key part of algorithm is the choice of approximate posterior q(z). log p(x) ≥ L = Eq(z|x) [log p(x, z)] −Eq(z|x) [log q(z|x)] | {z }| {z } Expected likelihood
Entropy
Mean-Field Posterior Approximations
Deep Latent Gaussian Model Latent variable model p(x,z)
p(z)
z
x p(x|z)
Mean-field or fully-factorised posterior is usually not sufficient
Real-world Posterior Distributions
Deep Latent Gaussian Model Latent variable model p(x,z)
p(z)
z
x p(x|z)
Complex dependencies · Non-Gaussian distributions · Multiple modes
Families of Approximate Posteriors Two high-level goals:
Build richer approximate posterior distributions.
Maintain computational efficiency and scalability.
Families of Approximate Posteriors Two high-level goals:
Build richer approximate posterior distributions.
Maintain computational efficiency and scalability. True Posterior
Fully-factorised
z2
z2
z1
z3 Most Expressive
q ⇤ (z|x) / p(x|z)p(z)
z1
z3 Least Expressive
qM F (z|x) =
Y k
q(zk )
Families of Approximate Posteriors Two high-level goals:
Build richer approximate posterior distributions.
Maintain computational efficiency and scalability. True Posterior
Fully-factorised
z2
z2
z1
z3 Most Expressive
q ⇤ (z|x) / p(x|z)p(z)
z1
z3 Least Expressive
qM F (z|x) =
Y k
Same as the problem of specifying a model of the data itself.
q(zk )
Structured Posterior Approximations True Posterior
Structured Approx.
z2 z1
Fully-factorised
z2
z2 z3
z1 z1
z3
Most Expressive
q ⇤ (z|x) / p(x|z)p(z)
z3
Least Expressive
q(z) =
Y k
qk (zk |{zj }j6=k )
qM F (z|x) =
Y
q(zk )
k
Structured mean field: Introduce any form of dependency to provide a richer approximating class of distributions. [Saul and Jordan, 1996.]
Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ}
Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ} Covariance models: Structure of covariance Σ describes dependency. Full covariance is richest, but computationally expensive. diag(↵1 , . . . , ↵K )
diag(↵1 , . . . , ↵K ) +uu>
Mean-field
+
Rank-1
diag(↵1 , . . . , ↵K ) P + j uj uj >
UU>
104 100 96 92 88 84
+…+
Rank-J
Test neg. marginal likelihood
+
Full
Rank1
Diag
Wake−Sleep
FA
Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ} Covariance models: Structure of covariance Σ describes dependency. Full covariance is richest, but computationally expensive. diag(↵1 , . . . , ↵K )
diag(↵1 , . . . , ↵K ) +uu>
Mean-field
+
Rank-1
diag(↵1 , . . . , ↵K ) P + j uj uj >
UU>
104 100 96 92 88 84
+…+
Rank-J
Test neg. marginal likelihood
+
Rank1
Full
Approximate posterior is always Gaussian.
Diag
Wake−Sleep
FA
Beyond Gaussian Approximations Autoregressive distributions: Impose an ordering and non-linear dependency on all preceding variables. Y qAR (z; ν) = qk (zk |z