rpartの結果をtidyに扱う(変数重要度と分岐情報を取り出す)
『Rユーザのためのtidymodels[実践]入門』を読み進めている。
Rユーザのためのtidymodels[実践]入門 〜モダンな統計・機械学習モデリングの世界:書籍案内|技術評論社
そのtidymodels関係で、決定木モデルを作成するrpartをtidyにするパッケージのお話。
kyphosisデータ
サンプルデータはrpartパッケージのkyphosisデータ。
# https://rdrr.io/cran/rpart/man/kyphosis.html data(kyphosis, package = "rpart") kyphosis %>% summary() # Kyphosis Age Number Start # absent :64 Min. : 1.00 Min. : 2.000 Min. : 1.00 # present:17 1st Qu.: 26.00 1st Qu.: 3.000 1st Qu.: 9.00 # Median : 87.00 Median : 4.000 Median :13.00 # Mean : 83.65 Mean : 4.049 Mean :11.49 # 3rd Qu.:130.00 3rd Qu.: 5.000 3rd Qu.:16.00 # Max. :206.00 Max. :10.000 Max. :18.00
broomパッケージ
まずはtidymodelsのbroomパッケージの関数について。
主に3つあり、前処理やモデルの実行結果をtidyにするtidy()関数、学習データに予測値と残差を付与するaugment()関数、モデル選択に利用するAICや決定係数を取得するglance()関数。
この記事ではtidy()関数を扱う。
挙動確認のため、試しにtidy()関数をlm()関数の結果に適用。偏回帰係数などがtidyに返ってくる。
lm_fit <- lm(Start ~ Age + Number, data = kyphosis) broom::tidy(lm_fit) # A tibble: 3 × 5 # term estimate std.error statistic p.value # <chr> <dbl> <dbl> <dbl> <dbl> # (Intercept) 16.3 1.54 10.6 7.88e-17 # Age 0.00427 0.00860 0.496 6.21e- 1 # Number -1.28 0.309 -4.15 8.55e- 5
一方、rpartの結果を渡すとエラーになってしまう。rpartクラスオブジェクトへのメソッドは無い、と。
fit_rpart <- rpart::rpart(Kyphosis ~ Age + Number + Start, data = kyphosis) broom::tidy(fit_rpart) # Error: No tidy method for objects of class rpart
broomstickパッケージ
そこで出てくるのがbroomstickパッケージ。broomstickとは箒の柄のこと。
broomstick.njtierney.com
2023年1月21日時点でCRANに登録されていないので、GitHubからいただいてくる。
このパッケージのtidy()関数を適用すると、変数重要度が返ってくる。
remotes::install_github("njtierney/broomstick") broomstick::tidy(fit_rpart) # A tibble: 3 × 2 # variable importance # <chr> <dbl> # Start 8.20 # Age 3.10 # Number 1.52
tidyrulesパッケージ
また、別のrpartに使えるパッケージでtidyrulesというものもある。 talegari.github.io
このパッケージのtidyRules()関数を適用すると、分岐情報などの実行結果が返ってくる。
tidyrules::tidyRules(fit_rpart) # A tibble: 5 × 6 # id LHS RHS support confidence lift # <int> <chr> <chr> <int> <dbl> <dbl> # 1 Start >= 8.5 & Start >= 14.5 absent 29 0.968 1.22 # 2 Start >= 8.5 & Start < 14.5 & Age < 55 absent 12 0.929 1.18 # 3 Start >= 8.5 & Start < 14.5 & Age >= 55 & Age >= 111 absent 14 0.812 1.03 # 4 Start >= 8.5 & Start < 14.5 & Age >= 55 & Age < 111 present 7 0.556 2.65 # 5 Start < 8.5 present 19 0.571 2.72
各変数の内容は以下の通り。
LHS
: Rules.RHS
: Predicted Class.support
: Number of observation covered by the rule.confidence
: Prediction accuracy for respective class. (laplace correction is implemented by default)lift
: The result of dividing the rule's estimated accuracy by the relative frequency of the predicted class in the training set.
tidyrules/tidyrules_vignette.Rmd at master · talegari/tidyrules · GitHub
通常分岐情報を取得するには、以下のように内部変数にアクセスする必要があるところを楽にしてくれている。
rpart:::labels.rpart(fit_rpart, minlength = 0) # [1] "root" "Start>=8.5" "Start>=14.5" "Start< 14.5" "Age< 55" # [6] "Age>=55" "Age>=111" "Age< 111" "Start< 8.5" partykit::as.party(fit_rpart) %>% partykit:::.list.rules.party() %>% stringr::str_replace_all(pattern = "\\\"","'") %>% stringr::str_remove_all(pattern = ", 'NA'") %>% stringr::str_remove_all(pattern = "'NA',") %>% stringr::str_remove_all(pattern = "'NA'") %>% stringr::str_squish() # [1] "Start >= 8.5 & Start >= 14.5" # [2] "Start >= 8.5 & Start < 14.5 & Age < 55" # [3] "Start >= 8.5 & Start < 14.5 & Age >= 55 & Age >= 111" # [4] "Start >= 8.5 & Start < 14.5 & Age >= 55 & Age < 111" # [5] "Start < 8.5"
結びに
分岐情報取得の話は以前書いた。
rpartの決定木から分岐情報を取り出す - データ分析関連メモ(メモです)
スライドはこちら。
speakerdeck.com
以上