Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: log matrix product #3026

Open
mhollanders opened this issue Feb 22, 2024 · 6 comments
Open

Feature request: log matrix product #3026

mhollanders opened this issue Feb 22, 2024 · 6 comments

Comments

@mhollanders
Copy link

In some applications such as the forward algorithm for hidden Markov models, matrix multiplication can be used to compute the marginal probabilities if working on the probability scale. However, if working on the log scale, there is no function (AFAIK) that computes log(A * B) using log(A) and log(B), where A and B are transition probability matrices. Such a function is easy enough to write in a Stan program:

/**
   * Log matrix product
   * 
   * Compute log(A * B) from log(A) and log(B)
   *
   * @param a:  Logarithm of matrix or row_vector A
   * @param b:  Logarithm of matrix or vector B
   *
   * @return  Logarithm of matrix A * B
   */
  matrix log_mat_prod(matrix a, matrix b) {
    int x = rows(a);
    int y = cols(b);
    matrix[x, y] c;
    for (i in 1:x) {
      for (j in 1:y) {
        c[i, j] = log_sum_exp(row(a, i)' + col(b, j));
      }
    }
    return(c);
  }

However, I anticipate this problem may be common enough to warrant a built-in Stan function. Please ignore if I am overlooking something obvious.

@bob-carpenter
Copy link
Contributor

bob-carpenter commented Feb 22, 2024

Thanks, @mhollanders. That looks like a reasonable function to implement. If you want to implement directly in Stan, this form will be more efficient.

/**
   * Return the natural logarithm of the product of the elementwise exponentiation of the specified matrices.  
   *
   * @param a first matrix
   * @param b second matrix
   * @return  log(exp(a) * exp(b))
   * @throws reject if cols(a) != rows(b)
   */
  matrix log_product_exp(matrix a, matrix b) {
    int M = rows(a);
    int K = cols(a);
    int N = cols(b);
    if (rows(N) != K) reject("matrices must conform for multiplication");
    matrix[M, K] a_tr = a';
    matrix[M, N] c;
    for (n in 1:N)
      for (m in 1:M)
        c[m, n] = log_sum_exp(a_tr[ , m] + b[ , n]);
    return c;
  }

That's because (a) it's more efficient to transpose the matrix all at once rather than a row at a time because the underlying algorithm can use cache-sensitive blocking, and (b) because matrices are column major in Stan, so it's more efficient to travers them in the revised order.

Unlike in R, return isn't a function, so it doesn't need parens. I also rewrote the doc to conform to our expected style, which describes what a function returns as explicitly as possible in one sentence. Along with this, I changed the name to more closely follow log-sum-exp.

In Stan, to extend this to row vectors and vectors, we'llneed an additional signature that is specific about return type (here I chose vector, but it could've been row_vector and it can't be both).

real log_product_exp(row_vector a, vector b);

It'll be more efficient to rewrite this function than delegate to the matrix version in Stan. This could also be done as an outer product of (vector, row_vector).

P.S. We have a built-in HMM distribution implementing the forward/backward algorithm under the hood if you have a standard HMM.

@mhollanders
Copy link
Author

Thanks for the prompt response and tips, @bob-carpenter, I have made the changes to my code for efficiency! FYI, my HMM is a bit more complex but it's good to see those HMM functions built-in!

@bob-carpenter
Copy link
Contributor

I'm curious about the added complexity and how we might improve our interface. One thing we didn't support well was varying transition matrices involving predictors, whereas I know those are things that show up in, say, moveHMM in R, which is popular in ecology (we tested against some of their simpler examples when setting up our HMM interfaces).

@bob-carpenter
Copy link
Contributor

P.S. I hope you caught my bugs varying m/n to I/j! I'll go back and edit the code for coherence.

@mhollanders
Copy link
Author

Hey, firstly, yes I did catch the bugs.

For my model, I firstly just generate new TPMs for each individual because there are very often individual-by-time varying covariates, such as individual infection intensity in disease models. Secondly, this model is sort of a doubly nested HMM (there's probably a better term that exists). The idea is that ecological transitions (such as mortality and infection state transitions) occur between so-called primary occasions. Within each primary occasion, multiple replicated secondary surveys occur. In my disease example, multiple surveys are conducted and during each capture, a sample is collected to determine infection status. However, we get both false-negatives and false-positives in the sampling process. Therefore, both the ecological states (alive and uninfected/infected, dead) and possible observed states (detected and uninfected/infected, not detected) are (partially) latent. Finally, for each sample that's collected, it's subjected to multiple diagnostic runs to determine whether that sample was infected; once again, there are false-negatives and false-positives in this process too. Note that this is another place where time-varying individual covariates are included; the diagnostic runs inform the latent sample infection intensity, and the (sometimes) replicated sample infection intensities inform the latent individual infection intensities. These infection intensities are used to model the detection probabilities of the sampling and diagonstic processes.

If you check out my first post in this thread, you can see this model in Stan code. Otherwise, here I have it written in equations (with latent states).

Actually, one thing that's really common with HMMs is unequal time intervals. You might consider adding a continuous time variation as a Stan functions, where you simply take the transition rate matrix (TRM), multiply it by the time interval between (primary) occasions, and matrix exponentiate it to yield the transition probability matrix (I do this in my model too).

@bob-carpenter
Copy link
Contributor

Thanks for the clarification. I'll reply in Discourse.

@bob-carpenter bob-carpenter transferred this issue from stan-dev/stan Feb 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants