Variational Inference: Foundations and Modern Methods

2 downloads 140 Views 25MB Size Report
Dec 5, 2016 - "Doubly stochastic variational Bayes for non-conjugate inference." (2014). □ David Wingate and Theophane
Variational Inference: Foundations and Modern Methods

David Blei, Rajesh Ranganath, Shakir Mohamed

NIPS 2016 Tutorial · December 5, 2016

Communities discovered in a 3.7M node network of U.S. Patents [Gopalan and Blei, PNAS 2013]

Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.

1

2

3

4

5

Game Season Team Coach Play Points Games Giants Second Players

Life Know School Street Man Family Says House Children Night

Film Movie Show Life Television Films Director Man Story Says

Book Life Books Novel Story Man Author House War Children

Wine Street Hotel House Room Night Place Restaurant Park Garden

6

7

8

9

10

Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator

Building Street Square Housing House Buildings Development Space Percent Real

Won Team Second Race Round Cup Open Game Play Win

Yankees Game Mets Season Run League Baseball Team Games Hit

Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers

11

12

13

14

15

Children School Women Family Parents Child Life Says Help Mother

Stock Percent Companies Fund Market Bank Investors Funds Financial Business

Church War Women Life Black Political Catholic Government Jewish Pope

Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition

Police Yesterday Man Officer Officers Case Found Charged Street Shot

Figure 5

Topics found in 1.8M articles from the New York Times

Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013).

a particular movie), our prediction of the rating depends on a linear combination of the user’s Blei, Wang, Paisley,toJMLR embedding and the movie’s embedding. We[Hoffman, can also use these inferred representations find groups of users that have similar tastes and groups of movies that are enjoyed by the same kinds

2013]

Attend, Infer, Repeat: Fast Scene Und

Scenes, concepts andGround-truth control. Figure 12. 3D scenes details: Left: object and camera po cup is closely aligned with ground-truth, thus not clearly [Eslami et al., 2016, Lake etvisible). al. 2015] We AIR framework. Middle: AIR achieves significantly lower reconstru much higher count inference accuracy. Right: Heatmap of locations The learned policy appears to be more dependent on identity (bottom)

ness and accuracy with that of a fully supervised network

zhen Italian Japanese

BalochiBantuKenya BantuSouthAfrica Basque

Bedouin

BiakaPygmy Brahui

Burusho Cambodian Colombian Dai Daur

Druze

French

Han Han−NChina HazaraHezhen Italian Japanese

Kalash Karitiana Lahu Makrani Mandenka

MayaMbutiPygmy Melanesian Miao M

prob

Adygei

Kalash Karitiana Lahu Makrani Mandenka

MayaMbutiPygmy Melanesian Miao Mongola Mozabite

NaxiOrcadian Oroqen

Palestinian

Papuan Pathan

Pima

Russian San Sardinian

She

Sindhi Surui Tu TujiaTuscan UygurXibo

Yakut

Yi

Yoruba

pops 1 2 3 4 5 6 7

Population analysis of 2 billion genetic measurements [Gopalan, Hao, Blei, Storey, Nature Genetics (in press)]

Neuroscience analysis of 220 million fMRI measurements [Manning et al., PLOS ONE 2014]

Mean Sample jpeg2K

jpeg

0.2bits/pixel

Compression and content generation. [Van den Oord et al., 2016, Gregor et al., 2016]

Analysis of 1.7M taxi trajectories, in Stan [Kucukelbir et al., 2016]

The probabilistic pipeline

KNOWLEDGE & QUESTION

DATA

R A. 7Aty

This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

pops 1 2 3 4 5 6 7

K=8

Make assumptions

CEU

Discover patterns

pops 1 2 3 4 5 6 7

Predict & Explore

8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

„

Customized data analysis is important to many fields.

„

Pipeline separates assumptions, computation, application

„

Eases collaborative solutions to statistics problems

28

The probabilistic pipeline

KNOWLEDGE & QUESTION

DATA

R A. 7Aty

This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

pops 1 2 3 4 5 6 7

K=8

Make assumptions

CEU

Discover patterns

pops 1 2 3 4 5 6 7

Predict & Explore

8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

„

Inference is the key algorithmic problem.

„

Answers the question: What does this model say about this data?

„

Our goal: General and scalable approaches to inference

28

KNOWLEDGE & QUESTION

DATA

R A. 7Aty

This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7

K=8 LWK

Make assumptions

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

Discover patterns

pops 1 2 3 4 5 6 7

Predict & Explore

8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

28

Criticize model Revise

[Box, 1980; Rubin, 1984; Gelman et al., 1996; Blei, 2014]

PART I Main ideas and historical context

Probabilistic Machine Learning

„

A probabilistic model is a joint distribution of hidden variables z and observed variables x, p(z, x).

„

Inference about the unknowns is through the posterior, the conditional distribution of the hidden variables given the observations p(z | x) =

„

p(z, x) . p(x)

For most interesting models, the denominator is not tractable. We appeal to approximate posterior inference.

Variational Inference p.z j x/ q.zI ⌫/





KL.q.zI ⌫ ⇤ / jj p.z j x//

⌫ init

„

VI turns inference into optimization.

„

Posit a variational family of distributions over the latent variables, q(z; ν)

„

Fit the variational parameters ν to be close (in KL) to the exact posterior. (There are alternative divergences, which connect to algorithms like EP, BP, and others.)

(a) Initialization

Example: Mixture of Gaussians

(c) Iteration 28 (a) Initialization

(d) Iteration 35

(c)Iteration Iteration2028 (b)

Evidence Lower Bound

(b) Iteration 20

(e) Iteration 50 (d) Iteration 35

(e) Iteration 50

Average Log Predictive

3;200

(c) Iteration 28

1 Evidence Lower Bound Average Log Predictive 3;500 1:1 3;200 1:2 1 3;800 1:3 3;500 1:1 4;100 1:4 1:2 3;800 0 10 20 30 40 50 60 0 10 201:330 40 50 60 4;100 1:4 Iterations Iterations (d) Iteration 35 (f) subcaption elbo (g) subcaption avelogpred 0 10 (e) 20Iteration 30 4050 50 60 0 10 20 30 40 50 60 Iterations Iterations Figure 1: Main caption (f) subcaption elbo

Evidence Lower Bound

(g) subcaption avelogpred

[images by Alp Kucukelbir] Figure 1: Main caption Average Log Predictive

History 1006

Carsten Peterson and James R . An derson

Coovergence 0 1 eM ColTeia llon Stabiles 2-04 - 1 XOR wilh Random WetgltJ

µj

Sj 0 .0

µi

Si - 0. 5

‘~ '0

' 00

10 00

10000

t.\ntler 0 1 S We8plll

