Supervised Topic Modeling for Short Texts: My Workflow and A Worked Example

Many organizations have a substantial amount of human-generated text from which they are not extracting a proportional amount of insight. For example, open-ended questions are found in most surveys—but are rarely given the same amount of attention (if any attention at all) as the easier-to-analyze quantitative data. I have tested out many supposedly “AI-powered” or “NLP-driven” tools for analyzing text in my career, and I haven’t found anything to be useful at finding topics or modeling sentiment when fed real data. I wrote on my reservations about common topic modeling methods over four years ago, where I showed how I perform exploratory analysis on text data based on word co-occurrences. That was an unsupervised approach: No a priori topics are given to a model to learn from. It looks at patterns of how frequently words are used together to infer topics.

I lay out my approach for supervised topic modeling in short texts (e.g., open-response survey data) here. My philosophy is one where there is no free lunch: If you want data that make sense for your specific organization and use case, you’re gonna actually have to read and consider the text you get. This isn’t using statistical learning to do the job for you—it is using these models as a tool to work with. You can’t solve this purely from the command line. There will need to be many humans in the loop. This process can be a grind, but what you’ll come out of it with is a scalable, bespoke model. I don’t think I’m saying anything particularly revolutionary in this post, but hopefully the walk-through and code might help you develop a similar workflow.

The working example I’m using here are Letterboxd reviews for Wes Anderson’s newest film, Asteroid City.

Step 1: Thematic Content Analysis

Create a coding corpus that is a random sample (maybe 2000 cases, depending on how long each piece of text is, how much bandwidth you have for this project, etc.) from your entire corpus of text (e.g., reviews, open-ended responses, help tickets). Perform a thematic content analysis in partnership with colleagues. I am not a qualitative methodologist by any means, but I start by reading the entirety of the random sample. Then I read through the coding corpus again, taking notes about key themes. I look at these themes, and I think about how they may be grouped together. I talk with my colleagues to see what they think. We discuss if certain themes can be combined or if they need to be kept separate. We talk about the maximum number of themes we think we should be looking for. We finalize themes, discuss what they mean in detail, and name them. From these readings and conversations, I write a standardized coding manual that others can use to read responses and code according to the themes.

In my example, I wrote an R script (see Appendix A or my GitHub) to scrape Letterboxd for reviews. For each possible rating (0.5 stars to 5.0 stars, in increments of half stars), I pulled six pages. Each page had twelve reviews. Ten possible ratings, six pages each, and twelve reviews per page multiplies out to 720 reviews. In the process of coding, I dropped any review that was not written in English—a limitation of mine here due to me making the regrettable decision to study Latin in college (although an apropos course for a Wes Anderson blog post).

I decided to stick to just one theme for the purposes of this exercise: Did the review discuss the visual style of the film? Anderson is known for his distinctive visual style, which at this point is unfortunately a bit of a cultural meme. There was a whole Instagram/TikTok trend of posting videos of oneself in the style of Anderson; there’s those cynical, intellectually vapid, artless, AI-generated simulacra imagining [insert well-known intellectual property here] in the Andersonian style; and there’s even a Hertz (the rent-a-car company) commercial that references the meme, which in turn references the style (Baudrillard’s head would be spinning, folks).

This visual style, by my eye, was solidified by the time of The Grand Budapest Hotel. Symmetrical framing, meticulously organized sets, the use of miniatures, a pastel color palette contrasted with saturated primary colors, distinctive costumes, straight lines, lateral tracking shots, whip pans, and so on. However, I did not consider aspects of Anderson’s style that are unrelated to the visual sense. This is where defining the themes with a clear line matters—and often there will be ambiguities, but one must do their best, because the process we’re doing is fundamentally a simplification of the rich diversity of the text. Thus, I did not consider the following to be in Anderson’s visual style: stories involving precocious children, fraught familial relations, uncommon friendships, dry humor, monotonous dialogue, soundtracks usually involving The Kinks, a fascination with stage productions, nesting doll narratives, or a decidedly twee yet bittersweet tone.

Step 2: Independent Coders

This is where you will need organizational buy-in and the help of your colleagues. Recruit fellow researchers (or subject experts) that have domain expertise. For example, if the text data are all reviews for a specific app, make sure you recruit colleagues that have knowledge about that specific app. I generate standardized Google Sheets where each row is a text response, each column is a theme, and I send them—along with the coding manual—to the coders. They are instructed to put in a “1” wherever a text touches on that theme as is described in the coding manual. I have each piece of text coded by two independent coders. After everyone is done, I resolve disagreements by a third-party vote (which may or may not be myself). Combine the various Google Sheets (I use the R package and the suite of packages for merging these data), and congratulations! You have supervised text data that are ready for training.

For Asteroid City reviews, I just did this on a Saturday afternoon myself. Any piece of text that touched on the visual style was flagged with a “1,” while the rest were marked as “0.” The text I pulled was abbreviated for longer reviews; however, I don’t mind this, as I feel it stays true to the analogy of short open-ended survey responses that way.

Step 3: Train the Models

This is where I stand on the shoulders of Emil Hvitfeldt and Julia Silge by way of their book, Supervised Machine Learning for Text Analysis in R. I define an entire machine learning workflow (i.e., data processing, cross-validation, model selection, and holdout performance), and then I package this workflow into convenient wrapper functions. This allows me to use that wrapper function on each column, so that every theme gets its own model. I write the cross-validation results and final models out as RDS files so that I can access them later.

