データ分析関連メモ(メモです)

仲秋の候、涼やかな秋風の下、ご一同様にはその後お健やかにお過ごしのことと存じます。

rpartの決定木から分岐情報を取り出す

rpartで作成した決定木から分岐の情報を抽出する。
かわいいかわいいpalmerpenguinsをサンプルデータとして決定木を作成。

library(palmerpenguins)

tree <- rpart::rpart(
  formula = species ~ .,
  data = penguins,
  method = "class")

tree
# n= 344 
# 
# node), split, n, loss, yval, (yprob)
# * denotes terminal node
# 
# 1) root 344 192 Adelie (0.441860465 0.197674419 0.360465116)  
# 2) flipper_length_mm< 206.5 214  64 Adelie (0.700934579 0.294392523 0.004672897)  
# 4) bill_length_mm< 43.35 151   5 Adelie (0.966887417 0.033112583 0.000000000) *
#   5) bill_length_mm>=43.35 63   5 Chinstrap (0.063492063 0.920634921 0.015873016) *
#   3) flipper_length_mm>=206.5 130   7 Gentoo (0.015384615 0.038461538 0.946153846)  
# 6) island=Dream,Torgersen 7   2 Chinstrap (0.285714286 0.714285714 0.000000000) *
#   7) island=Biscoe 123   0 Gentoo (0.000000000 0.000000000 1.000000000) *


rpart.plotはこんな感じ。

rpart.plot::rpart.plot(tree)


分岐情報は内部変数として持っているのでrpart:::labels.rpartでアクセスできる。
labels.rpart function - RDocumentation

「:::」はTriple Colon Operatorと呼ばれている内部変数にアクセスする演算子
R: Double Colon and Triple Colon Operators

rpart:::labels.rpart(tree)
# [1] "root"                     "flipper_length_mm< 206.5"
# [3] "bill_length_mm< 43.35"    "bill_length_mm>=43.35"   
# [5] "flipper_length_mm>=206.5" "island=bc"               
# [7] "island=a" 


内部でabbreviateという関数が使用されており、stringsはa,b,cといった形に省略されてしまう。今回のサンプルでもislandの分岐基準が「"island=bc" 」「"island=a”」と省略されてしまっている。
abbreviate function - RDocumentation

minlengthを0とすると省略せずに返してくれる。

rpart:::labels.rpart(tree, minlength = 0)
# [1] "root"                     "flipper_length_mm< 206.5"
# [3] "bill_length_mm< 43.35"    "bill_length_mm>=43.35"   
# [5] "flipper_length_mm>=206.5" "island=Dream,Torgersen"  
# [7] "island=Biscoe"


データフレームに色々まとめることもできる。
r - Extract variable labels from rpart decision tree - Stack Overflow

df_splits <- data.frame(
  splits = rpart:::labels.rpart(tree, minlength = 0),
  n = tree$frame$n,
  yval = tree$frame$yval
)

df_splits
# splits   n yval               var
# 1                     root 344    1 flipper_length_mm
# 2 flipper_length_mm< 206.5 214    1    bill_length_mm
# 3    bill_length_mm< 43.35 151    1            <leaf>
# 4    bill_length_mm>=43.35  63    2            <leaf>
# 5 flipper_length_mm>=206.5 130    3            island
# 6   island=Dream,Torgersen   7    2            <leaf>
# 7            island=Biscoe 123    3            <leaf>