Figure 5: {sf' B out} a nd vt vout from th e BM and MFT respec tively as functions of Nsweep o For details on architect ure, an nealing schedule , an d Tij values, see figure 3.

[Peterson and Anderson 1987]

[Jordan et (a) al. 1999]

(b)

[Hinton and van Camp 1993]

Figure 22: (a) A node Si in a sigmoid belief network machine with its Markov blanket. (b) The mean field equations yield a deterministic relationship, represented in the figure with the dotted lines, between the variational parameters µi and µj for nodes j in the Markov blanket of node i.

Corw flfOOOCll 0 1 L4ean Carela llon Oillerence

2-4-1 XOR with Random Welltlll

„ J

„

Variational inference adapts ideas from statistical physics to probabilistic • inference. Arguably, it began in the late eighties with Peterson and a tractable lower bound on the log likelihood and the variational parameter ξi can be • optimized along withwho the other variational parameters. Anderson (1987), used mean-field methods to fit a neural network. Saul and Jordan (1998) show that in the limiting case of networks in which each hidden • node • has• a • large _ number of parents, so that a central limit theorem can be invoked, the

This idea wasξipicked up byinterpretation Jordan’s lab in theexpectation early 1990s—Tommi parameter has a probabilistic as the approximate of σ(zi ), where σ(·) is again the logistic function. Jaakkola, Lawrence Saul, Zoubin Gharamani—who generalized it to For fixed values of the parameters ξi , by differentiating the KL divergence with respect to the variational parameters µi , we obtain the following consistency equations: many probabilistic models. (A review paper is Jordan et al., 1999.) ⎛ ⎞ _ _

'0

' 00

1000

10000

tbnber 01 SWeepli

is ters has for

The com with ing

Figure 6: Do as defined in equa tion (3 .17) as a functio n of Ns tKeep • For details on a rchitect ure, a nnealing sched ule, and T i j values, see figure

3.

„

#

µi = σ ⎝

θij µj + θi0 +

#

θji (µj − ξj ) +

#

Kji ⎠

(67)

j In parallel, Hinton andj Van Camp (1993)j also developed mean-field for & ' −ξ z + e(1−ξ )z where K is the derivative of − ln e with respect to µi . As Saul, et al. idea to the EM ji neural networks. Neal and Hinton (1993) connected this show, this term depends on node i, its child j, and the other parents (the “co-parents”) of algorithm, further variational for i,mixtures of node j.which Given thatlead the firstto term is a sum over contributions frommethods the parents of node the second term is a sum over contributions from the children of node i, we see that the expertsand (Waterhouse al., and HMMs 1997). consistency equation for aet given node1996) again involves contributions from(MacKay, the Markov blanket j j

j

j

of the node (see Fig. 22). Thus, as in the case of the Boltzmann machine, we find that the variational parameters are linked via their Markov blankets and the consistency equation (Eq. (67)) can be interpreted as a local message-passing algorithm.

We dom we of Thi net pen erro line cho cor dec ma

models with two-dimensional latent is Gaussian, linearly spaced coorCDF of the Gaussian to produce (a) NORB otted the corresponding generative

Today

(b) CIFAR

(c) Frey

igure 4. a) Performance on the NORB dataset. Left: Samples from the training data. Right: sampled pixel means from he model. b) Performance on CIFAR10 patches. Left: Samples from the training data. Right: Sampled pixel means om the model. c) Frey faces data. Left: data samples. Right: model samples.

˛ D 1:5;

D1



data { i n t N; // number o f o b s e r v a t i o n s i n t x [ N ] ; // d i s c r e t e - v a l u e d o b s e r v a t i o n s } parameters { // l a t e n t v a r i a b l e , must be p o s i t i v e r e a l < l o w e r =0> t h e t a ; } model { // non - c o n j u g a t e p r i o r f o r l a t e n t v a r i a b l e theta ~ w e i b u l l ( 1 . 5 , 1) ;

xn

igure 5. Imputation results on MNIST digits. The first olumn shows the true data. Column 2 shows pixel locaons set as missing in grey. The [Kingma andremaining Wellingcolumns 2013]show mputations and denoising of the images for 15 iterations, arting left to right. Top: 60% missingness. Middle: 80% missingness. Bottom: 5x5 patch missing.

space

(d) 20-D latent space

NIST for different dimensionalities

N Figure 6. Two dimensional embedding of the MNIST data et al.to2014] set. Each[Rezende colour corresponds one of the digit classes.

6.5. Data Visualisation

}

// l i k e l i h o o d f o r ( n i n 1 :N) x [ n ] ~ poisson ( theta ) ;

Figure 2: Specifying a simple nonconjugate probability model in Stan. [Kucukelbir et al. 2015]

analysis posits a prior density p.✓/ on the latent variables. Combining the likelihood with the pri gives the joint density p.X; ✓/ D p.X j ✓/ p.✓/.

variable models such as DLGMs are often We usedfocus on approximate inference for differentiable probability models. These models have conti There is now a Latent flurry new work variational inference, making it for visualisationof of high-dimensional data sets.on We latent uous variables ✓. They also have a gradient of the log-joint with respect to the latent˚ variabl project the MNIST data set to a 2-dimensional latent r log p.X; ✓/. The gradient is valid within the support of the prior supp.p.✓// D ✓ j ✓ space and use this 2-D embedding as a visualisation of accurate, and applying it to more scalable, easier to derive, faster, more matics and experimental design. We show the ability R and p.✓/ > 0 ✓ R , where K is the dimension of the latent variable space. This support s the data. A 2-dimensional embedding of the MNIST f the model to impute missing data using the MNIST is important: it determines the support of the posterior density and plays a key role later in the pap data set is shown in figure 6. The classes separate ata set in figure 5. complicated We test the imputation abilitymodels and applications. into di↵erent regions indicating that such a toolWe canmake no assumptions about conjugacy, either full or conditional. nder two di↵erent missingness types (Little & Rubin,

„

✓ K

K

2

987): Missing-at-random (MAR), where we consider 0% and 80% of the pixels to be missing randomly, and „ Not Missing-at-random (NMAR), where we consider a quare region of✓the image to be missing. The model roduces very good completions in both test cases. here is uncertainty in the identity of the image. This expected and reflected in the errors in these complej j ons as the resampling procedure is run, and further emonstrates the ability of the model to capture the iversity of the underlying data. We do not integrate ver the missing values in our imputation procedure, „ that simulates a Markov chain ut use a procedure hat we show converges to the true marginal distribuon. The procedure to sample from the missing pixels iven the observed pixels is explained in appendix E.

be useful in gaining insight into the structure of highFor example, consider a model that contains a Poisson likelihood with unknown rate, p.x j ✓/. T dimensional data sets.