The wrapper functions are in Appendix B below and at my GitHub. It’s, uh, long. (I go against my own imperative to write modular code here, but if there’s a less verbose way of defining this workflow, please genuinely let me know.) I am making three basic functions: one to create the train/test split, another to perform cross-validation, and a third to pull the best model.

The prep_data function takes your data, the name of the outcome and text variables, and how large of a holdout set you want.

You can change exactly what pre-processing steps and algorithms you want to use from what I have. But I have found that the processing of the data usually is more influential than a specific model. So in the wrapper function do_cv, I am trying out all combinations of the following:

  • Stop words: None, SMART, or ISO (different lists of stop words)
  • N-grams: Words only or words and bigrams
  • Stemming: Yes or no
  • Word frequency filtering: Words used 1 or more times in the corpus, 2 or more times, or 5 or more
  • Algorithm: Elastic net or random forest

This leaves 3 * 2 * 2 * 3 * 2 = 72 combinations of pre-processors and algorithm combinations. The models were tuned according to grid search (the size of which is an argument to the do_cv function) using k-fold validation (where k is also an argument). You could add other pre-processing steps (e.g., word embeddings) or other algorithms (e.g., XGBoost, neural networks) to the body of this function.

When that’s done, the best_model function will pull everything you need to know about the model that performed best in cross-validation, and it will give you performance metrics on the holdout set.

When all is said and done, the actual running of the code then looks deceptively simple:

library(tidyverse)
source("funs.R")

dat <- read_csv("ratings_coded.csv") %>% 
  filter(visual_style != 9) # remove non-english

set.seed(1839)
dat <- prep_data(dat, "visual_style", "text", .20)
cv_res <- do_cv(training(dat), 4, 10)
write_rds(cv_res, "visual_style-res.rds")
mod <- best_model(dat, cv_res, "roc_auc")

Where funs.R is a file that includes the functions in Appendix B. This has a 20% holdout set, 4 folds, and a grid search with size 10. It then writes out those results to an RDS object. Lastly, we pull the best model according to the area under the ROC curve. If you had multiple codes, you could write a for loop here for the name of each column (e.g., prep_data(dat, i, "text", .20)).

Step 4: Use the Models

First, let’s check out what some of these objects look like. Let’s take a look at the metrics for the first model that went into the cross-validation results set.

cv_res <- read_rds("visual_style-res.rds")
cv_res$cv_res$result[[1]]$.metrics
## [[1]]
## # A tibble: 60 × 6
##        penalty mixture .metric     .estimator .estimate .config              
##          <dbl>   <dbl> <chr>       <chr>          <dbl> <chr>                
##  1 0.000000787  0.0698 accuracy    binary         0.774 Preprocessor1_Model01
##  2 0.000000787  0.0698 sensitivity binary         0.923 Preprocessor1_Model01
##  3 0.000000787  0.0698 specificity binary         0     Preprocessor1_Model01
##  4 0.000000787  0.0698 precision   binary         0.828 Preprocessor1_Model01
##  5 0.000000787  0.0698 f_meas      binary         0.873 Preprocessor1_Model01
##  6 0.000000787  0.0698 roc_auc     binary         0.777 Preprocessor1_Model01
##  7 0.000374     0.166  accuracy    binary         0.774 Preprocessor1_Model02
##  8 0.000374     0.166  sensitivity binary         0.923 Preprocessor1_Model02
##  9 0.000374     0.166  specificity binary         0     Preprocessor1_Model02
## 10 0.000374     0.166  precision   binary         0.828 Preprocessor1_Model02
## # ℹ 50 more rows
## 
## [[2]]
## # A tibble: 60 × 6
##        penalty mixture .metric     .estimator .estimate .config              
##          <dbl>   <dbl> <chr>       <chr>          <dbl> <chr>                
##  1 0.000000787  0.0698 accuracy    binary         0.767 Preprocessor1_Model01
##  2 0.000000787  0.0698 sensitivity binary         0.913 Preprocessor1_Model01
##  3 0.000000787  0.0698 specificity binary         0.286 Preprocessor1_Model01
##  4 0.000000787  0.0698 precision   binary         0.808 Preprocessor1_Model01
##  5 0.000000787  0.0698 f_meas      binary         0.857 Preprocessor1_Model01
##  6 0.000000787  0.0698 roc_auc     binary         0.714 Preprocessor1_Model01
##  7 0.000374     0.166  accuracy    binary         0.767 Preprocessor1_Model02
##  8 0.000374     0.166  sensitivity binary         0.913 Preprocessor1_Model02
##  9 0.000374     0.166  specificity binary         0.286 Preprocessor1_Model02
## 10 0.000374     0.166  precision   binary         0.808 Preprocessor1_Model02
## # ℹ 50 more rows
## 
## [[3]]
## # A tibble: 60 × 6
##        penalty mixture .metric     .estimator .estimate .config              
##          <dbl>   <dbl> <chr>       <chr>          <dbl> <chr>                
##  1 0.000000787  0.0698 accuracy    binary         0.667 Preprocessor1_Model01
##  2 0.000000787  0.0698 sensitivity binary         1     Preprocessor1_Model01
##  3 0.000000787  0.0698 specificity binary         0.167 Preprocessor1_Model01
##  4 0.000000787  0.0698 precision   binary         0.643 Preprocessor1_Model01
##  5 0.000000787  0.0698 f_meas      binary         0.783 Preprocessor1_Model01
##  6 0.000000787  0.0698 roc_auc     binary         0.824 Preprocessor1_Model01
##  7 0.000374     0.166  accuracy    binary         0.667 Preprocessor1_Model02
##  8 0.000374     0.166  sensitivity binary         1     Preprocessor1_Model02
##  9 0.000374     0.166  specificity binary         0.167 Preprocessor1_Model02
## 10 0.000374     0.166  precision   binary         0.643 Preprocessor1_Model02
## # ℹ 50 more rows
## 
## [[4]]
## # A tibble: 60 × 6
##        penalty mixture .metric     .estimator .estimate .config              
##          <dbl>   <dbl> <chr>       <chr>          <dbl> <chr>                
##  1 0.000000787  0.0698 accuracy    binary         0.733 Preprocessor1_Model01
##  2 0.000000787  0.0698 sensitivity binary         0.913 Preprocessor1_Model01
##  3 0.000000787  0.0698 specificity binary         0.143 Preprocessor1_Model01
##  4 0.000000787  0.0698 precision   binary         0.778 Preprocessor1_Model01
##  5 0.000000787  0.0698 f_meas      binary         0.84  Preprocessor1_Model01
##  6 0.000000787  0.0698 roc_auc     binary         0.745 Preprocessor1_Model01
##  7 0.000374     0.166  accuracy    binary         0.7   Preprocessor1_Model02
##  8 0.000374     0.166  sensitivity binary         0.870 Preprocessor1_Model02
##  9 0.000374     0.166  specificity binary         0.143 Preprocessor1_Model02
## 10 0.000374     0.166  precision   binary         0.769 Preprocessor1_Model02
## # ℹ 50 more rows

