rldatopic-modelingquanteda

R: Quanteda+LDA, how to Visualise the Results?


Please have a look at the snippet at the end of this post. I run a simplified tutorial example of topic modeling with quanteda, but once the model has finished running, I find it difficult to extract the word with the highest probabilities in each topic and visualize them as Julia Silge does in the example mentioned in the reprex.

Any suggestion is welcome.

library(seededlda)
#> Loading required package: quanteda
#> Package version: 3.3.1
#> Unicode version: 13.0
#> ICU version: 67.1
#> Parallel computing: 12 of 12 threads used.
#> See https://quanteda.io for tutorials and examples.
#> Loading required package: proxyC
#> 
#> Attaching package: 'proxyC'
#> The following object is masked from 'package:stats':
#> 
#>     dist
#> 
#> Attaching package: 'seededlda'
#> The following object is masked from 'package:stats':
#> 
#>     terms
library(quanteda)
library(RCurl)
library(lubridate)
#> 
#> Attaching package: 'lubridate'
#> The following objects are masked from 'package:base':
#> 
#>     date, intersect, setdiff, union

## See https://koheiw.github.io/seededlda/articles/pkgdown/basic.html


url='https://www.dropbox.com/s/abme18nlrwxgmz8/data_corpus_sputnik2022.rds?raw=1'
download.file(url,
              destfile="sputnik.RDS",
              method="auto")

corp_all <- readRDS("sputnik.RDS")

corp <- corpus_subset(corp_all, date> "2022-11-29")

toks <- tokens(corp, remove_punct = TRUE, remove_symbols = TRUE, 
               remove_numbers = TRUE, remove_url = TRUE)
dfmt <- dfm(toks) |> 
    dfm_remove(stopwords("en")) |>
    dfm_remove("*@*") |>
    dfm_trim(max_docfreq = 0.1, docfreq_type = "prop")
print(dfmt)
#> Document-feature matrix of: 550 documents, 17,190 features (99.14% sparse) and 4 docvars.
#>              features
#> docs          spanish firm instalaza city zaragoza received similar exploded
#>   s1104914730       5    1         3    1        2        2       2        2
#>   s1104912678       0    0         0    0        0        0       0        0
#>   s1104910731       0    0         0    0        0        0       1        0
#>   s1104906969       0    0         0    0        0        0       1        0
#>   s1104905548       0    0         0    0        0        0       2        0
#>   s1104891116       0    0         0    0        0        1       0        0
#>              features
#> docs          near embassy
#>   s1104914730    2       4
#>   s1104912678    0       0
#>   s1104910731    0       0
#>   s1104906969    0       0
#>   s1104905548    2       1
#>   s1104891116    0       0
#> [ reached max_ndoc ... 544 more documents, reached max_nfeat ... 17,180 more features ]

lda <- textmodel_lda(dfmt, k = 5, verbose = TRUE)
#> Fitting LDA with 5 topics
#>  ...initializing
#>  ...Gibbs sampling in 2000 iterations
#>  ......iteration 100 elapsed time: 1.57 seconds (delta: -0.02%)
#>  ......iteration 200 elapsed time: 2.90 seconds (delta: 0.13%)
#>  ......iteration 300 elapsed time: 4.23 seconds (delta: 0.12%)
#>  ......iteration 400 elapsed time: 5.54 seconds (delta: 0.02%)
#>  ......iteration 500 elapsed time: 6.86 seconds (delta: 0.08%)
#>  ......iteration 600 elapsed time: 8.16 seconds (delta: -0.10%)
#>  ......iteration 700 elapsed time: 9.56 seconds (delta: -0.03%)
#>  ......iteration 800 elapsed time: 11.16 seconds (delta: -0.03%)
#>  ......iteration 900 elapsed time: 12.72 seconds (delta: 0.01%)
#>  ......iteration 1000 elapsed time: 14.32 seconds (delta: -0.07%)
#>  ......iteration 1100 elapsed time: 15.92 seconds (delta: -0.06%)
#>  ......iteration 1200 elapsed time: 17.33 seconds (delta: -0.13%)
#>  ......iteration 1300 elapsed time: 18.83 seconds (delta: 0.19%)
#>  ......iteration 1400 elapsed time: 20.36 seconds (delta: -0.07%)
#>  ......iteration 1500 elapsed time: 21.68 seconds (delta: 0.03%)
#>  ......iteration 1600 elapsed time: 22.99 seconds (delta: 0.01%)
#>  ......iteration 1700 elapsed time: 24.30 seconds (delta: 0.04%)
#>  ......iteration 1800 elapsed time: 25.60 seconds (delta: 0.04%)
#>  ......iteration 1900 elapsed time: 26.91 seconds (delta: -0.01%)
#>  ......iteration 2000 elapsed time: 28.22 seconds (delta: -0.08%)
#>  ...computing theta and phi
#>  ...complete