ntains a KL term that can often be Modern VI touches many important areas: probabilistic programming, he prior p (z) = N (0, I) and the learning, neural networks, convex optimization, Bayesian dimensionalityreinforcement of z. Let µ and d let µ and simply denote the

observed variable x is discrete; the latent rate ✓ is continuous and positive. Place a Weibull pri on ✓, defined over the positive real numbers. The resulting joint density describes a nonconjuga 7. Discussion differentiable probability model. (See Figure 2.) Its partial derivative @=@✓ p.x; ✓/ is valid within t support of the Weibull distribution, supp.p.✓// D RC ⇢ R. Because this model is nonconjugate, t Our algorithm generalises to a large class of models with continuous latent variables, which include Gausposterior is not a Weibull distribution. This presents a challenge for classical variational inferenc sian, non-negative or sparsity-promoting latent In variSection 2.3, we will see how handles this model.

statistics, and myriad applications. ables. For models with discrete latent variables (e.g.,

Many machine learning models are differentiable. For example: linear and logistic regression, matr sigmoid belief networks), policy-gradient approaches factorization with continuous or discrete measurements, linear dynamical systems, and Gaussian pr that improve upon the REINFORCE approach remain

but intelligent design is needed to cesses. Mixture models, hiddensome Markov models, and topicnewer models have discrete random variabl Our goal today the ismosttogeneral, teach you the basics, explain of the ideas, control the gradient-variance in high dimensionalMarginalizing setout these discrete variables renders these models differentiable. (We show an examp N (z; 0, I) dz tings. in Section 3.3.) However, marginalization is not tractable for all models, such as the Ising mod and to suggest These open areas of new research. sigmoid belief networks, and (untruncated) Bayesian nonparametric models. models are typically used with a large number

X

1

(µ2j +

2 j)

2.2

Variational Inference

Bayesian inference requires the posterior density p.✓ j X/, which describes how the latent variabl vary when conditioned on a set of observations X. Many posterior densities are intractable becau their normalization constants lack closed forms. Thus, we seek to approximate the posterior.

Variational Inference: Foundations and Modern Methods Part II: Mean-field VI and stochastic VI Jordan+, Introduction to Variational Methods for Graphical Models, 1999 Ghahramani and Beal, Propagation Algorithms for Variational Bayesian Learning, 2001 Hoffman+, Stochastic Variational Inference, 2013

Part III: Stochastic gradients of the ELBO Kingma and Welling, Auto-Encoding Variational Bayes, 2014 Ranganath+, Black Box Variational Inference, 2014 Rezende+, Stochastic Backpropagation and Approximate Inference in Deep Generative Models, 2014

Part IV: Beyond the mean field Agakov and Barber, An Auxiliary Variational Method, 2004 Gregor+, DRAW: A recurrent neural network for image generation, 2015 Rezende+, Variational Inference with Normalizing Flows, 2015 Ranganath+, Hierarchical Variational Models, 2015 Maaløe+, Auxiliary Deep Generative Models, 2016

Variational Inference: Foundations and Modern Methods p.z j x/ q.zI ⌫/





KL.q.zI ⌫ ⇤ / jj p.z j x//

⌫ init

VI approximates difficult quantities from complex models. With stochastic optimization we can „

scale up VI to massive data

„

enable VI on a wide class of difficult models

„

enable VI with elaborate and flexible families of approximations

PART II Mean-field variational inference and stochastic variational inference

Motivation: Topic Modeling

Topic models use posterior inference to discover the hidden thematic structure in a large collection of documents.

Example: Latent Dirichlet Allocation (LDA)

Documents exhibit multiple topics.

Example: Latent Dirichlet Allocation (LDA) Topics gene dna genetic .,,

Documents

0.04 0.02 0.01

life 0.02 evolve 0.01 organism 0.01 .,,

brain neuron nerve ...

0.04 0.02 0.01

data 0.02 number 0.02 computer 0.01 .,,

„

Each topic is a distribution over words

„

Each document is a mixture of corpus-wide topics

„

Each word is drawn from one of those topics

Topic proportions and assignments

Example: Latent Dirichlet Allocation (LDA) Topics

Documents

Topic proportions and assignments

„

But we only observe the documents; everything else is hidden.

„

So we want to calculate the posterior p(topics, proportions, assignments | documents) (Note: millions of documents; billions of latent variables)

LDA as a Graphical Model

Proportions parameter

Per-word topic assignment

Per-document topic proportions

˛

✓d

zd;n

Topic parameter

Observed word

Topics

wd;n



ˇk N D

K

„

Encodes assumptions about data with a factorization of the joint

„

Connects assumptions to algorithms for computing with data

„

Defines the posterior (through the joint)

Posterior Inference

˛

✓d

zd;n

wd;n

„



ˇk N D

K

The posterior of the latent variables given the documents is p(β, θ , z, w) p(β, θ , z | w) = R R P . z p(β, θ , z, w) β θ

„

We can’t compute the denominator, the marginal p(w).

„

We use approximate inference.

Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.

1

2

3

4

5

Game Season Team Coach Play Points Games Giants Second Players

Life Know School Street Man Family Says House Children Night

Film Movie Show Life Television Films Director Man Story Says

Book Life Books Novel Story Man Author House War Children

Wine Street Hotel House Room Night Place Restaurant Park Garden

6

7

8

9

10

Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator

Building Street Square Housing House Buildings Development Space Percent Real

Won Team Second Race Round Cup Open Game Play Win

Yankees Game Mets Season Run League Baseball Team Games Hit

Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers

11

12

13

14

15

Children School Women Family Parents Child Life Says Help Mother

Stock Percent Companies Fund Market Bank Investors Funds Financial Business

Church War Women Life Black Political Catholic Government Jewish Pope

Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition

Police Yesterday Man Officer Officers Case Found Charged Street Shot

Figure 5

Topics found in 1.8M articles from the New York Times

Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013).

a particular movie), our prediction of the rating depends on a linear combination of the user’s embedding and the movie’s embedding. We can also use these inferred representations to find

Mean-field VI and Stochastic VI

Subsample data

Infer local structure

Road map: „

Define the generic class of conditionally conjugate models

„

Derive classical mean-field VI

„

Derive stochastic VI, which scales to massive data

Update global structure

A Generic Class of Models Global variables

Local variables

ˇ xi

zi

n p(β, z, x) = p(β)

n Y i=1

p(zi , xi | β)

„

The observations are x = x1:n .

„

The local variables are z = z1:n .

„

The global variables are β.

„

The ith data point xi only depends on zi and β. Compute p(β, z | x).

A Generic Class of Models Global variables

Local variables

ˇ xi

