[Stable]

The Dirichlet distribution is a multivariate generalisation of the Beta distribution. It is the conjugate prior of the Categorical and Multinomial distributions, and describes a probability distribution over the \((k-1)\)-simplex — the set of \(k\)-dimensional vectors whose components are non-negative and sum to one.

dist_dirichlet(alpha)

Arguments

alpha

A list of positive numeric concentration vectors.

Details

We recommend reading this documentation on pkgdown which renders math nicely. https://pkg.mitchelloharawild.com/distributional/reference/dist_dirichlet.html

In the following, let \(\mathbf{X} = (X_1, \ldots, X_k)\) be a Dirichlet random variable with concentration parameter alpha = \(\boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_k)\), where each \(\alpha_i > 0\).

Support: \(\mathbf{x}\) on the \((k-1)\)-simplex, i.e. \(x_i \geq 0\) and \(\sum_{i=1}^k x_i = 1\).

Mean: \(E(X_i) = \frac{\alpha_i}{\alpha_0}\) where \(\alpha_0 = \sum_{i=1}^k \alpha_i\).

Variance:

$$ \mathrm{Var}(X_i) = \frac{\alpha_i(\alpha_0 - \alpha_i)}{\alpha_0^2(\alpha_0 + 1)} $$

Covariance:

$$ \mathrm{Cov}(X_i, X_j) = \frac{-\alpha_i \alpha_j}{\alpha_0^2(\alpha_0 + 1)}, \quad i \neq j $$

Probability density function (p.d.f):

$$ f(\mathbf{x}) = \frac{1}{B(\boldsymbol{\alpha})} \prod_{i=1}^k x_i^{\alpha_i - 1} $$

where \(B(\boldsymbol{\alpha}) = \frac{\prod_{i=1}^k \Gamma(\alpha_i)}{\Gamma(\alpha_0)}\) is the multivariate Beta function.

Examples

dist <- dist_dirichlet(alpha = list(c(2, 5, 3)))
dist
#> <distribution[1]>
#> [1] Dirichlet[3]

mean(dist)
#>      [,1] [,2] [,3]
#> [1,]  0.2  0.5  0.3
variance(dist)
#>            [,1]       [,2]       [,3]
#> [1,] 0.01454545 0.02272727 0.01909091
support(dist)
#> <support_region[1]>
#> [1] [0,1]^3
generate(dist, 10)
#> [[1]]
#>             [,1]      [,2]       [,3]
#>  [1,] 0.10198579 0.4725398 0.42547441
#>  [2,] 0.14742814 0.8055114 0.04706045
#>  [3,] 0.20673588 0.5102639 0.28300019
#>  [4,] 0.18071082 0.5001574 0.31913176
#>  [5,] 0.09835732 0.6077442 0.29389843
#>  [6,] 0.08591234 0.8300227 0.08406500
#>  [7,] 0.40678183 0.2546595 0.33855863
#>  [8,] 0.07476767 0.6300937 0.29513864
#>  [9,] 0.33052580 0.4766198 0.19285445
#> [10,] 0.14007871 0.6494068 0.21051445
#> 

density(dist, cbind(0.2, 0.5, 0.3))
#> [1] 8.505
density(dist, cbind(0.2, 0.5, 0.3), log = TRUE)
#> [1] 2.140654