What we see here is the 6 metrics for each of the 10 tunings from the grid search. We get four tibbles here—one for each fold. The .config tells us that this is pre-processor 1. What does this refer to?

cv_res$wfs$wflow_id[[1]]
## [1] "word_nostop_nostem_f2_elasticnet"

This is looking at just words, no stop words, no stemming, a minimum use of 2 in the corpus, and the elastic net.

But what we want to know is which model is best. We used best_model for that above.

mod$best_id
## [1] "word_smart_stemmed_f2_randforest"
mod$best_params
## # A tibble: 1 × 3
##    mtry min_n .config              
##   <int> <int> <chr>                
## 1    20    34 Preprocessor1_Model09

What performed best in cross-validation was only looking at words, using the SMART dictionary for stop words, stemming, only considering words used at least twice in the corpus, and using the random forest while considering 20 words at a time with a minimum leaf size of 34.

We can also look at the holdout metrics:

mod$ho_metrics
## # A tibble: 1 × 7
##   accuracy precision sensitivity specificity est_pct_class est_pct_prob act_pct
##      <dbl>     <dbl>       <dbl>       <dbl>         <dbl>        <dbl>   <dbl>
## 1    0.765     0.762       0.994       0.110        0.0327        0.191   0.259

You’ll be familiar with the first four, but the last three are ones I use. Often, I am using these types of models to estimate, “How many people are talking about X theme?” So, I don’t care as much about individual-level predictions as I do about the aggregate. I do this by averaging up the estimated probabilities (est_pct_prob) and using that as my estimate. So here, I would say that about 19% of reviews are discussing Anderson’s visual style. The actual percent (act_pct) is a bit more at about 26%. But not bad for a Saturday afternoon.

What about variable importance? What words are helping us figure out if the review touches on Wes’s visual style or not?

print(mod$var_imp, n = 10)
## # A tibble: 368 × 2
##    Variable Importance
##    <fct>         <dbl>
##  1 color         1.56 
##  2 visual        1.35 
##  3 cool          1.02 
##  4 audienc       0.994
##  5 shot          0.940
##  6 plot          0.751
##  7 design        0.575
##  8 pretti        0.554
##  9 set           0.503
## 10 landscap      0.492
## # ℹ 358 more rows

And yeah, this makes sense. The last thing we can do is plug in new reviews that have gone up since I made the model and generate predictions for them. The first and the third talk about visual style, while the others do not.

new_reviews <- tibble(
  text = c("Kodak film and miniatures make for a cool movie",
           "To me, Wes never misses",
           paste(
             "Visuals and characters were stunning as always but the dual plot",
             "was hard to follow and the whole thing felt a bit hollow"
            ),
           "my head hurts idk how i feel about this yet",
           "Me after asking AI to write me a wes anderson script")
)

# predict on these
predict(mod$fit, new_reviews, type = "prob")
## # A tibble: 5 × 2
##   .pred_0 .pred_1
##     <dbl>   <dbl>
## 1   0.675  0.325 
## 2   0.955  0.0448
## 3   0.545  0.455 
## 4   0.968  0.0317
## 5   0.870  0.130

Unfortunately, we don’t get hits on our two that actually are about his style. In the aggregate, we see 40% of these new reviews about his style. Averaging up the .pred_1 column only gives us about 20%, but we can’t be surprised about that big of a miss from just five new reviews.

Step 5: Scale and Maintain

You’ve got the best model(s) in-hand. What now? You scale it! In Step 1, we only looked at a random sample of maybe 2000 pieces of text. But you may have much, much more than that; you might be continuously collecting new text data that adhere to the coding manual. Use the predict() function on those cases. Then what I do is average up the predicted probabilities to get an estimate of how many people are talking about a given theme. But you could also use the predicted probabilities to decide which users to email about a specific theme, for example. Since your coding scheme came from a specific prompt or data source, make sure that you are applying it only to text that come from this same generating mechanism (e.g., survey question, feedback prompt, ticketing system).