zi

n p(β, z, x) = p(β)

n Y i=1

„

„

p(zi , xi | β)

A complete conditional is the conditional of a latent variable given the observations and other latent variables. Assume each complete conditional is in the exponential family, p(zi | β, xi ) = h(zi ) exp{η` (β, xi )> zi − a(η` (β, xi ))} p(β | z, x) = h(β) exp{ηg (z, x)> β − a(ηg (z, x))}.

A Generic Class of Models Global variables

Local variables

ˇ xi

zi

n p(β, z, x) = p(β)

n Y i=1

„

„

p(zi , xi | β)

A complete conditional is the conditional of a latent variable given the observations and other latent variable. The global parameter comes from conjugacy [Bernardo and Smith, 1994] Pn ηg (z, x) = α + i=1 t(zi , xi ), where α is a hyperparameter and t(·) are sufficient statistics for [zi , xi ].

A Generic Class of Models Global variables

Local variables

ˇ xi

zi

n p(β, z, x) = p(β)

n Y i=1

„

„

„

„

Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)

„

„

„

„

p(zi , xi | β) Dirichlet process mixtures, HDPs

Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)

Variational Inference

p.z j x/ q.zI ⌫/





KL.q.zI ⌫ ⇤ / jj p.z j x//

⌫ init

Minimize KL between q(β, z; ν) and the posterior p(β, z | x).

The Evidence Lower Bound

L (ν) = Eq [log p(β, z, x)] − Eq [log q(β, z; ν)]

„

KL is intractable; VI optimizes the evidence lower bound (ELBO) instead. ƒ ƒ

„

The ELBO trades off two terms. ƒ ƒ

„

It is a lower bound on log p(x). Maximizing the ELBO is equivalent to minimizing the KL.

The first term prefers q(·) to place its mass on the MAP estimate. The second term encourages q(·) to be diffuse.

Caveat: The ELBO is not convex.

Mean-field Variational Inference

ˇ

ˇ ELBO

zi

xi

i

zi

n „

„

„

We need to specify the form of q(β, z). The mean-field family is fully factorized, Qn q(β, z; λ, φ) = q(β; λ) i=1 q(zi ; φi ). Each factor is the same family as the model’s complete conditional, p(β | z, x) = h(β) exp{ηg (z, x)> β − a(ηg (z, x))} q(β; λ) = h(β) exp{λ> β − a(λ)}.

n

Mean-field Variational Inference

ˇ

ˇ ELBO

zi

xi

i

n „

Optimize the ELBO, L (λ, φ) = Eq [log p(β, z, x)] − Eq [log q(β, z)] .

„

„

Traditional VI uses coordinate ascent [Ghahramani and Beal, 2001]   λ∗ = Eφ ηg (z, x) ; φi∗ = Eλ [η` (β, xi )] Iteratively update each parameter, holding others fixed. ƒ

Notice the relationship to Gibbs sampling [Gelfand and Smith, 1990] .

ƒ

Caveat: The ELBO is not convex.

zi n

Mean-field Variational Inference for LDA

˛

�d

�d;n

�d

zd;n

�k

wd;n



ˇk

N D

K

„

The local variables are the per-document variables θd and zd .

„

The global variables are the topics β1 , . . . , βK .

„

The variational distribution is q(β, θ , z) =

K Y k=1

q(βk ; λk )

D Y d=1

q(θd ; γd )

N Y n=1

q(zd,n ; φd,n )

0.2 0.1 0.0

Probability

0.3

0.4

Mean-field Variational Inference for LDA

1 8 16 26 36 46 56 66 76 86 96 Topics

Mean-field Variational Inference for LDA

“Genetics” human genome dna genetic genes sequence gene molecular sequencing map information genetics mapping project sequences

“Evolution” “Disease” evolution disease evolutionary host species bacteria organisms diseases life resistance origin bacterial biology new groups strains phylogenetic control living infectious diversity malaria group parasite new parasites two united common tuberculosis

“Computers” computer models information data computers system network systems model parallel methods networks software new simulations

Classical Variational Inference

Input: data x, model p(β, z, x). Initialize λ randomly. repeat for each data point i do Set local parameter φi ← Eλ [η` (β, xi )]. end

Set global parameter λ←α+ until the ELBO has converged

Pn

i=1 Eφi

[t(Zi , xi )] .

A Generic Class of Models Global variables

Local variables

ˇ xi

zi

n p(β, z, x) = p(β)

n Y i=1

„

„

„

„

Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)

„

„

„

„

p(zi , xi | β) Dirichlet process mixtures, HDPs

Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)

Stochastic Variational Inference

˛

„

�d

�d;n

�d

zd;n

�k

wd;n N D



ˇk K

Classical VI is inefficient: ƒ ƒ ƒ

Do some local computation for each data point. Aggregate these computations to re-estimate global structure. Repeat.

„

This cannot handle massive data.

„

Stochastic variational inference (SVI) scales VI to massive data.

Stochastic Variational Inference

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7

K=8 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1

GLOBAL HIDDEN STRUCTURE 2 3 4 5 6

MASSIVE DATA

7 8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

28

Subsample data

Infer local structure

Update global structure

Stochastic Optimization

„

Replace the gradient with cheaper noisy estimates [Robbins and Monro, 1951]

„

Guaranteed to converge to a local optimum [Bottou, 1996]

„

Has enabled modern machine learning

Stochastic Optimization

„

With noisy gradients, update ˆ ν L (νt ) νt+1 = νt + ρt ∇

„

  ˆ ν L (ν) = ∇ν L (ν) Requires unbiased gradients, E ∇

„

Requires the step size sequence ρt follows the Robbins-Monro conditions

Stochastic Variational Inference

„

„

The natural gradient of the ELBO [Amari, 1998; Sato, 2001] € Š Pn ∗ ∇nat λ L (λ) = α + i=1 Eφi [t(Zi , xi )] − λ. Construct a noisy natural gradient, j ∼ Uniform(1, . . . , n)

ˆ nat L (λ) ∇ λ „

= α + nEφj∗ [t(Zj , xj )] − λ.

This is a good noisy gradient. ƒ ƒ

Its expectation is the exact gradient (unbiased). It only depends on optimized parameters of one data point (cheap).

Stochastic Variational Inference Input: data x, model p(β, z, x). Initialize λ randomly.

Set ρt appropriately.