knitr::kable(terms(lda))
topic1 topic2 topic3 topic4 topic5
twitter today crude joins patriot
police know grain bill missiles
data soviet exports journalist aircraft
photo rose electricity joined donbass
british crimea african republicans pentagon
musk un products trump soldiers
newspaper always french 🇺🇦 iraq
company charlie g7 author equipment
violence important india senate training
screenshot believe natural hour shoigu

##How can I visualise the results in the style of what I see here

## https://juliasilge.com/blog/sherlock-holmes-stm/

## and I refer in particular to Figure entitled "Highest word probabilities for each topic" ?

sessionInfo()
#> R version 4.3.1 (2023-06-16)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Debian GNU/Linux 12 (bookworm)
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.11.0 
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.11.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
#>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
#>  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
#> 
#> time zone: Europe/Brussels
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] lubridate_1.9.3 RCurl_1.98-1.12 seededlda_1.1.0 proxyC_0.3.4   
#> [5] quanteda_3.3.1 
#> 
#> loaded via a namespace (and not attached):
#>  [1] vctrs_0.6.4        cli_3.6.1          knitr_1.44         rlang_1.1.1       
#>  [5] xfun_0.40          stringi_1.7.12     purrr_1.0.2        styler_1.10.2     
#>  [9] generics_0.1.3     RcppParallel_5.1.7 glue_1.6.2         htmltools_0.5.6.1 
#> [13] rmarkdown_2.25     R.cache_0.16.0     grid_4.3.1         evaluate_0.22     
#> [17] bitops_1.0-7       fastmap_1.1.1      yaml_2.3.7         lifecycle_1.0.3   
#> [21] compiler_4.3.1     fs_1.6.3           timechange_0.2.0   fastmatch_1.1-4   
#> [25] Rcpp_1.0.11        R.oo_1.25.0        R.utils_2.12.2     lattice_0.21-9    
#> [29] digest_0.6.33      reprex_2.0.2       stopwords_2.3      magrittr_2.0.3    
#> [33] R.methodsS3_1.8.2  Matrix_1.6-1.1     tools_4.3.1        withr_2.5.1

Created on 2023-10-30 with reprex v2.0.2