What is absolutely crucial here is keeping humans in the loop. As new data come in, figure out a regular cadence where you generate a (smaller than the original) random sample of data, code it, and then see how your predictions are holding up to new data. (After checking performance, these newly coded cases can then be added to the training set and the model can be updated to include them.) Also, read through these data to see if the coding manual needs to be updated, themes are no longer relevant, or new themes need to be added. This is, again, a grind of a process. But it is a way of scaling up the unique expertise of the humans in your organization to data that generally doesn’t get much attention. And, in my experience, keeping humans in the loop is not only good for model performance, but reading real responses from real people helps you as a researcher understand your data and your users (or respondents or whoever is generating the data) better.

Appendix A

library(rvest)
library(tidyverse)

# 12 reviews per page
# 10 possible ratings
# 6 pages each
# n = 720 reviews total

base_url1 <- "https://letterboxd.com/film/asteroid-city/reviews/rated/"
base_url2 <- "/by/activity/page/"
ratings <- seq(.5, 5, .5)
pages <- 6

# people often say to use purrr::map or an apply statement instead of a for
# loop. but if it fails, I want to know where, and I want to keep the data
# we already collected. so I'm initializing an empty data frame and then
# slotting the data in one page at a time
dat <- tibble(rating = vector("double"), text = vector("character"))

for (r in ratings) {
  pg <- 1
  
  while (pg < 7) {
    url <- paste0(base_url1, r, base_url2, pg)
    
    txt <- url %>% 
      read_html() %>% 
      html_nodes(".collapsible-text") %>% 
      html_text2() %>% 
      map_chr(str_replace_all, fixed("\n"), " ")
    
    dat <- bind_rows(dat, tibble(rating = r, text = txt))
    
    cat("finished page", pg, "of", r, "stars\n")
    
    pg <- pg + 1
    Sys.sleep(runif(1, 0, 10))
  }
}

write_csv(dat, "ratings_uncoded.csv")

Appendix B

# funs -------------------------------------------------------------------------
library(textrecipes)
library(vip)
library(stopwords)
library(tidymodels)
library(workflowsets)
library(tidyverse)

# prepare data for modeling, do train/test split
# default at a holdout set of 15%
# stratify on the outcome variable
prep_data <- function(dat, y, txt, prop) {
  dat %>%
    transmute(text = .data[[txt]], y = factor(.data[[y]])) %>%
    initial_split(prop, strata = y)
}