repeat Sample j ∼ Unif(1, . . . , n).

  Set local parameter φ ← Eλ η` (β, xj ) . Set intermediate global parameter ˆ = α + nEφ [t(Zj , xj )]. λ

Set global parameter ˆ λ = (1 − ρt )λ + ρt λ. until forever

Stochastic Variational Inference

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7

K=8 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1

GLOBAL HIDDEN STRUCTURE 2 3 4 5 6

MASSIVE DATA

7 8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

28

Subsample data

Infer local structure

Update global structure

Stochastic Variational Inference in LDA

˛

�d

�d;n

�d

zd;n

�k

wd;n N D



ˇk K

„

Sample a document

„

Estimate the local variational parameters using the current topics

„

Form intermediate topics from those local parameters

„

Update topics as a weighted average of intermediate and current topics

Stochastic Variational Inference in LDA

900

Online 98K

Perplexity

850 800 750

Batch 98K

Online 3.3M

700 650 600 103.5

Documents analyzed

Top eight words

2048

4096

104

104.5

105

Documents seen (log scale) 8192

systems systems service road health systems made communication health service service companies announced billion market national language communication west care company language road billion

12288

16384

105.5 32768

106

106.5

49152

65536

service service business business business systems companies service service industry companies systems companies companies service business business industry industry companies company company company services services billion industry management company company health market systems management management industry billion services public public

[Hoffman et al., 2010]

Annual Review of Statistics and Its Application 2014.1:203-232. Downloaded from www.annualreviews.org by Princeton University Library on 01/09/14. For personal use only.

1

2

3

4

5

Game Season Team Coach Play Points Games Giants Second Players

Life Know School Street Man Family Says House Children Night

Film Movie Show Life Television Films Director Man Story Says

Book Life Books Novel Story Man Author House War Children

Wine Street Hotel House Room Night Place Restaurant Park Garden

6

7

8

9

10

Bush Campaign Clinton Republican House Party Democratic Political Democrats Senator

Building Street Square Housing House Buildings Development Space Percent Real

Won Team Second Race Round Cup Open Game Play Win

Yankees Game Mets Season Run League Baseball Team Games Hit

Government War Military Officials Iraq Forces Iraqi Army Troops Soldiers

11

12

13

14

15

Children School Women Family Parents Child Life Says Help Mother

Stock Percent Companies Fund Market Bank Investors Funds Financial Business

Church War Women Life Black Political Catholic Government Jewish Pope

Art Museum Show Gallery Works Artists Street Artist Paintings Exhibition

Police Yesterday Man Officer Officers Case Found Charged Street Shot

Figure 5

Topics found in a corpus of 1.8 million articles from the New York Times. Modified from Hoffman et al. (2013). Topics using the HDP, found in 1.8M articles from the New York Times

a particular movie), our prediction of the rating depends on a linear combination of the user’s embedding and the movie’s embedding. We can also use these inferred representations to find

SVI scales many models

Subsample data

„

„

„

„

Infer local structure

Bayesian mixture models Time series models (HMMs, linear dynamic systems) Factorial models Matrix factorization (factor analysis, PCA, CCA)

Update global structure

„

„

„

„

Dirichlet process mixtures, HDPs Multilevel regression (linear, probit, Poisson) Stochastic block models Mixed-membership models (LDA and some variants)

apanese

activated tyrosine phosphorylation activation phosphorylation kinase

science scientists says research people

research funding support nih program

united states women universities students education

virus hiv aids infection viruses patients disease treatment drugs clinical

Adygei

BalochiBantuKenya BantuSouthAfrica Basque

Bedouin

BiakaPygmy Brahui

Burusho Cambodian Colombian Dai Daur

Druze

French

Han Han−NChina HazaraHezhen Italian Japanese

amino acids cdna sequence isolated protein

wild type mutant mutations mutants mutation

bacteria bacterial host resistance parasite

mice antigen t cells antigens immune response

Kalash Karitiana Lahu Makrani Mandenka

proteins protein binding domain domains

receptor receptors ligand ligands apoptosis

cells cell expression cell lines bone marrow

enzyme enzymes iron active site reduction

rna dna rna polymerase cleavage site

sequence sequences genome dna sequencing

gene disease mutations families mutation

genetic population populations differences variation

Palestinian

surface liquid surfaces fluid model

fossil record birds fossils dinosaurs fossil

development embryos drosophila genes expression

NaxiOrcadian Oroqen

computer problem information computers problems

surface tip image sample device

Papuan Pathan

Pima

physicists particles physics particle experiment

reaction reactions molecule molecules transition state pressure high pressure pressures core inner core

mantle crust upper mantle meteorites ratios

species forest forests populations ecosystems

ancient found impact million years ago africa

synapses ltp glutamate synaptic neurons

materials organic polymer polymers molecules

laser optical light electrons quantum

magnetic magnetic field spin superconductivity superconducting

plants plant gene genes arabidopsis

cells proteins researchers protein found

MayaMbutiPygmy Melanesian Miao Mongola Mozabite

neurons stimulus motor visual cortical

brain memory subjects left task

p53 cell cycle activity cyclin regulation

earthquake earthquakes fault images data

volcanic deposits magma eruption volcanism

Russian San Sardinian

climate ocean ice changes climate change

She

stars astronomers universe galaxies galaxy

sun solar wind earth planets planet

co2 carbon carbon dioxide methane water

ozone atmospheric measurements stratosphere concentrations

Sindhi Surui Tu TujiaTuscan UygurXibo

Yakut

Yi

Yoruba

pops 1 2

prob

3 4 5 6 7

Kalash Karitiana Lahu Makrani Mandenka

MayaMbutiPygmy Melanesian Miao Mongola Mozabite

NaxiOrcadian Oroqen

Palestinian

Papuan Pathan

Pima

Russian San Sardinian

She

Sindhi Surui Tu TujiaTuscan UygurXibo

Yakut

Yi

Yoruba

pops 1 2 3 4 5 6 7

PART III Stochastic Gradients of the ELBO

Review: The Promise

KNOWLEDGE & QUESTION

DATA

R A. 7Aty

This content downloaded from 128.59.38.144 on Thu, 12 Nov 2015 01:49:31 UTC All use subject to JSTOR Terms and Conditions

K=7 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH

pops 1 2 3 4 5 6 7

K=8

Make assumptions

CEU

Discover patterns

pops 1 2 3 4 5 6 7 8

K=9 LWK

YRI

ACB

ASW

CDX

CHB

CHS

JPT

KHV

CEU

FIN

GBR

IBS

TSI

MXL

PUR

CLM

PEL

GIH pops 1 2 3 4 5 6 7 8 9

Figure S2: Population structure inferred from the TGP data set using the TeraStructure algorithm at three values for the number of populations K. The visualization of the ✓’s in the Figure shows patterns consistent with the major geographical regions. Some of the clusters identify a specific region (e.g. red for Africa) while others represent admixture between regions (e.g. green for Europeans and Central/South Americans). The presence of clusters that are shared between different regions demonstrates the more continuous nature of the structure. The new cluster from K = 7 to K = 8 matches structure differentiating between American groups. For K = 9, the new cluster is unpopulated.

28

„

Realized for conditionally conjugate models

„

What about the general case?

Predict & Explore

The Variational Inference Recipe Start with a model:

p(z, x)

The Variational Inference Recipe Choose a variational approximation:

q(z; ν)

The Variational Inference Recipe Write down the ELBO:

L (ν) = Eq(z;ν) [log p(x, z) − log q(z; ν)]

The Variational Inference Recipe Compute the expectation(integral):

Example: L (ν) = xν2 + log ν

The Variational Inference Recipe Take derivatives:

Example: ∇ν L (ν) = 2xν +

1 ν

The Variational Inference Recipe Optimize:

νt+1 = νt + ρt ∇ν L

The Variational Inference Recipe

p(x, z) q(z; ⌫)

Z

q.zI ⌫/

(· · · )q(z; ⌫)dz

r⌫

Example: Bayesian Logistic Regression

„

Data pairs yi , xi

„

xi are covariates

„

yi are label

„

z is the regression coefficient

„

Generative process p(z) ∼ N(0, 1)

p(yi | xi , z) ∼ Bernoulli(σ(zxi ))

VI for Bayesian Logistic Regression

Assume: „

We have one data point (y, x)

„

x is a scalar

„

The approximating family q is the normal; ν = (µ, σ2 )

The ELBO is L (µ, σ2 ) = Eq [log p(z) + log p(y | x, z) − log q(z)]

VI for Bayesian Logistic Regression

L (µ, σ2 ) =

Eq [log p(z) − log q(z) + log p(y | x, z)]

VI for Bayesian Logistic Regression

L (µ, σ2 ) =

=

Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2

VI for Bayesian Logistic Regression

L (µ, σ2 ) =

= =

Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2

VI for Bayesian Logistic Regression

L (µ, σ2 ) =

= = =

Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2 1 1 2 2 − (µ + σ ) + log σ2 + yxµ − Eq [log(1 + exp(xz))] 2 2

VI for Bayesian Logistic Regression

L (µ, σ2 ) =

= = =

Eq [log p(z) − log q(z) + log p(y | x, z)] 1 1 − (µ2 + σ2 ) + log σ2 + Eq [log p(y | x, z)] + C 2 2 1 2 1 2 − (µ + σ ) + log σ2 + Eq [yxz − log(1 + exp(xz))] 2 2 1 1 2 2 − (µ + σ ) + log σ2 + yxµ − Eq [log(1 + exp(xz))] 2 2

We are stuck. 1. We cannot analytically take that expectation. 2. The expectation hides the objectives dependence on the variational parameters. This makes it hard to directly optimize.

Options?

„

Derive a model specific bound: [Jordan and Jaakola; 1996], [Braun and McAuliffe; 2008], others

„

More general approximations that require model-specific analysis: [Wang and Blei; 2013], [Knowles and Minka; 2011]

Nonconjugate Models

„

Nonlinear Time series Models

„

Discrete Choice Models

„

Deep Latent Gaussian Models

„

Bayesian Neural Networks

„

„

„

Models with Attention (such as DRAW) Generalized Linear Models (Poisson Regression) Stochastic Volatility Models

„

„

„

Deep Exponential Families (e.g. Sparse Gamma or Poisson) Correlated Topic Model (including nonparametric variants) Sigmoid Belief Network

We need a solution that does not entail model specific work

Black Box Variational Inference (BBVI)

Black box variational inference

REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES

MASSIVE DATA

FAMILIES

ANY MODEL BLACK BOX VARIATIONAL INFERENCE

I Sample from

p.ˇ; z j x/

q. /

I Form noisy gradients without model-specific computation I Use stochastic optimization

The Problem in the Classical VI Recipe

p(x, z) q(z; ⌫)

Z

q.zI ⌫/

(· · · )q(z; ⌫)dz

r⌫

The New VI Recipe

p(x, z) q(z; ⌫)

r⌫

Use stochastic optimization!

Z

q.zI ⌫/

(· · · )q(z; ⌫)dz

Computing Gradients of Expectations „

Define g(z, ν) = log p(x, z) − log q(z; ν)

„

What is ∇ν L

Z

∇ν L = ∇ν

q(z; ν)g(z, ν)dz

Z

= Z =

∇ν q(z; ν)g(z, ν) + q(z; ν)∇ν g(z, ν)dz q(z; ν)∇ν log q(z; ν)g(z, ν) + q(z; ν)∇ν g(z, ν)dz

= Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Using ∇ν log q =

∇ν q q

Roadmap

„

Score Function Gradients

„

Pathwise Gradients

„

Amortized Inference

Score Function Gradients of the ELBO

Score Function Estimator Recall ∇ν L = Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Simplify: Eq [∇ν g(z, ν)] = Eq [∇ν log q(z; ν)] = 0 Gives the gradient: ∇ν L = Eq(z;ν) [∇ν log q(z; ν)(log p(x, z) − log q(z; ν))]

Sometimes called likelihood ratio or REINFORCE gradients [Glynn 1990; Williams, 1992; Wingate+ 2013; Ranganath+ 2014; Mnih+ 2014]

Noisy Unbiased Gradients

Gradient: Eq(z;ν) [∇ν log q(z; ν)(log p(x, z) − log q(z; ν))] Noisy unbiased gradients with Monte Carlo!

1X ∇ν log q(zs ; ν)(log p(x, zs ) − log q(zs ; ν)), S s=1 S

where zs ∼ q(z; ν)

Basic BBVI

Algorithm 1: Basic Black Box Variational Inference Input : Model log p(x, z), Variational approximation q(z; ν) Output : Variational Parameters: ν while not converged do z[s] ∼ q // Draw S samples from q ρ = t-th value of a Robbins Monro sequence PS ν = ν + ρ 1S s=1 ∇ν log q(z[s]; ν)(log p(x, z[s]) − log q(z[s]; ν)) t=t+1 end

The requirements for inference

The noisy gradient: 1X ∇ν log q(zs ; ν)(log p(x, zs ) − log q(zs ; ν)), S s=1 S

where zs ∼ q(z; ν) To compute the noisy gradient of the ELBO we need „ „ „

Sampling from q(z) Evaluating ∇ν log q(z; ν)

Evaluating log p(x, z) and log q(z)

There is no model specific work: black box criteria are satisfied

Black Box Variational Inference

Black box variational inference

REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES

MASSIVE DATA

FAMILIES

ANY MODEL BLACK BOX VARIATIONAL INFERENCE

I Sample from

p.ˇ; z j x/

q. /

I Form noisy gradients without model-specific computation I Use stochastic optimization

Problem: Basic BBVI doesn’t work

Variance of the gradient can be a problem Varq(z;ν) = Eq(z;ν) [(∇ν log q(z; ν)(log p(x, z) − log q(z; ν)) − ∇ν L )2 ]. 2.0

PDF Abs Mu Score

1.5 1.0 0.5 0.0

2.0

1.5

1.0

0.5

0.0

0.5

1.0

1.5

2.0

Intuition: Sampling rare values can lead to large scores and thus high variance

Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4

PDF f = x + x2 fˆ; h = x2 fˆ; h = f

3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5 „ „ „

0.0

0.5

1.0

1.5

2.0

h is a function of our choice a is chosen to minimize the variance Good h have high correlation with the original function f

Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4

PDF f = x + x2 fˆ; h = x2 fˆ; h = f

3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5 „ „ „

0.0

0.5

1.0

1.5

2.0

For variational inference we need functions with known q expectation Set h as ∇ν log q(z; ν) Simple as Eq [∇ν log q(z; ν)] = 0 for any q

Solution: Control Variates Replace with f with ˆf where E[ˆf (z)] = E[f (z)]. General such class: ˆf (z) ¬ f (z) − a(h(z) − E[h(z)]) 6 5 4

PDF f = x + x2 fˆ; h = x2 fˆ; h = f

3 2 1 0 −1 −2.0 −1.5 −1.0 −0.5

0.0

0.5

1.0

1.5

2.0

Many of the other techniques from Monte Carlo can help: „ Importance Sampling, Quasi Monte Carlo, Rao-Blackwellization [Ruiz+ 2016; Ranganath+2014; Titsias+2015; Mnih+2016]

Nonconjugate Models

„

Nonlinear Time series Models

„

Discrete Choice Models

„

Deep Latent Gaussian Models

„

Bayesian Neural Networks

„

„

„

Models with Attention (such as DRAW) Generalized Linear Models (Poisson Regression) Stochastic Volatility Models

„

„

„

Deep Exponential Families (e.g. Sparse Gamma or Poisson) Correlated Topic Model (including nonparametric variants) Sigmoid Belief Network

We can design models based on data rather than inference.

More Assumptions?

The current black box criteria „ „ „

Sampling from q(z) Evaluating ∇ν log q(z; ν)

Evaluating log p(x, z) and log q(z)

Can we make additional assumptions that are not too restrictive?

Pathwise Gradients of the ELBO

Pathwise Estimator

Assume 1. z = t(ε, ν) for ε ∼ s(ε) implies z ∼ q(z; ν) Example: ε ∼ Normal(0, 1) z = εσ + µ

→ z ∼ Normal(µ, σ2 ) 2. log p(x, z) and log q(z) are differentiable with respect to z

Pathwise Estimator Recall ∇ν L = Eq(z;ν) [∇ν log q(z; ν)g(z, ν) + ∇ν g(z, ν)] Rewrite using using z = t(ε, ν) ∇ν L = Es(ε) [∇ν log s(ε)g(t(ε, ν), ν) + ∇ν g(t(ε, ν), ν)] To differentiate: ∇L (ν) = Es(ε) [∇ν g(t(ε, ν), ν)]

= Es(ε) [∇z [log p(x, z) − log q(z; ν)]∇ν t(ε, ν) − ∇ν log q(z; ν)]

= Es(ε) [∇z [log p(x, z) − log q(z; ν)]∇ν t(ε, ν)] This is also known as the reparameterization gradient.

[Glasserman 1991; Fu 2006; Kingma+ 2014; Rezende+ 2014; Titsias+ 2014]

ion (7) both lead to unbiased estimates of the exact gradient. Variance Comparison require the gradient of the model and thus applies to more gh variance. 103 101 10

1

10

3

100 101 Number of

Pathwise Score Function Score Function with Control Variate

102 103 samples

(b) Multivariate Nonlinear Regression Model

[Kucukelbir+ 2016]

ator variances. The gradient estimator exhibits lower over, it does not require control variate variance reduction, tions.

Score Function Estimator vs. Pathwise Estimator

Score Function „ Differentiates the density ∇ν q(z; ν) „

„

„

Works for discrete and continuous models

Pathwise „ Differentiates the function ∇z [log p(x, z) − log q(z; ν)] „

„

Works for large class of variational approximations Variance can be a big problem

„

Requires differentiable models Requires variational approximation to have form z = t(ε, ν) Generally better behaved variance

Amortized Inference

Hierarchical Models

A generic class of models Global variables

Local variables

ˇ xi

zi

n p( , z, x) = p( )

n Y i=1

p(zi , xi | )

Ñ

Bayesian mixture models

Ñ

Dirichlet process mix

Ñ

Time series models

Ñ

Multilevel regression

Mean Field Variational Approximation

ˇ

ˇ ELBO

zi

xi

i

n

zi n

SVI: Revisited Input: data x, model p(β, z, x). Initialize λ randomly.

Set ρt appropriately.

repeat Sample j ∼ Unif(1, . . . , n).

  Set local parameter φ ← Eλ η` (β, xj ) .

