Representing Parametric Survival Model in 'Counting Process' form in JAGS

RSurvival AnalysisJagsStanWeibull

R Problem Overview


I'm trying to build a survival model in JAGS that allows for time-varying covariates. I'd like it to be a parametric model — for example, assuming survival follows the Weibull distribution (but I'd like to allow the hazard to vary, so exponential is too simple). So, this is essentially a Bayesian version of what can be done in the flexsurv package, which allows for time-varying covariates in parametric models.

Therefore, I want to be able to enter the data in a 'counting-process' form, where each subject has multiple rows, each corresponding to a time interval in which their covariates remained constant (as described in this pdf or here. This is the (start, stop] formulation that the survival or flexurv packages allow.

Unfortunately, every explanation of how to perform survival analysis in JAGS seems to assume one row per subject.

I attempted to take this simpler approach and extend it to the counting process format, but the model does not correctly estimate the distribution.

A Failed Attempt:

Here's an example. First we generate some data:

library('dplyr')
library('survival')

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

  person person_start_time person_censored person_duration
   (int)             (dbl)           (lgl)           (dbl)
1      1          11.81416           FALSE        487.4553
2      2         114.20900           FALSE        168.7674
3      3          75.34220           FALSE        356.6298
4      4         339.98225           FALSE        385.5119
5      5         389.23357           FALSE        259.9791
6      6         253.71067           FALSE        259.0032
7      7         419.52305            TRUE        310.4770

Then we split the data into 2 observations per subject. I'm just splitting each subject at time = 300 (unless they didn't make it to time=300, in which they get just one observation).

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)

  person    split START      END TINTERVAL CENS TINTERVAL_CENS
   (int)    (dbl) (dbl)    (dbl)     (dbl) (dbl)        (dbl)
1      1 300.0000     0 300.0000 300.00000     1           NA
2      1 300.0000   300 487.4553 187.45530     0    187.45530
3      2 168.7674     0 168.7674 168.76738     1           NA
4      3 300.0000     0 300.0000 300.00000     1           NA
5      3 300.0000   300 356.6298  56.62979     0     56.62979
6      4 300.0000     0 300.0000 300.00000     1           NA

Now we can set up the JAGS model.

## Set-Up JAGS Model -------
dat_jags <- as.list(dat_split)
dat_jags$N <- length(dat_jags$TINTERVAL)
inits <- replicate(n = 2, simplify = FALSE, expr = {
       list(TINTERVAL_CENS = with(dat_jags, ifelse(CENS, TINTERVAL + 1, NA)),
            END_CENS = with(dat_jags, ifelse(CENS, END + 1, NA)) )
})

model_string <- 
"
  model {
    # set priors on reparameterized version, as suggested
    # here: https://sourceforge.net/p/mcmc-jags/discussion/610036/thread/d5249e71/?limit=25#8c3b
    log_a ~ dnorm(0, .001) 
    log(a) <- log_a
    log_b ~ dnorm(0, .001)
    log(b) <- log_b
    nu <- a
    lambda <- (1/b)^a
    
    for (i in 1:N) {
      # Estimate Subject-Durations:
      CENS[i] ~ dinterval(TINTERVAL_CENS[i], TINTERVAL[i])
      TINTERVAL_CENS[i] ~ dweibull( nu, lambda )
    }
  }
"

library('runjags')
param_monitors <- c('a', 'b', 'nu', 'lambda') 
fit_jags <- run.jags(model = model_string,
                     burnin = 1000, sample = 1000, 
                     monitor = param_monitors,
                     n.chains = 2, data = dat_jags, inits = inits)
# estimates:
fit_jags
# actual:
c(a=true_shape, b=true_scale)

Depending on where the split point is, the model estimates very different parameters for the underlying distribution. It only gets the parameters right if the data isn't split into the counting process form. It seems like this is not the way to format the data for this kind of problem.