# do cross-validation with the same, pre-defined, hard-coded engines and recipes
# n_folds is number of folds; grid_size is the size of the grid
do_cv <- function(dat_train, n_folds, grid_size) {
  
  # define recipes -------------------------------------------------------------
  ## base ----------------------------------------------------------------------
  rec_base <- recipe(y ~ text, dat_train)
  
  ## tokenize ------------------------------------------------------------------
  rec_word_nostop <- rec_base %>%
    step_tokenize(
      text,
      token = "words"
    )
  
  rec_word_smart <- rec_base %>%
    step_tokenize(
      text,
      token = "words",
      options = list(stopwords = stopwords(source = "smart"))
    )
  
  rec_word_iso <- rec_base %>%
    step_tokenize(
      text,
      token = "words",
      options = list(stopwords = stopwords(source = "stopwords-iso"))
    )
  
  rec_both_nostop <- rec_base %>%
    step_tokenize(
      text,
      token = "skip_ngrams",
      options = list(
        n = 2,
        k = 0
      )
    )
  
  rec_both_smart <- rec_base %>%
    step_tokenize(
      text,
      token = "skip_ngrams",
      options = list(
        stopwords = stopwords(source = "smart"),
        n = 2,
        k = 0
      )
    )
  
  rec_both_iso <- rec_base %>%
    step_tokenize(
      text,
      token = "skip_ngrams",
      options = list(
        stopwords = stopwords(source = "stopwords-iso"),
        n = 2,
        k = 0
      )
    )
  
  ## stem ----------------------------------------------------------------------
  rec_word_nostop_stemmed <- rec_word_nostop %>%
    step_stem(text)
  
  rec_word_smart_stemmed <- rec_word_smart %>%
    step_stem(text)
  
  rec_word_iso_stemmed <- rec_word_iso %>%
    step_stem(text)
  
  rec_both_nostop_stemmed <- rec_both_nostop %>%
    step_stem(text)
  
  rec_both_smart_stemmed <- rec_both_smart %>%
    step_stem(text)
  
  rec_both_iso_stemmed <- rec_both_iso %>%
    step_stem(text)
  
  ## filter, weight ------------------------------------------------------------
  rec_word_nostop_f2 <- rec_word_nostop %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart_f2 <- rec_word_smart %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso_f2 <- rec_word_iso %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop_f2 <- rec_both_nostop %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart_f2 <- rec_both_smart %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso_f2 <- rec_both_iso %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_nostop_stemmed_f2 <- rec_word_nostop_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart_stemmed_f2 <- rec_word_smart_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso_stemmed_f2 <- rec_word_iso_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop_stemmed_f2 <- rec_both_nostop_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart_stemmed_f2 <- rec_both_smart_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso_stemmed_f2 <- rec_both_iso_stemmed %>%
    step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_nostop_f5 <- rec_word_nostop %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart_f5 <- rec_word_smart %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso_f5 <- rec_word_iso %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop_f5 <- rec_both_nostop %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart_f5 <- rec_both_smart %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso_f5 <- rec_both_iso %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_nostop_stemmed_f5 <- rec_word_nostop_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart_stemmed_f5 <- rec_word_smart_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso_stemmed_f5 <- rec_word_iso_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop_stemmed_f5 <- rec_both_nostop_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart_stemmed_f5 <- rec_both_smart_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso_stemmed_f5 <- rec_both_iso_stemmed %>%
    step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_nostop <- rec_word_nostop %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart <- rec_word_smart %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso <- rec_word_iso %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop <- rec_both_nostop %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart <- rec_both_smart %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso <- rec_both_iso %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_nostop_stemmed <- rec_word_nostop_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_smart_stemmed <- rec_word_smart_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_word_iso_stemmed <- rec_word_iso_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_nostop_stemmed <- rec_both_nostop_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_smart_stemmed <- rec_both_smart_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  rec_both_iso_stemmed <- rec_both_iso_stemmed %>%
    step_tf(text, weight_scheme = "binary")
  
  ## define specs --------------------------------------------------------------
  spec_elasticnet <- logistic_reg(
    mode = "classification",
    engine = "glmnet",
    penalty = tune(),
    mixture = tune()
  )
  
  spec_randforest <- rand_forest(
    mode = "classification",
    mtry = tune(),
    min_n = tune(),
    trees = 500
  ) %>%
    set_engine(engine = "ranger", importance = "impurity")
  
  # make workflowset -----------------------------------------------------------
  wfs <- workflow_set(
    preproc = list(
      word_nostop_nostem_f2 = rec_word_nostop_f2,
      word_smart_nostem_f2 = rec_word_smart_f2,
      word_iso_nostem_f2 = rec_word_iso_f2,
      both_nostop_nostem_f2 = rec_both_nostop_f2,
      both_smart_nostem_f2 = rec_both_smart_f2,
      both_iso_nostem_f2 = rec_both_iso_f2,
      word_nostop_stemmed_f2 = rec_word_nostop_stemmed_f2,
      word_smart_stemmed_f2 = rec_word_smart_stemmed_f2,
      word_iso_stemmed_f2 = rec_word_iso_stemmed_f2,
      both_nostop_stemmed_f2 = rec_both_nostop_stemmed_f2,
      both_smart_stemmed_f2 = rec_both_smart_stemmed_f2,
      both_iso_stemmed_f2 = rec_both_iso_stemmed_f2,
      word_nostop_nostem_f5 = rec_word_nostop_f5,
      word_smart_nostem_f5 = rec_word_smart_f5,
      word_iso_nostem_f5 = rec_word_iso_f5,
      both_nostop_nostem_f5 = rec_both_nostop_f5,
      both_smart_nostem_f5 = rec_both_smart_f5,
      both_iso_nostem_f5 = rec_both_iso_f5,
      word_nostop_stemmed_f5 = rec_word_nostop_stemmed_f5,
      word_smart_stemmed_f5 = rec_word_smart_stemmed_f5,
      word_iso_stemmed_f5 = rec_word_iso_stemmed_f5,
      both_nostop_stemmed_f5 = rec_both_nostop_stemmed_f5,
      both_smart_stemmed_f5 = rec_both_smart_stemmed_f5,
      both_iso_stemmed_f5 = rec_both_iso_stemmed_f5,
      word_nostop_nostem_f0 = rec_word_nostop,
      word_smart_nostem_f0 = rec_word_smart,
      word_iso_nostem_f0 = rec_word_iso,
      both_nostop_nostem_f0 = rec_both_nostop,
      both_smart_nostem_f0 = rec_both_smart,
      both_iso_nostem_f0 = rec_both_iso,
      word_nostop_stemmed_f0 = rec_word_nostop_stemmed,
      word_smart_stemmed_f0 = rec_word_smart_stemmed,
      word_iso_stemmed_f0 = rec_word_iso_stemmed,
      both_nostop_stemmed_f0 = rec_both_nostop_stemmed,
      both_smart_stemmed_f0 = rec_both_smart_stemmed,
      both_iso_stemmed_f0 = rec_both_iso_stemmed
    ),
    models = list(
      elasticnet = spec_elasticnet,
      randforest = spec_randforest
    ),
    cross = TRUE
  )
  
  # do cross-validation
  folds <- vfold_cv(dat_train, v = n_folds)
  
  # get result of cross-validation
  cv_res <- wfs %>%
    workflow_map(
      "tune_grid",
      grid = grid_size,
      resamples = folds,
      metrics = metric_set(
        accuracy,
        sensitivity,
        specificity,
        precision,
        f_meas,
        roc_auc
      ),
      verbose = TRUE
    )
  
  # return the entire workflow set and the results of cross-validation
  return(list(wfs = wfs, cv_res = cv_res))
}

