I have a Markov Chain in R:
set.seed(123)
n_states <- 5
matrix <- matrix(runif(n_states^2), nrow=n_states)
# set some transitions to 0
matrix[1, 4:5] <- 0
matrix[5, 1:3] <- 0
matrix[2, 5] <- 0
transition_matrix <- t(apply(matrix, 1, function(x) x/sum(x)))
rownames(transition_matrix) <- paste0("S", 1:n_states)
colnames(transition_matrix) <- paste0("S", 1:n_states)
print(round(transition_matrix, 3))
It looks like this:
S1 S2 S3 S4 S5
S1 0.2229340 0.03531601 0.7417500 0.00000000 0.0000000
S2 0.3910569 0.26197885 0.2248868 0.12207747 0.0000000
S3 0.1536622 0.33530265 0.2545791 0.01580275 0.2406533
S4 0.2652280 0.16563210 0.1719994 0.09849610 0.2986444
S5 0.0000000 0.00000000 0.0000000 0.59278229 0.4072177
For a fixed number of turns, I want to find out all possible state sequences that can occur and their corresponding probabilities.
I tried to do this manually using loops to enumerate all such sequences:
# Function to generate sequences for multiple turn lengths
find_sequences_all_turns <- function(transition_matrix, start_state = 1, max_turns = 5) {
n_states <- nrow(transition_matrix)
all_sequences <- list()
all_probabilities <- numeric()
all_turns <- numeric()
seq_counter <- 1
generate_sequence <- function(current_seq, current_prob, steps_left, total_steps) {
if(length(current_seq) > 1) {
all_sequences[[seq_counter]] <<- current_seq
all_probabilities[seq_counter] <<- current_prob
all_turns[seq_counter] <<- total_steps - steps_left
seq_counter <<- seq_counter + 1
}
if(steps_left == 0) {
return()
}
current_state <- current_seq[length(current_seq)]
possible_next_states <- which(transition_matrix[current_state,] > 0)
for(next_state in possible_next_states) {
prob <- transition_matrix[current_state, next_state]
generate_sequence(
c(current_seq, next_state),
current_prob * prob,
steps_left - 1,
total_steps
)
}
}
generate_sequence(c(start_state), 1, max_turns - 1, max_turns)
result_df <- data.frame(
turn = all_turns,
sequence_no = 1:length(all_sequences),
sequence = sapply(all_sequences, paste, collapse=""),
probability = all_probabilities
)
result_df <- result_df[order(result_df$turn, -result_df$probability),]
rownames(result_df) <- NULL
return(result_df)
}
I then tried to call the function:
sequences_df <- find_sequences_all_turns(transition_matrix)
> sequences_df
turn sequence_no sequence probability
1 2 154 13 7.417500e-01
2 2 1 11 2.229340e-01
3 2 65 12 3.531601e-02
4 3 171 132 2.487108e-01
5 3 193 133 1.888341e-01
6 3 243 135 1.785046e-01
7 3 40 113 1.653613e-01
8 3 155 131 1.139789e-01
9 3 2 111 4.969955e-02
10 3 66 121 1.381057e-02
11 3 218 134 1.172169e-02
12 3 82 122 9.252048e-03
13 3 104 123 7.942105e-03
14 3 18 112 7.873137e-03
15 3 129 124 4.311289e-03
Is there something I can do to make sure this code run faster for large number of turns
PS: I used this code to verify all probabilities sum to 1 for each turn:
library(dplyr)
probability_sums <- sequences_df %>%
group_by(turn) %>%
summarise(
total_probability = sum(probability),
num_sequences = n(),
check_sum_to_one = abs(total_probability - 1) < 1e-10
)
print(probability_sums)
I think it should work well for your purpose (given tm <- transition_matrix
, shorter variable name for "transition matrix")
nms <- row.names(tm)
N <- 5
res <-lapply(2:N, function(k) {
u <- expand.grid(rep(list(nms), k))
subset(
transform(
u,
Freq = Reduce(`*`, Map(\(x, y) tm[cbind(x, y)], u[-length(u)], u[-1]))
),
Freq > 0
)
})
and you can see
> lapply(res, head, n = 10)
[[1]]
Var1 Var2 Freq
1 S1 S1 0.22293395
2 S2 S1 0.39105686
3 S3 S1 0.15366217
4 S4 S1 0.26522803
6 S1 S2 0.03531601
7 S2 S2 0.26197885
8 S3 S2 0.33530265
9 S4 S2 0.16563210
11 S1 S3 0.74175004
12 S2 S3 0.22488682
[[2]]
Var1 Var2 Var3 Freq
1 S1 S1 S1 0.04969955
2 S2 S1 S1 0.08717985
3 S3 S1 S1 0.03425651
4 S4 S1 S1 0.05912833
6 S1 S2 S1 0.01381057
7 S2 S2 S1 0.10244863
8 S3 S2 S1 0.13112240
9 S4 S2 S1 0.06477157
11 S1 S3 S1 0.11397892
12 S2 S3 S1 0.03455660
[[3]]
Var1 Var2 Var3 Var4 Freq
1 S1 S1 S1 S1 0.011079716
2 S2 S1 S1 S1 0.019435348
3 S3 S1 S1 S1 0.007636940
4 S4 S1 S1 S1 0.013181713
6 S1 S2 S1 S1 0.003078844
7 S2 S2 S1 S1 0.022839277
8 S3 S2 S1 S1 0.029231635
9 S4 S2 S1 S1 0.014439781
11 S1 S3 S1 S1 0.025409771
12 S2 S3 S1 S1 0.007703838
[[4]]
Var1 Var2 Var3 Var4 Var5 Freq
1 S1 S1 S1 S1 S1 0.0024700449
2 S2 S1 S1 S1 S1 0.0043327990
3 S3 S1 S1 S1 S1 0.0017025332
4 S4 S1 S1 S1 S1 0.0029386513
6 S1 S2 S1 S1 S1 0.0006863789
7 S2 S2 S1 S1 S1 0.0050916503
8 S3 S2 S1 S1 S1 0.0065167238
9 S4 S2 S1 S1 S1 0.0032191175
11 S1 S3 S1 S1 S1 0.0056647005
12 S2 S3 S1 S1 S1 0.0017174471
Another option (might be more efficient since the rows that result in 0
probabilities are filtered out in advances) is using merge
within a repeat
loop, to update the table, iteratively
N <- 5
dp <- d <- subset(as.data.frame.table(tm), Freq > 0)
res <- list(dp)
repeat {
if (length(res) == N - 1) break
p <- names(dp[-length(dp)])
q <- paste0("Var", as.integer(sub("\\D+", "", p)) + 1)
dnew <- setNames(dp, c(q, "yFreq"))
dp <- subset(
transform(
merge(d, dnew)[unique(c(p, names(dnew), "Freq"))],
Freq = Freq * yFreq
),
select = -yFreq
)
res[[length(res) + 1]] <- dp
}
such that
> lapply(res, head, n = 10)
[[1]]
Var1 Var2 Freq
1 S1 S1 0.22293395
2 S2 S1 0.39105686
3 S3 S1 0.15366217
4 S4 S1 0.26522803
6 S1 S2 0.03531601
7 S2 S2 0.26197885
8 S3 S2 0.33530265
9 S4 S2 0.16563210
11 S1 S3 0.74175004
12 S2 S3 0.22488682
[[2]]
Var1 Var2 Var3 Freq
1 S1 S1 S1 0.049699546
2 S1 S1 S2 0.007873137
3 S1 S1 S3 0.165361267
4 S2 S1 S1 0.087179851
5 S2 S1 S2 0.013810568
6 S2 S1 S3 0.290066443
7 S3 S1 S1 0.034256514
8 S3 S1 S2 0.005426735
9 S3 S1 S3 0.113978919
10 S4 S1 S1 0.059128333
[[3]]
Var1 Var2 Var3 Var4 Freq
1 S1 S1 S1 S1 0.0110797161
2 S1 S1 S1 S2 0.0017551896
3 S1 S1 S1 S3 0.0368646403
4 S1 S1 S3 S2 0.0554460705
5 S1 S1 S3 S3 0.0420975207
6 S1 S1 S3 S4 0.0026131624
7 S1 S1 S3 S5 0.0397947423
8 S1 S1 S2 S4 0.0009611327
9 S1 S1 S2 S1 0.0030788444
10 S1 S1 S2 S2 0.0020625955
[[4]]
Var1 Var2 Var3 Var4 Var5 Freq
1 S1 S1 S1 S1 S1 0.0024700449
2 S1 S1 S1 S1 S2 0.0003912914
3 S1 S1 S1 S1 S3 0.0082183799
4 S1 S1 S1 S3 S2 0.0123608115
5 S1 S1 S1 S3 S3 0.0093849666
6 S1 S1 S1 S3 S4 0.0005825626
7 S1 S1 S1 S3 S5 0.0088715991
8 S1 S1 S1 S2 S4 0.0002142691
9 S1 S1 S1 S2 S1 0.0006863789
10 S1 S1 S1 S2 S2 0.0004598226