Set intermediate global parameter ˆ = α + nEφ [t(Zj , xj )]. λ

Set global parameter until forever

ˆ λ = (1 − ρt )λ + ρt λ.

SVI: The problem Input: data x, model p(β, z, x). Initialize λ randomly.

Set ρt appropriately.

repeat Sample j ∼ Unif(1, . . . , n).

  Set local parameter φ ← Eλ η` (β, xj ) .

Set intermediate global parameter

ˆ = α + nEφ [t(Zj , xj )]. λ

Set global parameter

ˆ λ = (1 − ρt )λ + ρt λ.

until forever „ „

These expectations are no longer tractable Inner stochastic optimization needed for each data point.

SVI: The problem Input: data x, model p(β, z, x). Initialize λ randomly.

Set ρt appropriately.

repeat Sample j ∼ Unif(1, . . . , n).

  Set local parameter φ ← Eλ η` (β, xj ) .

Set intermediate global parameter

ˆ = α + nEφ [t(Zj , xj )]. λ

Set global parameter

ˆ λ = (1 − ρt )λ + ρt λ.

until forever Idea: Learn a mapping f from xi to φi

Amortizing Inference

ELBO: – L (λ, φ1...n ) = Eq [log p(β, z, x)] − Eq log q(β; λ) + Amortizing the ELBO with inference network f : – L (λ, θ ) = Eq [log p(β, z, x)] − Eq log q(β; λ) +

