rnearest-neighbortidymodels

View which were the nearest neighbors in tidymodel workflow


I have a similar question like this: How to view the nearest neighbors in R?

Code example:

library(tidymodels)

knn_rec <- recipe(Species ~ ., data = iris)

knn_lookup <- workflow() %>%
  add_model(nearest_neighbor(neighbors = 3) %>%
    set_engine("kknn") %>% 
    set_mode("classification")) %>%
  add_recipe(knn_rec) %>%
  fit(data = iris)

knn_pred <- predict(knn_lookup , new_data = iris[1,1:4], type = "raw")

How can I trace back which rows were identified as the 3 nearest neighbors? I'd prefer something simple like tidy() within the tidyverse.


Solution

  • The kknn package does all of the work for you. If you use any steps with the recipe, make sure that you prep and bake the data to get the real values used by the models.

    Here's a simple example:

    library(tidymodels)
    library(kknn)
    
    data(two_class_dat)
    
    set.seed(1)
    dat_split <- initial_split(two_class_dat)
    dat_train <- training(dat_split)
    dat_test  <- testing(dat_split)
    
    knn_obj <- 
      kknn(
        Class ~ .,
        dat_train,
        dat_test %>% select(-Class) %>% head(),
        k = 3,
        kernel = "triangular"
      )
    
    # Indicies of neighbors
    knn_obj$C
    #>      [,1] [,2] [,3]
    #> [1,]   72  247  152
    #> [2,]  336  397  113
    #> [3,]   13  528  329
    #> [4,]  218  362  298
    #> [5,]  566  258  392
    #> [6,]  579   66  302
    
    # Weights: 
    knn_obj$W
    #>           [,1]      [,2]       [,3]
    #> [1,] 0.3461305 0.2225861 0.08587877
    #> [2,] 0.2912265 0.2832970 0.15322362
    #> [3,] 0.8111406 0.4207619 0.26617555
    #> [4,] 0.8868195 0.4606245 0.21953733
    #> [5,] 0.5464303 0.3773175 0.32209628
    #> [6,] 0.7769446 0.7006026 0.12411210
    

    Created on 2024-11-11 with reprex v2.1.1