# take results from cross-validation, get a final model and held-out metrics
best_model <- function(dat, cv_out, metric = "roc_auc") {
  
  # get the id of the best model, as according to the specified metric
  best_id <- cv_out$cv_res %>%
    rank_results(rank_metric = metric) %>%
    filter(.metric == metric & rank == 1) %>%
    pull(wflow_id)
  
  # get the name of the model
  # if you name the workflow sets differently, this step will change
  # it's based on the hardcoded names I gave them in defining the workflow set
  m <- str_split(best_id, "_")[[1]][[5]]
  
  # get best parameters, do final fit
  best_params <- cv_out$cv_res %>%
    extract_workflow_set_result(best_id) %>%
    select_best(metric = metric)
  
  final_fit <- cv_out$cv_res %>%
    extract_workflow(best_id) %>%
    finalize_workflow(best_params) %>%
    fit(training(dat))
  
  # run it on the holdout data
  dat_test <- testing(dat)
  
  # if glmnet, feed it the right penalty
  # I'm pretty sure this is necessary,
  # because glmnet fits many lambda as it's more efficient
  if (m == "elasticnet") {
    dat_test <- bind_cols(
      dat_test,
      predict(final_fit, dat_test, penalty = best_params$penalty),
      predict(final_fit, dat_test, type = "prob", penalty = best_params$penalty)
    )
  } else {
    dat_test <- bind_cols(
      dat_test,
      predict(final_fit, dat_test),
      predict(final_fit, dat_test, type = "prob")
    )
  }
  
  # define metrics to output
  ms <- metric_set(accuracy, sensitivity, specificity, precision)
  
  # return metrics
  metrics_res <- ms(dat_test, truth = y, estimate = .pred_class) %>%
    select(-.estimator) %>%
    spread(.metric, .estimate) %>%
    mutate(
      est_pct_class = mean(dat_test$.pred_class == 1),
      est_pct_prob = mean(dat_test$.pred_1),
      act_pct = mean(dat_test$y == 1)
    )
  
  # get variable importance, which depends on the model
  if (m == "randforest") {
    var_imp <- final_fit %>%
      extract_fit_parsnip() %>%
      vi() %>%
      mutate(
        Variable = str_remove_all(Variable, "tf_text_"),
        Variable = factor(Variable, Variable)
      )
  } else if (m == "elasticnet") {
    var_imp <- final_fit %>%
      tidy() %>%
      filter(
        penalty == best_params$penalty &
          term != "(Intercept)" &
          estimate > 0
      ) %>%
      arrange(desc(abs(estimate))) %>%
      select(-penalty) %>%
      mutate(
        term = str_remove_all(term, "tf_text_"),
        term = factor(term, term)
      )
  } else {
    # if you add more models, this is where you'd change the code to get
    # variable importance for that specific class of model output
    var_imp <- NA
  }
  
  return(
    list(
      best_id = best_id,
      best_params = best_params,
      var_imp = var_imp,
      fit = final_fit,
      ho_metrics = metrics_res
    )
  )
}

Probabilistic Photograph Manipulation with ggplot2 and imager

I started taking photos earlier this year. And as someone who loves thinking about probability, statistics, chance, randomness, and R programming, I started thinking about ways to apply probabilistic programming to photography. This is my first attempt.

I’m going to be using one shot I particularly like. It’s a tower on 47th between Wyandotte and Baltimore in Kansas City, Missouri—as seen from the parking garage roof above The Cheesecake Factory:

Through futzing around for a while, I developed an, uh, “algorithm,” sure, let’s call it that, to perturb and abstract a photograph. At a high level, what it is doing is changing location of pixels according to a uniform distribution and changing the colors according to a normal distribution.

The code for the following steps is found at the bottom of the page and linked to at my GitHub.

The Steps

  1. Represent a picture as a five-column data.frame, where each row is a pixel: Two columns for the x and y location, then three columns for red, green, and blue values that determine the color of that pixel.

  2. Pull one number from a uniform distribution bounded at .25 and .75. This is what I’ll call “jumble probability.”

  3. For each pixel, draw from a Bernoulli distribution with p set to that “jumble probability.”

  4. Take all of the pixels that drew a 1 in Step 3 and make them their own set. Then “jumble” them: Shuffle them around, re-arranging them randomly in the x-y plane.

All of the red, green, and blue values in the imager package are normalized from 0 to 1. And we want to nudge these around a bit, so:

  1. Take three draws from a normal distribution with a mean of 0 and a standard deviation of .1.

  2. From this distribution: Add the first draw to the red value, the second draw to the green value, and the third draw to the blue value.

  3. Wherever this leads to values greater than 1, make them 1; whenever this leads to values less than 0, make them 0. These three values make up the new color of the pixel.

With high-resolution images, you have a ton of pixels. My photograph had a data.frame with 24,000,000 rows. Trying to plot all of these took a lot of computing power—and frankly, I just did not want to wait that long to see the images. So, given this practical consideration, let’s add another bit of abstraction:

  1. Draw one number, let’s call it “pixel count,” from a uniform distribution bounded at 1,000 and 1,000,000. (Round to the nearest integer.)

  2. Randomly filter down to a subset of “pixel count” pixels.

This creates some white space, so I made each pixel a square point in ggplot2 and randomly varied the size:

  1. Draw a number from a uniform distribution bounded at 5 and 30, again rounding to the nearest integer, and use this as the size parameter in geom_point().

  2. Make a scatterplot with each row represented as a square.


The Result

I did this 100 times and used ImageMagick in the terminal (see code below) to make a .gif that shows 10 of these images every second. This gives us an interesting look at probability applied to an underlying image:


This is where I talk about how memory is reconstructive and abstract and how time distorts our memories. So every time we recall a memory, it’s slightly different in random ways. And this piece shows that. We never get the full image back, just fractured bits. Or, maybe this is where I talk about how we lay out all of our life plans—but life is chaos and random and stochastic, so this piece represents how even if we may control the general direction our life is headed, we don't end up quite there due to randomness inherent in human existence. Or this is where I say I just thought it was a fun .gif to make; read into it as much as you will.


