データ分析関連メモ(メモです)

仲秋の候、涼やかな秋風の下、ご一同様にはその後お健やかにお過ごしのことと存じます。

rsample::sliding_window()関数の引数

『Rユーザのためのtidymodels[実践]入門』を読み始めた。
Rユーザのためのtidymodels[実践]入門 〜モダンな統計・機械学習モデリングの世界:書籍案内|技術評論社


第1章で時系列データを分析データと検証データに分割する関数rsample::sliding_window()関数の説明がある。
この関数の引数について自分向けに整理。



まずはデータの読み込み。データは参考書と同様。分割の仕方がわかり易いようにindexを追加している。
"S4248SM144NCEN"という販売量が入っている列名は、FRED(Federal Reserve Economic Data)のコードを意味しているのかもしれない。
Merchant Wholesalers, Except Manufacturers' Sales Branches and Offices: Nondurable Goods: Beer, Wine, and Distilled Alcoholic Beverages Sales (S4248SM144NCEN) | FRED | St. Louis Fed

data(drinks, package = "modeldata")

drinks_annual <- drinks %>% 
  dplyr::mutate(year = lubridate::year(date)) %>% 
  dplyr::filter(dplyr::between(year, 1992, 1994)) %>% 
  dplyr::mutate(index = dplyr::row_number(x = date))

drinks_annual %>% 
  head()
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1
# 1992-02-01  3458  1992     2
# 1992-03-01  4002  1992     3
# 1992-04-01  4564  1992     4
# 1992-05-01  4221  1992     5
# 1992-06-01  4529  1992     6


次にsliding_window()の基本的な使い方について。
先ず、rsample::initial_time_split()関数をデータに適用して、rsplitクラスのオブジェクト(drinks_split_ts2)を取得する。
次に、rsplitクラスのオブジェクトにrsample::training()関数を適用して学習データ(train_ts_data2)を参照する。

drinks_split_ts2 <- drinks_annual %>% 
  rsample::initial_time_split(prop = 0.68)

train_ts_data2 <- drinks_split_ts2 %>% 
  rsample::training()


この学習データ(train_ts_data2)にrsample::sliding_window()関数を適用すれば分析データと検証データに分割できるのだが、 この分割方法をrsample::sliding_window()関数の引数で指定することになる。

rsample::sliding_window()関数の引数は以下の通り(書籍から引用)。
lookbackが分析セット、assess系の2つは検証セットの操作をしている。

lookback……分析セットの件数(時系列の起点となる地震を含まない件数)
assess_start……検証セットの時系列上での起点
assess_stop……検証セットの時系列上での終点


引数のデフォルト値はそれぞれ0,1,1となっている。
Time-based Resampling — slide-resampling • rsample

sliding_window(
data,
...,
lookback = 0L,
assess_start = 1L,
assess_stop = 1L,
complete = TRUE,
step = 1L,
skip = 0L
)



ここから、引数に様々な値を与えて出力結果を確認していく。

デフォルトのまま実行すると、分析データは最初の1件から、検証データは分析データの次の1件で、1つずつずれていく。

fold_default <- rsample::sliding_window(data = train_ts_data2)

rsample::analysis(fold_default$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1

rsample::assessment(fold_default$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-02-01  3458  1992     2

rsample::analysis(fold_default$splits[[2]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-02-01  3458  1992     2

rsample::assessment(fold_default$splits[[2]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-03-01  4002  1992     3


lookback = 10とすると、分析データ11件(起点となる1件が足され、10 + 1で11件)、検証データはデフォルトのまま1件。

fold_look10 <- rsample::sliding_window(data = train_ts_data2,
                                       lookback = 10)

rsample::analysis(fold_look10$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1
# 1992-02-01  3458  1992     2
# 1992-03-01  4002  1992     3
# 1992-04-01  4564  1992     4
# 1992-05-01  4221  1992     5
# 1992-06-01  4529  1992     6
# 1992-07-01  4466  1992     7
# 1992-08-01  4137  1992     8
# 1992-09-01  4126  1992     9
# 1992-10-01  4259  1992    10
# 1992-11-01  4240  1992    11

rsample::assessment(fold_look10$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-12-01  4936  1992    12

rsample::analysis(fold_look10$splits[[2]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-02-01  3458  1992     2
# 1992-03-01  4002  1992     3
# 1992-04-01  4564  1992     4
# 1992-05-01  4221  1992     5
# 1992-06-01  4529  1992     6
# 1992-07-01  4466  1992     7
# 1992-08-01  4137  1992     8
# 1992-09-01  4126  1992     9
# 1992-10-01  4259  1992    10
# 1992-11-01  4240  1992    11
# 1992-12-01  4936  1992    12

rsample::assessment(fold_look10$splits[[2]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1993-01-01  3031  1993    13


assess_start = 5とすると、分析データはデフォルトと同様に1件、検証データは分析データの5件あとのindexが6であるレコード1件。
assess_stopはassess_start以上の数値指定しないとエラーになる。

fold_start5 <- rsample::sliding_window(data = train_ts_data2,
                                       assess_start = 5, 
                                       assess_stop = 5)

rsample::analysis(fold_start5$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1

rsample::assessment(fold_start5$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-06-01  4529  1992     6


assess_start = 5, assess_stop = 10とすると、検証データは分析データから数えて5件目~10件目の6件を取る。

fold_start5_stop10 <- rsample::sliding_window(data = train_ts_data2,
                                       assess_start = 5, 
                                       assess_stop = 10)

rsample::analysis(fold_start5_stop10$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1

rsample::assessment(fold_start5_stop10$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-06-01  4529  1992     6
# 1992-07-01  4466  1992     7
# 1992-08-01  4137  1992     8
# 1992-09-01  4126  1992     9
# 1992-10-01  4259  1992    10
# 1992-11-01  4240  1992    11


lookback = 10, assess_start = 5, assess_stop = 10とすると、分析データは11件、検証データはそこから5件目~10件目を取る。

fold_look10_start5 <- rsample::sliding_window(data = train_ts_data2,
                                              lookback = 10,
                                              assess_start = 5, 
                                              assess_stop = 5)
rsample::analysis(fold_look10_start5$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1992-01-01  3459  1992     1
# 1992-02-01  3458  1992     2
# 1992-03-01  4002  1992     3
# 1992-04-01  4564  1992     4
# 1992-05-01  4221  1992     5
# 1992-06-01  4529  1992     6
# 1992-07-01  4466  1992     7
# 1992-08-01  4137  1992     8
# 1992-09-01  4126  1992     9
# 1992-10-01  4259  1992    10
# 1992-11-01  4240  1992    11

rsample::assessment(fold_look10_start5$splits[[1]])
# date       sales  year index
# <date>     <dbl> <dbl> <int>
# 1993-04-01  4377  1993    16