Many organizations have a substantial amount of human-generated text
from which they are not extracting a proportional amount of insight. For
example, open-ended questions are found in most surveys—but are rarely
given the same amount of attention (if any attention at all) as the
easier-to-analyze quantitative data. I have tested out many supposedly
“AI-powered” or “NLP-driven” tools for analyzing text in my career, and
I haven’t found anything to be useful at finding topics or modeling
sentiment when fed real data. I wrote on my
reservations about common topic modeling methods over four years
ago, where I showed how I perform exploratory analysis on text data
based on word co-occurrences. That was an unsupervised
approach: No a priori topics are given to a model to learn from. It
looks at patterns of how frequently words are used together to infer
topics.
I lay out my approach for supervised topic modeling in short
texts (e.g., open-response survey data) here. My philosophy is one where
there is no free lunch: If you want data that make sense for your
specific organization and use case, you’re gonna actually have to
read and consider the text you get. This isn’t using
statistical learning to do the job for you—it is using these models as a
tool to work with. You can’t solve this purely from the command
line. There will need to be many humans in the loop. This process can be
a grind, but what you’ll come out of it with is a scalable, bespoke
model. I don’t think I’m saying anything particularly revolutionary in
this post, but hopefully the walk-through and code might help you
develop a similar workflow.
The working example I’m using here are Letterboxd reviews for Wes
Anderson’s newest film, Asteroid City.
Step 1: Thematic Content Analysis
Create a coding corpus that is a random sample (maybe 2000 cases,
depending on how long each piece of text is, how much bandwidth you have
for this project, etc.) from your entire corpus of text (e.g., reviews,
open-ended responses, help tickets). Perform a thematic
content analysis in partnership with colleagues. I am not a
qualitative methodologist by any means, but I start by reading the
entirety of the random sample. Then I read through the coding corpus
again, taking notes about key themes. I look at these themes, and I
think about how they may be grouped together. I talk with my colleagues
to see what they think. We discuss if certain themes can be combined or
if they need to be kept separate. We talk about the maximum number of
themes we think we should be looking for. We finalize themes, discuss
what they mean in detail, and name them. From these readings and
conversations, I write a standardized coding manual that others can use
to read responses and code according to the themes.
In my example, I wrote an R script (see Appendix A or my
GitHub) to scrape Letterboxd for reviews. For each possible rating
(0.5 stars to 5.0 stars, in increments of half stars), I pulled six
pages. Each page had twelve reviews. Ten possible ratings, six pages
each, and twelve reviews per page multiplies out to 720 reviews. In the
process of coding, I dropped any review that was not written in
English—a limitation of mine here due to me making the regrettable
decision to study Latin in college (although an apropos course
for a Wes Anderson blog post).
I decided to stick to just one theme for the purposes of this
exercise: Did the review discuss the visual style of the film? Anderson
is known for his distinctive visual style, which at this point is
unfortunately a bit of a cultural meme. There was a whole
Instagram/TikTok trend of posting videos of oneself in the style of
Anderson; there’s those cynical, intellectually
vapid, artless, AI-generated simulacra imagining [insert well-known
intellectual property here] in the Andersonian style; and there’s even a
Hertz (the rent-a-car company) commercial that references the meme,
which in turn references the style (Baudrillard’s head would be
spinning, folks).
This visual style, by my eye, was solidified by the time of The
Grand Budapest Hotel. Symmetrical framing, meticulously organized
sets, the use of miniatures, a pastel color palette contrasted with
saturated primary colors, distinctive costumes, straight lines, lateral
tracking shots, whip pans, and so on. However, I did not consider
aspects of Anderson’s style that are unrelated to the visual sense. This
is where defining the themes with a clear line matters—and often there
will be ambiguities, but one must do their best, because the process
we’re doing is fundamentally a simplification of the rich diversity of
the text. Thus, I did not consider the following to be in Anderson’s
visual style: stories involving precocious children, fraught
familial relations, uncommon friendships, dry humor, monotonous
dialogue, soundtracks usually involving The Kinks, a fascination with
stage productions, nesting doll narratives, or a decidedly twee yet
bittersweet tone.
Step 2: Independent Coders
This is where you will need organizational buy-in and the help of
your colleagues. Recruit fellow researchers (or subject experts) that
have domain expertise. For example, if the text data are all reviews for
a specific app, make sure you recruit colleagues that have knowledge
about that specific app. I generate standardized Google Sheets where
each row is a text response, each column is a theme, and I send
them—along with the coding manual—to the coders. They are instructed to
put in a “1” wherever a text touches on that theme as is described in
the coding manual. I have each piece of text coded by two independent
coders. After everyone is done, I resolve disagreements by a third-party
vote (which may or may not be myself). Combine the various Google Sheets
(I use the
R package and the
suite of packages for merging these data), and
congratulations! You have supervised text data that are ready for
training.
For Asteroid City reviews, I just did this on a Saturday
afternoon myself. Any piece of text that touched on the visual style was
flagged with a “1,” while the rest were marked as “0.” The text I pulled
was abbreviated for longer reviews; however, I don’t mind this, as I
feel it stays true to the analogy of short open-ended survey responses
that way.
Step 3: Train the Models
This is where I stand on the shoulders of Emil Hvitfeldt and Julia
Silge by way of their book, Supervised
Machine Learning for Text Analysis in R. I define an entire
machine learning workflow (i.e., data processing, cross-validation,
model selection, and holdout performance), and then I package this
workflow into convenient wrapper functions. This allows me to use that
wrapper function on each column, so that every theme gets its own model.
I write the cross-validation results and final models out as RDS files
so that I can access them later.
The wrapper functions are in Appendix B below and at my
GitHub. It’s, uh, long. (I go against my own imperative to
write modular code here, but if there’s a less verbose way of defining
this workflow, please genuinely let me know.) I am making three basic
functions: one to create the train/test split, another to perform
cross-validation, and a third to pull the best model.
The prep_data
function takes your data, the name of the
outcome and text variables, and how large of a holdout set you want.
You can change exactly what pre-processing steps and algorithms you
want to use from what I have. But I have found that the processing of
the data usually is more influential than a specific model. So in the
wrapper function do_cv
, I am trying out all combinations of
the following:
- Stop words: None, SMART, or ISO (different lists of stop
words)
- N-grams: Words only or words and bigrams
- Stemming: Yes or no
- Word frequency filtering: Words used 1 or more times in the corpus,
2 or more times, or 5 or more
- Algorithm: Elastic net or random forest
This leaves 3 * 2 * 2 * 3 * 2 = 72 combinations of pre-processors and
algorithm combinations. The models were tuned according to grid search
(the size of which is an argument to the do_cv
function)
using k-fold validation (where k is also an argument). You could add
other pre-processing steps (e.g., word embeddings) or other algorithms
(e.g., XGBoost, neural networks) to the body of this function.
When that’s done, the best_model
function will pull
everything you need to know about the model that performed best in
cross-validation, and it will give you performance metrics on the
holdout set.
When all is said and done, the actual running of the code then looks
deceptively simple:
library(tidyverse)
source("funs.R")
dat <- read_csv("ratings_coded.csv") %>%
filter(visual_style != 9) # remove non-english
set.seed(1839)
dat <- prep_data(dat, "visual_style", "text", .20)
cv_res <- do_cv(training(dat), 4, 10)
write_rds(cv_res, "visual_style-res.rds")
mod <- best_model(dat, cv_res, "roc_auc")
Where funs.R
is a file that includes the functions in
Appendix B. This has a 20% holdout set, 4 folds, and a grid search with
size 10. It then writes out those results to an RDS object. Lastly, we
pull the best model according to the area under the ROC curve. If you
had multiple codes, you could write a for
loop here for the
name of each column (e.g.,
prep_data(dat, i, "text", .20)
).
Step 4: Use the Models
First, let’s check out what some of these objects look like. Let’s
take a look at the metrics for the first model that went into the
cross-validation results set.
cv_res <- read_rds("visual_style-res.rds")
cv_res$cv_res$result[[1]]$.metrics
## [[1]]
## # A tibble: 60 × 6
## penalty mixture .metric .estimator .estimate .config
## <dbl> <dbl> <chr> <chr> <dbl> <chr>
## 1 0.000000787 0.0698 accuracy binary 0.774 Preprocessor1_Model01
## 2 0.000000787 0.0698 sensitivity binary 0.923 Preprocessor1_Model01
## 3 0.000000787 0.0698 specificity binary 0 Preprocessor1_Model01
## 4 0.000000787 0.0698 precision binary 0.828 Preprocessor1_Model01
## 5 0.000000787 0.0698 f_meas binary 0.873 Preprocessor1_Model01
## 6 0.000000787 0.0698 roc_auc binary 0.777 Preprocessor1_Model01
## 7 0.000374 0.166 accuracy binary 0.774 Preprocessor1_Model02
## 8 0.000374 0.166 sensitivity binary 0.923 Preprocessor1_Model02
## 9 0.000374 0.166 specificity binary 0 Preprocessor1_Model02
## 10 0.000374 0.166 precision binary 0.828 Preprocessor1_Model02
## # ℹ 50 more rows
##
## [[2]]
## # A tibble: 60 × 6
## penalty mixture .metric .estimator .estimate .config
## <dbl> <dbl> <chr> <chr> <dbl> <chr>
## 1 0.000000787 0.0698 accuracy binary 0.767 Preprocessor1_Model01
## 2 0.000000787 0.0698 sensitivity binary 0.913 Preprocessor1_Model01
## 3 0.000000787 0.0698 specificity binary 0.286 Preprocessor1_Model01
## 4 0.000000787 0.0698 precision binary 0.808 Preprocessor1_Model01
## 5 0.000000787 0.0698 f_meas binary 0.857 Preprocessor1_Model01
## 6 0.000000787 0.0698 roc_auc binary 0.714 Preprocessor1_Model01
## 7 0.000374 0.166 accuracy binary 0.767 Preprocessor1_Model02
## 8 0.000374 0.166 sensitivity binary 0.913 Preprocessor1_Model02
## 9 0.000374 0.166 specificity binary 0.286 Preprocessor1_Model02
## 10 0.000374 0.166 precision binary 0.808 Preprocessor1_Model02
## # ℹ 50 more rows
##
## [[3]]
## # A tibble: 60 × 6
## penalty mixture .metric .estimator .estimate .config
## <dbl> <dbl> <chr> <chr> <dbl> <chr>
## 1 0.000000787 0.0698 accuracy binary 0.667 Preprocessor1_Model01
## 2 0.000000787 0.0698 sensitivity binary 1 Preprocessor1_Model01
## 3 0.000000787 0.0698 specificity binary 0.167 Preprocessor1_Model01
## 4 0.000000787 0.0698 precision binary 0.643 Preprocessor1_Model01
## 5 0.000000787 0.0698 f_meas binary 0.783 Preprocessor1_Model01
## 6 0.000000787 0.0698 roc_auc binary 0.824 Preprocessor1_Model01
## 7 0.000374 0.166 accuracy binary 0.667 Preprocessor1_Model02
## 8 0.000374 0.166 sensitivity binary 1 Preprocessor1_Model02
## 9 0.000374 0.166 specificity binary 0.167 Preprocessor1_Model02
## 10 0.000374 0.166 precision binary 0.643 Preprocessor1_Model02
## # ℹ 50 more rows
##
## [[4]]
## # A tibble: 60 × 6
## penalty mixture .metric .estimator .estimate .config
## <dbl> <dbl> <chr> <chr> <dbl> <chr>
## 1 0.000000787 0.0698 accuracy binary 0.733 Preprocessor1_Model01
## 2 0.000000787 0.0698 sensitivity binary 0.913 Preprocessor1_Model01
## 3 0.000000787 0.0698 specificity binary 0.143 Preprocessor1_Model01
## 4 0.000000787 0.0698 precision binary 0.778 Preprocessor1_Model01
## 5 0.000000787 0.0698 f_meas binary 0.84 Preprocessor1_Model01
## 6 0.000000787 0.0698 roc_auc binary 0.745 Preprocessor1_Model01
## 7 0.000374 0.166 accuracy binary 0.7 Preprocessor1_Model02
## 8 0.000374 0.166 sensitivity binary 0.870 Preprocessor1_Model02
## 9 0.000374 0.166 specificity binary 0.143 Preprocessor1_Model02
## 10 0.000374 0.166 precision binary 0.769 Preprocessor1_Model02
## # ℹ 50 more rows
What we see here is the 6 metrics for each of the 10 tunings from the
grid search. We get four tibbles here—one for each fold. The
.config
tells us that this is pre-processor 1. What does
this refer to?
cv_res$wfs$wflow_id[[1]]
## [1] "word_nostop_nostem_f2_elasticnet"
This is looking at just words, no stop words, no stemming, a minimum
use of 2 in the corpus, and the elastic net.
But what we want to know is which model is best. We used
best_model
for that above.
mod$best_id
## [1] "word_smart_stemmed_f2_randforest"
mod$best_params
## # A tibble: 1 × 3
## mtry min_n .config
## <int> <int> <chr>
## 1 20 34 Preprocessor1_Model09
What performed best in cross-validation was only looking at words,
using the SMART dictionary for stop words, stemming, only considering
words used at least twice in the corpus, and using the random forest
while considering 20 words at a time with a minimum leaf size of 34.
We can also look at the holdout metrics:
mod$ho_metrics
## # A tibble: 1 × 7
## accuracy precision sensitivity specificity est_pct_class est_pct_prob act_pct
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0.765 0.762 0.994 0.110 0.0327 0.191 0.259
You’ll be familiar with the first four, but the last three are ones I
use. Often, I am using these types of models to estimate, “How many
people are talking about X theme?” So, I don’t care as much about
individual-level predictions as I do about the aggregate. I do this by
averaging up the estimated probabilities (est_pct_prob
) and
using that as my estimate. So here, I would say that about 19% of
reviews are discussing Anderson’s visual style. The actual percent
(act_pct
) is a bit more at about 26%. But not bad for a
Saturday afternoon.
What about variable importance? What words are helping us figure out
if the review touches on Wes’s visual style or not?
print(mod$var_imp, n = 10)
## # A tibble: 368 × 2
## Variable Importance
## <fct> <dbl>
## 1 color 1.56
## 2 visual 1.35
## 3 cool 1.02
## 4 audienc 0.994
## 5 shot 0.940
## 6 plot 0.751
## 7 design 0.575
## 8 pretti 0.554
## 9 set 0.503
## 10 landscap 0.492
## # ℹ 358 more rows
And yeah, this makes sense. The last thing we can do is plug in
new reviews that have gone up since I made the model and
generate predictions for them. The first and the third talk about visual
style, while the others do not.
new_reviews <- tibble(
text = c("Kodak film and miniatures make for a cool movie",
"To me, Wes never misses",
paste(
"Visuals and characters were stunning as always but the dual plot",
"was hard to follow and the whole thing felt a bit hollow"
),
"my head hurts idk how i feel about this yet",
"Me after asking AI to write me a wes anderson script")
)
# predict on these
predict(mod$fit, new_reviews, type = "prob")
## # A tibble: 5 × 2
## .pred_0 .pred_1
## <dbl> <dbl>
## 1 0.675 0.325
## 2 0.955 0.0448
## 3 0.545 0.455
## 4 0.968 0.0317
## 5 0.870 0.130
Unfortunately, we don’t get hits on our two that actually are about
his style. In the aggregate, we see 40% of these new reviews about his
style. Averaging up the .pred_1
column only gives us about
20%, but we can’t be surprised about that big of a miss from just five
new reviews.
Step 5: Scale and Maintain
You’ve got the best model(s) in-hand. What now? You scale it! In Step
1, we only looked at a random sample of maybe 2000 pieces of text. But
you may have much, much more than that; you might be continuously
collecting new text data that adhere to the coding manual. Use the
predict()
function on those cases. Then what I do is
average up the predicted probabilities to get an estimate of how many
people are talking about a given theme. But you could also use the
predicted probabilities to decide which users to email about a specific
theme, for example. Since your coding scheme came from a specific prompt
or data source, make sure that you are applying it only to text that
come from this same generating mechanism (e.g., survey question,
feedback prompt, ticketing system).
What is absolutely crucial here is keeping humans in
the loop. As new data come in, figure out a regular cadence where you
generate a (smaller than the original) random sample of data, code it,
and then see how your predictions are holding up to new data.
(After checking performance, these newly coded cases can then
be added to the training set and the model can be updated to include
them.) Also, read through these data to see if the coding manual needs
to be updated, themes are no longer relevant, or new themes need to be
added. This is, again, a grind of a process. But it is a way of scaling
up the unique expertise of the humans in your organization to data that
generally doesn’t get much attention. And, in my experience, keeping
humans in the loop is not only good for model performance, but reading
real responses from real people helps you as a researcher understand
your data and your users (or respondents or whoever is generating the
data) better.
Appendix A
library(rvest)
library(tidyverse)
# 12 reviews per page
# 10 possible ratings
# 6 pages each
# n = 720 reviews total
base_url1 <- "https://letterboxd.com/film/asteroid-city/reviews/rated/"
base_url2 <- "/by/activity/page/"
ratings <- seq(.5, 5, .5)
pages <- 6
# people often say to use purrr::map or an apply statement instead of a for
# loop. but if it fails, I want to know where, and I want to keep the data
# we already collected. so I'm initializing an empty data frame and then
# slotting the data in one page at a time
dat <- tibble(rating = vector("double"), text = vector("character"))
for (r in ratings) {
pg <- 1
while (pg < 7) {
url <- paste0(base_url1, r, base_url2, pg)
txt <- url %>%
read_html() %>%
html_nodes(".collapsible-text") %>%
html_text2() %>%
map_chr(str_replace_all, fixed("\n"), " ")
dat <- bind_rows(dat, tibble(rating = r, text = txt))
cat("finished page", pg, "of", r, "stars\n")
pg <- pg + 1
Sys.sleep(runif(1, 0, 10))
}
}
write_csv(dat, "ratings_uncoded.csv")
Appendix B
# funs -------------------------------------------------------------------------
library(textrecipes)
library(vip)
library(stopwords)
library(tidymodels)
library(workflowsets)
library(tidyverse)
# prepare data for modeling, do train/test split
# default at a holdout set of 15%
# stratify on the outcome variable
prep_data <- function(dat, y, txt, prop) {
dat %>%
transmute(text = .data[[txt]], y = factor(.data[[y]])) %>%
initial_split(prop, strata = y)
}
# do cross-validation with the same, pre-defined, hard-coded engines and recipes
# n_folds is number of folds; grid_size is the size of the grid
do_cv <- function(dat_train, n_folds, grid_size) {
# define recipes -------------------------------------------------------------
## base ----------------------------------------------------------------------
rec_base <- recipe(y ~ text, dat_train)
## tokenize ------------------------------------------------------------------
rec_word_nostop <- rec_base %>%
step_tokenize(
text,
token = "words"
)
rec_word_smart <- rec_base %>%
step_tokenize(
text,
token = "words",
options = list(stopwords = stopwords(source = "smart"))
)
rec_word_iso <- rec_base %>%
step_tokenize(
text,
token = "words",
options = list(stopwords = stopwords(source = "stopwords-iso"))
)
rec_both_nostop <- rec_base %>%
step_tokenize(
text,
token = "skip_ngrams",
options = list(
n = 2,
k = 0
)
)
rec_both_smart <- rec_base %>%
step_tokenize(
text,
token = "skip_ngrams",
options = list(
stopwords = stopwords(source = "smart"),
n = 2,
k = 0
)
)
rec_both_iso <- rec_base %>%
step_tokenize(
text,
token = "skip_ngrams",
options = list(
stopwords = stopwords(source = "stopwords-iso"),
n = 2,
k = 0
)
)
## stem ----------------------------------------------------------------------
rec_word_nostop_stemmed <- rec_word_nostop %>%
step_stem(text)
rec_word_smart_stemmed <- rec_word_smart %>%
step_stem(text)
rec_word_iso_stemmed <- rec_word_iso %>%
step_stem(text)
rec_both_nostop_stemmed <- rec_both_nostop %>%
step_stem(text)
rec_both_smart_stemmed <- rec_both_smart %>%
step_stem(text)
rec_both_iso_stemmed <- rec_both_iso %>%
step_stem(text)
## filter, weight ------------------------------------------------------------
rec_word_nostop_f2 <- rec_word_nostop %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart_f2 <- rec_word_smart %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso_f2 <- rec_word_iso %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop_f2 <- rec_both_nostop %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart_f2 <- rec_both_smart %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso_f2 <- rec_both_iso %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_nostop_stemmed_f2 <- rec_word_nostop_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart_stemmed_f2 <- rec_word_smart_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso_stemmed_f2 <- rec_word_iso_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop_stemmed_f2 <- rec_both_nostop_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart_stemmed_f2 <- rec_both_smart_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso_stemmed_f2 <- rec_both_iso_stemmed %>%
step_tokenfilter(text, min_times = 2, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_nostop_f5 <- rec_word_nostop %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart_f5 <- rec_word_smart %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso_f5 <- rec_word_iso %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop_f5 <- rec_both_nostop %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart_f5 <- rec_both_smart %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso_f5 <- rec_both_iso %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_nostop_stemmed_f5 <- rec_word_nostop_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart_stemmed_f5 <- rec_word_smart_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso_stemmed_f5 <- rec_word_iso_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop_stemmed_f5 <- rec_both_nostop_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart_stemmed_f5 <- rec_both_smart_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso_stemmed_f5 <- rec_both_iso_stemmed %>%
step_tokenfilter(text, min_times = 5, max_tokens = 5000) %>%
step_tf(text, weight_scheme = "binary")
rec_word_nostop <- rec_word_nostop %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart <- rec_word_smart %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso <- rec_word_iso %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop <- rec_both_nostop %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart <- rec_both_smart %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso <- rec_both_iso %>%
step_tf(text, weight_scheme = "binary")
rec_word_nostop_stemmed <- rec_word_nostop_stemmed %>%
step_tf(text, weight_scheme = "binary")
rec_word_smart_stemmed <- rec_word_smart_stemmed %>%
step_tf(text, weight_scheme = "binary")
rec_word_iso_stemmed <- rec_word_iso_stemmed %>%
step_tf(text, weight_scheme = "binary")
rec_both_nostop_stemmed <- rec_both_nostop_stemmed %>%
step_tf(text, weight_scheme = "binary")
rec_both_smart_stemmed <- rec_both_smart_stemmed %>%
step_tf(text, weight_scheme = "binary")
rec_both_iso_stemmed <- rec_both_iso_stemmed %>%
step_tf(text, weight_scheme = "binary")
## define specs --------------------------------------------------------------
spec_elasticnet <- logistic_reg(
mode = "classification",
engine = "glmnet",
penalty = tune(),
mixture = tune()
)
spec_randforest <- rand_forest(
mode = "classification",
mtry = tune(),
min_n = tune(),
trees = 500
) %>%
set_engine(engine = "ranger", importance = "impurity")
# make workflowset -----------------------------------------------------------
wfs <- workflow_set(
preproc = list(
word_nostop_nostem_f2 = rec_word_nostop_f2,
word_smart_nostem_f2 = rec_word_smart_f2,
word_iso_nostem_f2 = rec_word_iso_f2,
both_nostop_nostem_f2 = rec_both_nostop_f2,
both_smart_nostem_f2 = rec_both_smart_f2,
both_iso_nostem_f2 = rec_both_iso_f2,
word_nostop_stemmed_f2 = rec_word_nostop_stemmed_f2,
word_smart_stemmed_f2 = rec_word_smart_stemmed_f2,
word_iso_stemmed_f2 = rec_word_iso_stemmed_f2,
both_nostop_stemmed_f2 = rec_both_nostop_stemmed_f2,
both_smart_stemmed_f2 = rec_both_smart_stemmed_f2,
both_iso_stemmed_f2 = rec_both_iso_stemmed_f2,
word_nostop_nostem_f5 = rec_word_nostop_f5,
word_smart_nostem_f5 = rec_word_smart_f5,
word_iso_nostem_f5 = rec_word_iso_f5,
both_nostop_nostem_f5 = rec_both_nostop_f5,
both_smart_nostem_f5 = rec_both_smart_f5,
both_iso_nostem_f5 = rec_both_iso_f5,
word_nostop_stemmed_f5 = rec_word_nostop_stemmed_f5,
word_smart_stemmed_f5 = rec_word_smart_stemmed_f5,
word_iso_stemmed_f5 = rec_word_iso_stemmed_f5,
both_nostop_stemmed_f5 = rec_both_nostop_stemmed_f5,
both_smart_stemmed_f5 = rec_both_smart_stemmed_f5,
both_iso_stemmed_f5 = rec_both_iso_stemmed_f5,
word_nostop_nostem_f0 = rec_word_nostop,
word_smart_nostem_f0 = rec_word_smart,
word_iso_nostem_f0 = rec_word_iso,
both_nostop_nostem_f0 = rec_both_nostop,
both_smart_nostem_f0 = rec_both_smart,
both_iso_nostem_f0 = rec_both_iso,
word_nostop_stemmed_f0 = rec_word_nostop_stemmed,
word_smart_stemmed_f0 = rec_word_smart_stemmed,
word_iso_stemmed_f0 = rec_word_iso_stemmed,
both_nostop_stemmed_f0 = rec_both_nostop_stemmed,
both_smart_stemmed_f0 = rec_both_smart_stemmed,
both_iso_stemmed_f0 = rec_both_iso_stemmed
),
models = list(
elasticnet = spec_elasticnet,
randforest = spec_randforest
),
cross = TRUE
)
# do cross-validation
folds <- vfold_cv(dat_train, v = n_folds)
# get result of cross-validation
cv_res <- wfs %>%
workflow_map(
"tune_grid",
grid = grid_size,
resamples = folds,
metrics = metric_set(
accuracy,
sensitivity,
specificity,
precision,
f_meas,
roc_auc
),
verbose = TRUE
)
# return the entire workflow set and the results of cross-validation
return(list(wfs = wfs, cv_res = cv_res))
}
# take results from cross-validation, get a final model and held-out metrics
best_model <- function(dat, cv_out, metric = "roc_auc") {
# get the id of the best model, as according to the specified metric
best_id <- cv_out$cv_res %>%
rank_results(rank_metric = metric) %>%
filter(.metric == metric & rank == 1) %>%
pull(wflow_id)
# get the name of the model
# if you name the workflow sets differently, this step will change
# it's based on the hardcoded names I gave them in defining the workflow set
m <- str_split(best_id, "_")[[1]][[5]]
# get best parameters, do final fit
best_params <- cv_out$cv_res %>%
extract_workflow_set_result(best_id) %>%
select_best(metric = metric)
final_fit <- cv_out$cv_res %>%
extract_workflow(best_id) %>%
finalize_workflow(best_params) %>%
fit(training(dat))
# run it on the holdout data
dat_test <- testing(dat)
# if glmnet, feed it the right penalty
# I'm pretty sure this is necessary,
# because glmnet fits many lambda as it's more efficient
if (m == "elasticnet") {
dat_test <- bind_cols(
dat_test,
predict(final_fit, dat_test, penalty = best_params$penalty),
predict(final_fit, dat_test, type = "prob", penalty = best_params$penalty)
)
} else {
dat_test <- bind_cols(
dat_test,
predict(final_fit, dat_test),
predict(final_fit, dat_test, type = "prob")
)
}
# define metrics to output
ms <- metric_set(accuracy, sensitivity, specificity, precision)
# return metrics
metrics_res <- ms(dat_test, truth = y, estimate = .pred_class) %>%
select(-.estimator) %>%
spread(.metric, .estimate) %>%
mutate(
est_pct_class = mean(dat_test$.pred_class == 1),
est_pct_prob = mean(dat_test$.pred_1),
act_pct = mean(dat_test$y == 1)
)
# get variable importance, which depends on the model
if (m == "randforest") {
var_imp <- final_fit %>%
extract_fit_parsnip() %>%
vi() %>%
mutate(
Variable = str_remove_all(Variable, "tf_text_"),
Variable = factor(Variable, Variable)
)
} else if (m == "elasticnet") {
var_imp <- final_fit %>%
tidy() %>%
filter(
penalty == best_params$penalty &
term != "(Intercept)" &
estimate > 0
) %>%
arrange(desc(abs(estimate))) %>%
select(-penalty) %>%
mutate(
term = str_remove_all(term, "tf_text_"),
term = factor(term, term)
)
} else {
# if you add more models, this is where you'd change the code to get
# variable importance for that specific class of model output
var_imp <- NA
}
return(
list(
best_id = best_id,
best_params = best_params,
var_imp = var_imp,
fit = final_fit,
ho_metrics = metrics_res
)
)
}