R Code

library(imager)
library(tidyverse)

plot_point = function(img, n, ...) {
  ggplot(slice_sample(img, n = n), aes(x, y)) + 
    geom_point(aes(color = hex), ...) +
    scale_color_identity() +
    scale_y_reverse() +
    theme_void()
}

img <- load.image("20221112_DSC_0068_TP.JPG") # load image in

dims <- dim(img)[1:2] # get dimensions for exporting

# change to data frame
img_dat <- img %>% 
  as.data.frame(wide = "c") %>% 
  mutate(hex = rgb(c.1, c.2, c.3), xy = paste(x, y, sep = "_"))

# make up an "algorithm", do it like 100 times
set.seed(1839)
for (i in seq_len(100)) {
  cat("starting", i, "\n")
  
  # jumble with probability
  p_jumble <- runif(1, .25, .75)
  
  # figure out which points to jumble
  to_jumble <- as.logical(rbinom(nrow(img_dat), 1, p_jumble))
  
  # make a jumbled order, brb
  jumbled <- order(runif(sum(to_jumble)))
  
  # add some error to each color column
  # then turn to hex value
  c_err <- rnorm(3, 0, .1)
  img_dat_edit <- img_dat %>% 
    mutate(
      # need to make between 0 and 1
      c.1 = c.1 + c_err[1], 
      c.1 = ifelse(c.1 > 1, 1, c.1),
      c.1 = ifelse(c.1 < 0, 0, c.1),
      c.2 = c.2 + c_err[2], 
      c.2 = ifelse(c.2 > 1, 1, c.2),
      c.2 = ifelse(c.2 < 0, 0, c.2),
      c.3 = c.3 + c_err[3],
      c.3 = ifelse(c.3 > 1, 1, c.3),
      c.3 = ifelse(c.3 < 0, 0, c.3),
      hex = rgb(c.1, c.2, c.3)
    )
  
  # then use jumble to jumble the colors
  img_dat_edit$hex[to_jumble] <- img_dat_edit$hex[jumbled]
  
  # select n random pixels of random size
  n <- round(runif(1, 1000, 1000000))
  size = round(runif(1, 5, 30))
  
  # plot and save
  p <- plot_point(img_dat_edit, n, shape = "square", size = size)
  ggsave(
    paste0("plaza/plaza_iter_", i, ".png"),
    p,
    width = dims[1], 
    height = dims[2], 
    units = "px"
  )
}

There’s a way to make a .gif using the magick package for R, but it was creating a truly massive file and taking forever, so I used the underlying ImageMagick package in the command line.