n X i=1

n X

™ q(zi ; φi )

i=1

™ q(zi | xi ; φi = fθ (xi ))

[Dayan+ 1995; Heess+ 2013; Gershman+ 2014, many others]

Amortized SVI Input: data x, model p(β, z, x). Initialize λ randomly.

Set ρt appropriately.

repeat Sample β ∼ q(β; λ).

Sample j ∼ Unif(1, . . . , n).

Sample zj ∼ q(zj | xj ; φθ (xj ).

Compute stochastic gradients ∇ˆλ L = ∇λ log q(β; λ)(log p(β) + n log p(xj , zj | β) − log q(β))

∇ˆθ L = n∇θ log q(zj | xj ; θ )(log p(xj , zj | β) − log q(zj | xk ; θ )) Update λ = λ + ρt ∇ˆλ

θ = θ + ρt ∇ˆθ . until forever

A computational-statistical tradeoff „

Amortized inference is faster, but admits a smaller class of approximations

„

The size of the smaller class depends on the flexibility of f

n Y

i=1

q(zi ;

i)

n Y

i=1

q(zi |xi ; f✓ (xi ))

Example: Variational Autoencoder (VAE)

z

p(z) = Normal(0, 1)

x

p(x|z) = Normal(µ (z),

2

(z))

µ and σ2 are deep networks with parameters β. [Kingma+ 2014; Rezende+ 2014]

Example: Variational Autoencoder (VAE)

z

z ~ q(z | x)

Model p(x |z)

Inference Network q(z |x)

2

q(z|x) = Normal(f✓µ (x), f✓ (x))

x ~ p(x | z) Data x

All functions are deep networks

Example: Variational Autoencoder (VAE)

Analogies Analogy-making

Rules of Thumb for a New Model

If log p(x, z) is z differentiable „

Try out an approximation q that is reparameterizable

If log p(x, z) is not z differentiable „

Use score function estimator with control variates

„

Add further variance reductions based on experimental evidence

Rules of Thumb for a New Model

If log p(x, z) is z differentiable „

Try out an approximation q that is reparameterizable

If log p(x, z) is not z differentiable „

Use score function estimator with control variates

„

Add further variance reductions based on experimental evidence

General Advice: „

Use coordinate specific learning rates (e.g. RMSProp, AdaGrad)

„

Annealing + Tempering

„

Consider parallelizing across samples from q

Software

Systems with Variational Inference: „

Venture, WebPPL, Edward, Stan, PyMC3, Infer.net, Anglican

Good for trying out lots of models

Differentiation Tools: „

Theano, Torch, Tensorflow, Stan Math, Caffe

Can lead to more scalable implementations of individual models

PART IV Beyond the Mean Field

BlackVariational box variational inference Review: Bound and Optimisation

REUSABLE REUSABLE REUSABLE VARIATIONAL VARIATIONAL VARIATIONAL FAMILIES FAMILIES

MASSIVE DATA

FAMILIES

ANY MODEL BLACK BOX VARIATIONAL INFERENCE

„

I Sample from q. / and variational inference. Probabilistic modelling

„

ScalableI inference through stochastic optimisation.

„

p.ˇ; z j x/

I Form noisy gradients without model-specific computation

Use stochastic optimization

Black-box variational inference: Non-conjugate models, Monte Carlo gradient estimators and amortised inference. These advances empower us with new way to design more flexible approximate posterior distributions q(z)

Mean-field Approximations

Fully-factorised

p.z j x/ q.zI ⌫/





z2

KL.q.zI ⌫ ⇤ / jj p.z j x//

z1

⌫ init

qM F (z|x) =

z3 Y

q(zk )

k

Key part of algorithm is the choice of approximate posterior q(z). log p(x) ≥ L = Eq(z|x) [log p(x, z)] −Eq(z|x) [log q(z|x)] | {z }| {z } Expected likelihood

Entropy

Mean-Field Posterior Approximations

Deep Latent Gaussian Model Latent variable model p(x,z)

p(z)

z

x p(x|z)

Mean-field or fully-factorised posterior is usually not sufficient

Real-world Posterior Distributions

Deep Latent Gaussian Model Latent variable model p(x,z)

p(z)

z

x p(x|z)

Complex dependencies · Non-Gaussian distributions · Multiple modes

Families of Approximate Posteriors Two high-level goals: „

Build richer approximate posterior distributions.

„

Maintain computational efficiency and scalability.

Families of Approximate Posteriors Two high-level goals: „

Build richer approximate posterior distributions.

„

Maintain computational efficiency and scalability. True Posterior

Fully-factorised

z2

z2

z1

z3 Most Expressive

q ⇤ (z|x) / p(x|z)p(z)

z1

z3 Least Expressive

qM F (z|x) =

Y k

q(zk )

Families of Approximate Posteriors Two high-level goals: „

Build richer approximate posterior distributions.

„

Maintain computational efficiency and scalability. True Posterior

Fully-factorised

z2

z2

z1

z3 Most Expressive

q ⇤ (z|x) / p(x|z)p(z)

z1

z3 Least Expressive

qM F (z|x) =

Y k

Same as the problem of specifying a model of the data itself.

q(zk )

Structured Posterior Approximations True Posterior

Structured Approx.

z2 z1

Fully-factorised

z2

z2 z3

z1 z1

z3

Most Expressive

q ⇤ (z|x) / p(x|z)p(z)

z3

Least Expressive

q(z) =

Y k

qk (zk |{zj }j6=k )

qM F (z|x) =

Y

q(zk )

k

Structured mean field: Introduce any form of dependency to provide a richer approximating class of distributions. [Saul and Jordan, 1996.]

Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ}

Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ} Covariance models: Structure of covariance Σ describes dependency. Full covariance is richest, but computationally expensive. diag(↵1 , . . . , ↵K )

diag(↵1 , . . . , ↵K ) +uu>

Mean-field

+

Rank-1

diag(↵1 , . . . , ↵K ) P + j uj uj >

UU>

104 100 96 92 88 84

+…+

Rank-J

Test neg. marginal likelihood

+

Full

Rank1

Diag

Wake−Sleep

FA

Gaussian Approximate Posteriors Use a correlated Gaussian: qG (z; ν)=N (z|µ, Σ) Variational parameters ν = {µ, Σ} Covariance models: Structure of covariance Σ describes dependency. Full covariance is richest, but computationally expensive. diag(↵1 , . . . , ↵K )

diag(↵1 , . . . , ↵K ) +uu>

Mean-field

+

Rank-1

diag(↵1 , . . . , ↵K ) P + j uj uj >

UU>

104 100 96 92 88 84

+…+

Rank-J

Test neg. marginal likelihood

+

Rank1

Full

Approximate posterior is always Gaussian.

Diag

Wake−Sleep

FA

Beyond Gaussian Approximations Autoregressive distributions: Impose an ordering and non-linear dependency on all preceding variables. Y qAR (z; ν) = qk (zk |z