I am using JAGS to estimate a Dirichlet Process Mixture of Normals. The code works well and the estimated density is accurate. However, I would like to know which component each observation is assigned to and the corresponding parameters for that component. This is hard due to the label switching problem in mixture models.
Any suggestions? Putting an order restriction on the means is one recommended way to avoid this but I have not been able to implement it well in JAGS.
Below is the R code followed by the JAGS model.
library(R2jags)
library(mixtools)
nobs=200
popmn=c(-6,-3,3); popstd=c(1.5,1.5,.7); popprob=c(.6,.3,.1)
x=rnormmix(n=nobs,lambda=popprob,mu=popmn,sigma=popstd)
xgrid=seq(-8,8,length=500)
###true density
densTrue=rep(0,length(xgrid))
for(i in 1:length(xgrid))
densTrue[i]=sum(popprob*dnorm(xgrid[i],popmn,popstd))
densTrue=densTrue/sum(densTrue)
###Call JAGS
K=10
dataList = list(y=x,n=200,K=K,grid=xgrid,n2=length(xgrid))
jags.inits <- function(){
list('mu'=rep(0,K),'tau'=rep(1,K),'prob'=rep(1/K,K))
}
parameters=c('f')
jags.fit=jags(data=dataList,inits=jags.inits,parameters,"jagsModel",n.chains=1,n.iter=800,n.burnin=300)
jags.fit
#contruct density
DPDens2=jags.fit$BUGSoutput$mean$f
DPDens2=DPDens2/sum(DPDens2)
###Plot
plot(xgrid,trueDens,type='l')
lines(xgrid,DPDens2,col=2)
JAGS Model:
model{
for(i in 1:n){
y[i] ~ dnorm(mu[z[i]],tau[z[i]])
z[i] ~ dcat(prob[])
for(j in 1:K){
f.d[i,j] <- prob[j]*coi*exp(-0.5*pow(y[i]-mu[z[i]],2)*tau[z[i]])*sqrt(tau[z[i]])
}
}
prob[1:K] ~ ddirch(prob.par[])
for(i in 1:K){
repeat[i] <- step(1-dist[i])
prob.par[i] <- alpha/K
mu[i] ~ dnorm(0,.001)
tau[i] ~ dgamma(0.001,0.001)
}
# total is number of active "used" components
total <- sum(repeat[])
dist[1] <- 1
for(i in 2:K){
dist[i] <- sum(count[i,1:i])
count[i,i] <- 1
for(j in 1:i-1){
count[i,j] <- equals(z[i],z[j])
}
}
alpha <- 1
###Density
for(i in 1:n2){
for(j in 1:K){
f.p[i,j] <- prob[j]*coi*exp(-0.5*pow(grid[i]-mu[j],2)*tau[j])*sqrt(tau[j])
}
# f[i] is predictive density over grid[i]
f[i] <- sum(f.p[i,1:K])
}
coi <- 0.3989422804
}