convert -resize 15% -delay 10 -loop 0 -dispose previous plaza/*.png plaza.gif

Color-Swapping Film Palettes in R with imager, ggplot2, and kmeans

I like visual arts, but I’m moderately colorblind and thus have never been great at making my own works. When I’m plotting data and need colors, my standard procedure is having a website generate me a color palette or finding a visually pleasing one someone else has made and posted online.

I also love film, and I started thinking about ways I could generate color palettes from films that use color beautifully. There are a number of packages that can generate color palettes from images in R, but I wanted to try writing the code myself.

I also wanted to not just generate a color palette from an image, but then swapping it with a different color palette from a different film. This is similar to neural style transfer with TensorFlow, but much simpler. I’m one of those people that likes to joke how OLS is undefeated; I generally praise the use of simpler models over more complex ones. So instead of a neural network, I use k-means clustering to transfer a color palette of one still frame from a film onto another frame from a different movie.

Here’s the code for the functions I’ll be using. I’ll describe them in more detail below.

library(imager)
library(tidyverse)

norm <- function(x) (x - min(x)) / (max(x) - min(x))

shuffle <- function(x) x[sample(seq_along(x), length(x))]

get_palette <- function(filename, k, mdn = FALSE) {
  
  dat_pal <- load.image(filename) %>% 
    as.data.frame(wide = "c")
  
  res_pal <- dat_pal %>% 
    select(starts_with("c")) %>% 
    kmeans(k, algorithm = "Lloyd", iter.max = 500)
  
  if (!mdn) {
    pal <- res_pal$centers %>% 
      as_tibble() %>% 
      mutate(hex = rgb(c.1, c.2, c.3)) %>% 
      pull(hex)
  } else if (mdn) {
    pal <- dat_pal %>% 
      mutate(cluster = res_pal$cluster) %>% 
      group_by(cluster) %>% 
      summarise(across(starts_with("c"), median)) %>% 
      mutate(hex = rgb(c.1, c.2, c.3)) %>% 
      pull(hex)
  }
  
  return(pal)
}

make_plot <- function(filename_in, pal, xy = TRUE) {
  
  the_shot <- load.image(filename_in)
  
  dat_shot <- the_shot %>% 
    as.data.frame(wide = "c")
  
  dat_shot_norm <- dat_shot %>% 
    when(!xy ~ select(., starts_with("c")), ~ .) %>% 
    mutate(across(everything(), norm))
  
  res_shot <- kmeans(
    dat_shot_norm, 
    length(pal), 
    algorithm = "Lloyd", 
    iter.max = 500
  )
  
  dat_shot$clust <- factor(res_shot$cluster)
  
  p <- ggplot(dat_shot, aes(x = x, y = y)) +
    geom_raster(aes(fill = clust)) +
    scale_y_reverse() +
    theme_void() +
    theme(legend.position = "none") +
    scale_fill_manual(values = pal)
  
  return(list(plot = p, dims = dim(the_shot)[1:2]))
}

When I thought about transferring the color of one film onto an image from another film, two things came to mind immediately. The Umbrellas of Cherbourg is one of the most visually striking films I’ve ever watched; there’s such a dazzling variety of colors, and it displays a wide collection of unique wallpaper. As for what shot to impose those colors onto, one of my favorite shots is the “coffee scene” from Chungking Express.

The function get_palette() reads an image in using the package. This package allows you to decompose the image into a data frame, where each row is a pixel. There are x and y columns, which show where at in the image the pixel is when plotted. There are three additional columns that contain the RGB values. I k-means cluster the three RGB columns, using the built-in kmeans function and giving it an arbitrary k, and extract the average RGB values from each cluster (i.e., the cluster centers) and then convert them to hex values using the built-in rgb function.

And then make_plot() works similarly. It takes a file name, reads it in, and converts it to a data frame. This time, I allow the x and y columns to be used in the clustering. This means that clustering will be a mix of (a) what the original color was, and (b) where at in the frame the pixel is. All columns are normalized. I use the length of the color palette to determine k. I then plot it with , using the new color palette to fill-in according to the clustering of the new pixels. It’s more or less a coloring book, where the lines are determined by k-means clustering.

set.seed(1839)

pal1 <- get_palette("umbrellas.jpeg", k = 12)
plot1 <- make_plot("chungking.jpeg", pal = pal1)

ggsave(
  "chungking_k12.png", 
  plot1$plot, 
  width = plot1$dims[1], 
  height = plot1$dims[2],
  units = "px"
)

I write the file out to the same dimensions to preserve the integrity of the aspect ratio. Here’s the two original shots, and then the one produced with make_plot():

We can see that the new coloring is a blend of pixel location and the color of the original pixel.

I started playing around with other ideas, and include two new parts in this next image blend. First, I get the median value, instead of the mean, of the RGB values when clustering for the palette; and second, I shuffle up the order of the palette randomly before feeding it into the plotting function.

I wanted to apply a movie with warm colors to a movie with cool colors. My mind went to Her and Blade Runner, respectively.

set.seed(1839)

pal2 <- get_palette("her.jpeg", k = 3, mdn = TRUE)
plot2 <- make_plot("bladerunner.jpeg", pal = shuffle(pal2))

ggsave(
  "bladerunner_k3.png", 
  plot2$plot, 
  width = plot2$dims[1],
  height = plot2$dims[2],
  units = "px"
)

The originals:

And the blend:

What I like about this is that, since we include x and y in the clustering of the second image, we get different colors on either side of Roy Batty’s face.

I also wanted to see what the influence of taking out the x and y values would be. xy = FALSE removes any influence of where the pixel is placed in the image, so clustering is done purely on RGB values.

set.seed(1839)

pal3 <- get_palette("2001.jpeg", k = 5, mdn = TRUE)
plot3 <- make_plot("arrival.jpeg", pal = shuffle(pal3), xy = FALSE)

ggsave(
  "arrival_k5.png", 
  plot3$plot, 
  width = plot3$dims[1],
  height = plot3$dims[2],
  units = "px"
)

I wanted to combine these two shots from 2001: A Space Odyssey and Arrival because they visually rhyme with one another:

We can see in the color-blended image that colors fill in on spaces that are separated geographically from one another in the xy-plane of the image:

Compare this to another version I made, where I allowed x and y to be included:

We see that vertical line in the upper third of the shot forming due to the influence of x in the data. This also demonstrates overfitting: It’ll draw a line where two adjacent data points are functionally equivalent if you misspecify k. But for aesthetic purposes, overfitting isn’t necessarily a problem!

We also see a indistinct boundaries of one color into another here. The underlying image has few distinct lines—the entire image is ink drawn onto a wispy mist. What about when we get distinct lines and contrast? The easy answer for clean lines would have been to go to Wes Anderson here, but I felt like that was too expected from a blog post written by somebody such as myself. So instead, I took colors from the animated Lion King, a vibrant film, and projected it onto one of Roger Deakins’ best shots from Fargo.

set.seed(1839)

pal4 <- get_palette("lionking.jpeg", k = 2, mdn = TRUE)
plot4 <- make_plot("fargo.jpeg", pal = rev(pal4), xy = FALSE)

ggsave(
  "fargo_k2.png", 
  plot4$plot, 
  width = plot4$dims[1],
  height = plot4$dims[2],
  units = "px"
)

The last thing I wanted to do was look at a shot that had two primary colors and project it onto a black-and-white film, replacing that underlying dichotomy with two other colors.

set.seed(1839)

pal5 <- get_palette("killbill.jpeg", k = 2, mdn = TRUE)
plot5 <- make_plot("strangelove.jpeg", pal = pal5, xy = FALSE)

ggsave(
  "strangelove_k2.png", 
  plot5$plot, 
  width = plot5$dims[1],
  height = plot5$dims[2],
  units = "px"
)

The first shot below from Kill Bill Vol. 1 came to mind for a shot that was mostly two colors, while I went with my favorite scene from Dr. Strangelove, perhaps the funniest film ever made, for the black-and-white still:

The functions are above and the full code is at my GitHub. Try playing with the functions and blending images; it’s fun, but it also a visual guide that helps you truly understand what exactly k-means clustering is doing.