If I am missing an assumption and my problem is less related to JAGS and more related to how I'm formulating the problem, suggestions are very welcome. I might be despairing that time-varying covariates can't be used in parametric survival models (and can only be used in models like the Cox model, which assumes constant hazards and which doesn't actually estimate the underlying distribution)— however, as I mentioned above, the flexsurvreg package in R does accommodate the (start, stop] formulation in parametric models.

If anyone knows how to build a model like this in another language (e.g. STAN instead of JAGS) that would be appreciated too.

Edit:

Chris Jackson provides some helpful advice via email:

> I think the T() construct for truncation in JAGS is needed here. Essentially for each period (t[i], t[i+1]) where a person is alive but the covariate is constant, the survival time is left-truncated at the start of the period, and possibly also right-censored at the end. So you'd write something like y[i] ~ dweib(shape, scale[i])T(t[i], )

I tried implementing this suggestion as follows:

model {
  # same as before
  log_a ~ dnorm(0, .01) 
  log(a) <- log_a
  log_b ~ dnorm(0, .01)
  log(b) <- log_b
  nu <- a
  lambda <- (1/b)^a
  
  for (i in 1:N) {
    # modified to include left-truncation
    CENS[i] ~ dinterval(END_CENS[i], END[i])
    END_CENS[i] ~ dweibull( nu, lambda )T(START[i],)
  }
}

Unfortunately this doesn't quite do the trick. With the old code, the model was mostly getting the scale parameter right, but doing a very bad job on the shape parameter. With this new code, it gets very close to the correct shape parameter, but consistently over-estimates the scale parameter. I have noticed that the degree of over-estimation is correlated with how late the split point comes. If the split-point is early (cens_point = 50), there's not really any over-estimation; if it's late (cens_point = 350), there is a lot.

I thought maybe the problem could be related to 'double-counting' the observations: if we see a censored observation at t=300, then from that same person, an uncensored observation at t=400, it seems intuitive to me that this person is contributing two data-points to our inference about the Weibull parameters when really they should just be contributing one point. I, therefore, tried incorporating a random-effect for each person; however, this completely failed, with huge estimates (in the 50-90 range) for the nu parameter. I'm not sure why that is, but perhaps that's a question for a separate post. Since I'm not whether the problems are related, you can find the code for this whole post, including the JAGS code for that model, here.

R Solutions


Solution 1 - R

You can use rstanarm package, which is a wrapper around STAN. It allows to use standard R formula notation to describe survival models. stan_surv function accepts arguments in a "counting process" form. Different base hazard functions including Weibull can be used to fit the model.

The survival part of rstanarm - stan_surv function is still not available at CRAN so you should install the package directly from mc-stan.org.

install.packages("rstanarm", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))

Please see the code below:

library(dplyr)
library(survival)
library(rstanarm)

## Make the Data: -----
set.seed(3)
n_sub <- 1000
current_date <- 365*2

true_shape <- 2
true_scale <- 365

dat <- data_frame(person = 1:n_sub,
                  true_duration = rweibull(n = n_sub, shape = true_shape, scale = true_scale),
                  person_start_time = runif(n_sub, min= 0, max= true_scale*2),
                  person_censored = (person_start_time + true_duration) > current_date,
                  person_duration = ifelse(person_censored, current_date - person_start_time, true_duration)
)

## Split into multiple observations per person: --------
cens_point <- 300 # <----- try changing to 0 for no split; if so, model correctly estimates
dat_split <- dat %>%
  group_by(person) %>%
  do(data_frame(
    split = ifelse(.$person_duration > cens_point, cens_point, .$person_duration),
    START = c(0, split[1]),
    END = c(split[1], .$person_duration),
    TINTERVAL = c(split[1], .$person_duration - split[1]), 
    CENS = c(ifelse(.$person_duration > cens_point, 1, .$person_censored), .$person_censored), # <— edited original post here due to bug; but problem still present when fixing bug
    TINTERVAL_CENS = ifelse(CENS, NA, TINTERVAL),
    END_CENS = ifelse(CENS, NA, END)
  )) %>%
  filter(TINTERVAL != 0)
dat_split$CENS <- as.integer(!(dat_split$CENS))


# Fit STAN survival model
mod_tvc <- stan_surv(
  formula = Surv(START, END, CENS) ~ 1,
  data = dat_split,
  iter = 1000,
  chains = 2,
  basehaz = "weibull-aft")

# Print fit coefficients
mod_tvc$coefficients[2]
unname(exp(mod_tvc$coefficients[1]))

Output, which is consistent with true values (true_shape <- 2; true_scale <- 365):

> mod_tvc$coefficients[2]
weibull-shape 
     1.943157 
> unname(exp(mod_tvc$coefficients[1]))
[1] 360.6058

You can also look at STAN source using rstan::get_stanmodel(mod_tvc$stanfit) to compare STAN code with the attempts you made in JAGS.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionjwdinkView Question on Stackoverflow
Solution 1 - RArtemView Answer on Stackoverflow