Solution

  • The code below (which may not be the cleanest) achieves what I want.

    library(tidyverse)
    library(quanteda)
    #> Package version: 3.3.1
    #> Unicode version: 13.0
    #> ICU version: 67.1
    #> Parallel computing: 4 of 4 threads used.
    #> See https://quanteda.io for tutorials and examples.
    library(seededlda)
    #> Loading required package: proxyC
    #> 
    #> Attaching package: 'proxyC'
    #> The following object is masked from 'package:stats':
    #> 
    #>     dist
    #> 
    #> Attaching package: 'seededlda'
    #> The following object is masked from 'package:stats':
    #> 
    #>     terms
    library(RCurl)
    #> 
    #> Attaching package: 'RCurl'
    #> The following object is masked from 'package:tidyr':
    #> 
    #>     complete
    
    
    
    
    
    
    ## See https://koheiw.github.io/seededlda/articles/pkgdown/basic.html
    
    
    url='https://www.dropbox.com/s/abme18nlrwxgmz8/data_corpus_sputnik2022.rds?raw=1'
    download.file(url,
                  destfile="sputnik.RDS",
                  method="auto")
    
    corp_all <- readRDS("sputnik.RDS")
    
    corp <- corpus_subset(corp_all, date> "2022-11-29")
    
    toks <- tokens(corp, remove_punct = TRUE, remove_symbols = TRUE, 
                   remove_numbers = TRUE, remove_url = TRUE)
    dfmt <- dfm(toks) |> 
        dfm_remove(stopwords("en")) |>
        dfm_remove("*@*") |>
        dfm_trim(max_docfreq = 0.1, docfreq_type = "prop")
    print(dfmt)
    #> Document-feature matrix of: 550 documents, 17,190 features (99.14% sparse) and 4 docvars.
    #>              features
    #> docs          spanish firm instalaza city zaragoza received similar exploded
    #>   s1104914730       5    1         3    1        2        2       2        2
    #>   s1104912678       0    0         0    0        0        0       0        0
    #>   s1104910731       0    0         0    0        0        0       1        0
    #>   s1104906969       0    0         0    0        0        0       1        0
    #>   s1104905548       0    0         0    0        0        0       2        0
    #>   s1104891116       0    0         0    0        0        1       0        0
    #>              features
    #> docs          near embassy
    #>   s1104914730    2       4
    #>   s1104912678    0       0
    #>   s1104910731    0       0
    #>   s1104906969    0       0
    #>   s1104905548    2       1
    #>   s1104891116    0       0
    #> [ reached max_ndoc ... 544 more documents, reached max_nfeat ... 17,180 more features ]
    
    lda <- textmodel_lda(dfmt, k = 5, verbose = TRUE)
    #> Fitting LDA with 5 topics
    #>  ...initializing
    #>  ...Gibbs sampling in 2000 iterations
    #>  ......iteration 100 elapsed time: 2.10 seconds (delta: -0.02%)
    #>  ......iteration 200 elapsed time: 3.71 seconds (delta: 0.18%)
    #>  ......iteration 300 elapsed time: 5.28 seconds (delta: -0.14%)
    #>  ......iteration 400 elapsed time: 7.27 seconds (delta: -0.06%)
    #>  ......iteration 500 elapsed time: 9.02 seconds (delta: 0.05%)
    #>  ......iteration 600 elapsed time: 10.66 seconds (delta: 0.07%)
    #>  ......iteration 700 elapsed time: 12.69 seconds (delta: 0.06%)
    #>  ......iteration 800 elapsed time: 15.02 seconds (delta: 0.06%)
    #>  ......iteration 900 elapsed time: 16.67 seconds (delta: -0.12%)
    #>  ......iteration 1000 elapsed time: 18.66 seconds (delta: 0.10%)
    #>  ......iteration 1100 elapsed time: 20.79 seconds (delta: 0.03%)
    #>  ......iteration 1200 elapsed time: 22.45 seconds (delta: 0.13%)
    #>  ......iteration 1300 elapsed time: 24.36 seconds (delta: 0.09%)
    #>  ......iteration 1400 elapsed time: 26.68 seconds (delta: 0.19%)
    #>  ......iteration 1500 elapsed time: 28.45 seconds (delta: -0.06%)
    #>  ......iteration 1600 elapsed time: 30.12 seconds (delta: 0.18%)
    #>  ......iteration 1700 elapsed time: 31.78 seconds (delta: -0.07%)
    #>  ......iteration 1800 elapsed time: 33.33 seconds (delta: 0.02%)
    #>  ......iteration 1900 elapsed time: 35.00 seconds (delta: -0.05%)
    #>  ......iteration 2000 elapsed time: 36.59 seconds (delta: -0.17%)
    #>  ...computing theta and phi
    #>  ...complete
    
    knitr::kable(terms(lda))
    
    topic1 topic2 topic3 topic4 topic5
    crude today un twitter patriot
    grain soviet french joins missiles
    exports know macron bill aircraft
    electricity rose donbass journalist pentagon
    products crimea visit joined soldiers
    african important france republicans training
    g7 charlie brussels 🇺🇦 army
    india always assets musk data
    natural believe alliance author equipment
    pipeline problems deputy senate march
    
    ##How can I visualise the results in the style of what I see here
    
    ## https://juliasilge.com/blog/sherlock-holmes-stm/
    
    ## and I refer in particular to Figure entitled "Highest word probabilities for each topic" ?
    
    
    ## top 10 terms per topic
    
    top10 <- terms(lda, n = 10) |>
        as_tibble() |>
        pivot_longer(cols=starts_with("t"),
                     names_to="topic", values_to="word")
    
    
    
    phi <- lda$phi |>
        as_tibble(rownames="topic")  |>
        pivot_longer(cols=c(-topic))
        
    
    top10phi <- top10 |>
        left_join(y=phi, by=c("topic", "word"="name")) ##finally I have a tibble I can work with.
    
    top10phi
    #> # A tibble: 50 × 3
    #>    topic  word       value
    #>    <chr>  <chr>      <dbl>
    #>  1 topic1 crude    0.00509
    #>  2 topic2 today    0.00488
    #>  3 topic3 un       0.00482
    #>  4 topic4 twitter  0.0101 
    #>  5 topic5 patriot  0.00642
    #>  6 topic1 grain    0.00426
    #>  7 topic2 soviet   0.00457
    #>  8 topic3 french   0.00482
    #>  9 topic4 joins    0.00720
    #> 10 topic5 missiles 0.00612
    #> # ℹ 40 more rows
    
    
    ## See https://stackoverflow.com/questions/5409776/how-to-order-bars-in-faceted-ggplot2-bar-chart/5414445#5414445
    
    
    sort_facets <- function(df, cat_a, cat_b, cat_out, ranking_var){
        res <- df |>
            mutate({{cat_out}}:=factor(paste({{cat_a}}, {{cat_b}}))) |>
            mutate({{cat_out}}:=reorder({{cat_out}}, rank({{ranking_var}})))
    
      return(res)  
    }
    
    
    
    
    dd2 <- sort_facets(top10phi, topic, word, category2, value)
    
    gpl <- ggplot(dd2, aes(y=category2, x=value)) +
      geom_bar(stat = "identity") +
      facet_wrap(. ~ topic, scales = "free_y", nrow=3) +
        scale_y_discrete(labels=dd2$word, breaks=dd2$category2,
                         )+
            xlab("Probability")+
        ylab(NULL)
    
    gpl
    

    sessionInfo()
    #> R version 4.3.1 (2023-06-16)
    #> Platform: x86_64-pc-linux-gnu (64-bit)
    #> Running under: Debian GNU/Linux 12 (bookworm)
    #> 
    #> Matrix products: default
    #> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.11.0 
    #> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.11.0
    #> 
    #> locale:
    #>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
    #>  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
    #>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
    #>  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
    #>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
    #> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
    #> 
    #> time zone: Europe/Brussels
    #> tzcode source: system (glibc)
    #> 
    #> attached base packages:
    #> [1] stats     graphics  grDevices utils     datasets  methods   base     
    #> 
    #> other attached packages:
    #>  [1] RCurl_1.98-1.8  seededlda_1.1.0 proxyC_0.3.4    quanteda_3.3.1 
    #>  [5] lubridate_1.9.2 forcats_1.0.0   stringr_1.5.0   dplyr_1.1.2    
    #>  [9] purrr_1.0.2     readr_2.1.4     tidyr_1.3.0     tibble_3.2.1   
    #> [13] ggplot2_3.4.3   tidyverse_2.0.0
    #> 
    #> loaded via a namespace (and not attached):
    #>  [1] utf8_1.2.2         generics_0.1.3     bitops_1.0-7       stringi_1.7.8     
    #>  [5] lattice_0.20-45    hms_1.1.3          digest_0.6.29      magrittr_2.0.3    
    #>  [9] evaluate_0.15      grid_4.3.1         timechange_0.2.0   fastmap_1.1.0     
    #> [13] Matrix_1.6-1.1     stopwords_2.3      fansi_1.0.3        scales_1.2.1      
    #> [17] cli_3.6.1          rlang_1.1.1        munsell_0.5.0      reprex_2.0.2      
    #> [21] withr_2.5.0        yaml_2.3.5         tools_4.3.1        tzdb_0.3.0        
    #> [25] colorspace_2.0-3   fastmatch_1.1-4    vctrs_0.6.3        R6_2.5.1          
    #> [29] lifecycle_1.0.3    fs_1.5.2           pkgconfig_2.0.3    RcppParallel_5.1.5
    #> [33] pillar_1.9.0       gtable_0.3.0       glue_1.6.2         Rcpp_1.0.9        
    #> [37] xfun_0.31          tidyselect_1.2.0   highr_0.9          knitr_1.39        
    #> [41] farver_2.1.1       htmltools_0.5.2    labeling_0.4.2     rmarkdown_2.14    
    #> [45] compiler_4.3.1
    

    Created on 2023-10-30 with reprex v2.0.2