Manipulating Discrete Joint Distributions

Robin Evans

2022-06-08

Marginal and Conditional Distributions

library(rje)

First let’s generate a joint probability distribution for a \(2 \times 2 \times 2 \times 2\)-table.

set.seed(123)
p = rprobdist(2, 4)

We can easily calculate the marginal distribution for the first two variables:

marginTable(p, 1:2)
##           [,1]      [,2]
## [1,] 0.1095329 0.4286592
## [2,] 0.2429444 0.2188636

Note that the function base::margin.table() performs the same function as marginTable(), but is not as fast. Output is ordered according to how the variables are entered into the function:

marginTable(p, 2:1)
##           [,1]      [,2]
## [1,] 0.1095329 0.2429444
## [2,] 0.4286592 0.2188636

but this can be over-ridden by setting the argument order=FALSE.

We can also obtain conditional distributions:

conditionTable(p, 3, 1)
##           [,1]      [,2]
## [1,] 0.3399897 0.5162332
## [2,] 0.6600103 0.4837668

conditionTable() orders output with the ‘free’ variables first (as ordered in the argument variables) followed by the conditioning variables. Sometimes it’s useful to keep a conditional (or marginal) distribution in the same form as the original table, even for variables which are removed

conditionTable2(p, 3, 1)
## , , 1, 1
## 
##           [,1]      [,2]
## [1,] 0.3399897 0.3399897
## [2,] 0.5162332 0.5162332
## 
## , , 2, 1
## 
##           [,1]      [,2]
## [1,] 0.6600103 0.6600103
## [2,] 0.4837668 0.4837668
## 
## , , 1, 2
## 
##           [,1]      [,2]
## [1,] 0.3399897 0.3399897
## [2,] 0.5162332 0.5162332
## 
## , , 2, 2
## 
##           [,1]      [,2]
## [1,] 0.6600103 0.6600103
## [2,] 0.4837668 0.4837668

Interventions

In causal inference it is common to want to know what happens if we intervene on a variable under a certain causal ordering. This is effectively just knowing about a joint distribution after dividing by a particular conditional distribution.

p_int = interventionTable(p, 3, 1:2)
## check this is p(1,2) * p(4 | 1, 2, 3)
p_int2 = conditionTable2(p, 1:2, c())*conditionTable2(p, 4, 1:3)
all.equal(p_int, p_int2)
## [1] TRUE

Multiple distributions

#rprobMat(100, dim=2, d=4)

Reconstructing Joint Distributions

When dealing with margins of multivariate distributions, it can be useful to be able to repeat probabilities to match the pattern of a joint distribution. In particular if we are given various conditional distributions (say from a Bayesian network model), we may wish to multiply them together to obtain the joint distribution.

For example, the model in which \(X_2\) is independent of \(X_3\) given \(X_1\) might be stored as the conditional probability tables \(P(X_1)\), \(P(X_2 | X_1)\) and \(P(X_3 | X_1)\). In order to reconstruct the joint distribution \(P(X_1,X_2,X_3)\), one needs to multiply \[ P(X_1=x_1,X_2=x_2,X_3=x_3) = P(X_1=x_1) \cdot P(X_2 = x_2 | X_1=x_1) \cdot P(X_3 = x_3 | X_1=x_1), \] so that the values of \(x_1,x_2,x_3\) match.

To use R’s vectorization for this we must turn the probability tables into vectors indexed by \((x_1,x_2,x_3)\), regardless of which variables are actually represented in the table; if a variable is not represented then values will be repeated. The indexing should be in reverse lexicographical order (i.e. first index changes fastest: 000, 100, 010, 110, …, 111), which is the way arrays are stored in R.

For example, if \(X_1,X_2,X_3\) are all binary (i.e. take values in \(\{0,1\}\)) then we’d transform the table of \(X_3 | X_1\) into \[ P(X_3 = 0 | X_1 = 0), \, P(X_3 = 0 | X_1 = 1), P(X_3 = 0 | X_1 = 0), \, P(X_3 = 0 | X_1 = 1)\\ P(X_3 = 1 | X_1 = 0), \, P(X_3 = 1 | X_1 = 1), P(X_3 = 1 | X_1 = 0), \, P(X_3 = 1 | X_1 = 1). \] Now, suppose we already have a vector for \(P(X_3 = x_3 | X_1 = x_1)\) indexed by \((x_1, x_3)\) in reverse lexicographical order: \[ P(X_3 = 0 | X_1 = 0), \, P(X_3 = 0 | X_1 = 1), P(X_3 = 1 | X_1 = 0), \, P(X_3 = 1 | X_1 = 1), \] we need the first and second entries repeated, followed by the third and fourth entries:

patternRepeat0(c(1,3), c(2,2,2))
## [1] 1 2 1 2 3 4 3 4

patternRepeat0() requires us only to specify the elements present and the dimension of the full distribution. The existing order of the distribution is assumed to be reverse lexicographic, regardless of the order given in the first argument, but this can be over-ridden.

patternRepeat0(c(3,1), c(2,2,2))
## [1] 1 2 1 2 3 4 3 4
patternRepeat0(c(3,1), c(2,2,2), keep.order=TRUE)
## [1] 1 3 1 3 2 4 2 4

Another way to think about this is that if we take the possible indices for a 3 dimensional array and match them to the indices of just the first and third dimensions, patternRepeat0() tells us which point should be matched.

Example

Let’s generate some conditional probability tables.

set.seed(134)

p1 = c(rdirichlet(1,c(1,1)))
p2.1 = c(rdirichlet(2,c(1,1)))
p3.1 = c(rdirichlet(2,c(1,1)))

p12 = p1*p2.1
## get joint distribution
p123 = p12*p3.1[patternRepeat0(c(1,3), c(2,2,2))]

## put into array to verify this has correct
## conditional distribution
dim(p123) = c(2,2,2)
conditionTable(p123, 3, 1)
##           [,1]      [,2]
## [1,] 0.5956466 0.6709426
## [2,] 0.4043534 0.3290574
## can also get conditional distribution indexed by all variables
p3.1[patternRepeat0(c(1,3), c(2,2,2))]
## [1] 0.5956466 0.6709426 0.5956466 0.6709426 0.4043534 0.3290574 0.4043534
## [8] 0.3290574
c(conditionTable2(p123, 3, 1))
## [1] 0.5956466 0.6709426 0.5956466 0.6709426 0.4043534 0.3290574 0.4043534
## [8] 0.3290574