Skip to content

toona note

sklearn の StratifiedKFold のサンプル数に対する挙動の変化

はじめに

scikit-learn の StratifiedKFold の引数 n_splits と、データのサンプル数による挙動と、分割不能になる数を確認します。

Stratified KFold とは

データセットのクラスを保持しながらデータを分割する方法です。
たとえば、データセット内の、クラス 1 とクラス 2 のデータの比率が 9:1 であった場合、テストデータにおいてもクラスの比率が 9:1 であると、実運用段階に近く、よい正確な精度の見積もりができる可能性が高いと考えられます。
このような比率を保持した分割を行うのが StratifiedKFold です。
[scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold. html#sklearn.model_selection.StratifiedKFold)

疑問点

scikit-learn の StratifiedKFold は可能な限りデータのクラスの比率を保持しますが、そのままでは比率の保持ができないデータを渡した場合はどのような挙動をするのでしょうか?

例えばクラス 1 に属するデータは 5 個しかないにも関わらず、学習・試験データを 6 セット求められた場合です。
以降、順に挙動を確認します。

1. n_splits <= 最小クラスのデータ数

確認のための、目的変数(クラス)は、クラス 3 のデータが 3 個、クラス 4 のデータが 4 個、クラス 5 のデータが 5 個とします。

ここから、3 つの学習・試験データの分割を得ることを考えましょう。この場合データは十分にあるので、素直な挙動になります。
各データが 1 度ずつ登場し、3 つの分割方法が得られます。

2. 最小クラスのデータ数 < n_splits <= 最大クラスのデータ数

先の場合と同じデータを用いて、5 つの学習・試験データの分割を得ることを考えます。

最小クラスは 3 データしかないので、データの分割はできなさそうですが、分割できてしまいます。
内容を見ると、試験データにおいては、データ数が少ないクラスのデータが含まれないことがあるようです。
また、この際、「UserWarning: The least populated class in y has only 3 members, which is less than n_splits=5.」 と警告が出ます。
警告は出ますが、実行はできるので注意が必要です。

3. n_splits > 最大クラスのデータ数

先の場合と同じデータを用いて、6 セット得ようとしてみます。

この場合は、最大クラスのデータ数以上にデータを分割することを求められているので当然動作しません。

"ValueError: n_splits=6 cannot be greater than the number of members in each class."

とエラーが出ます。

私は今まで、この each class というのは、「任意のクラスのデータよりも~」という意味だと勘違いしていました。
ただ、n_splits の数は最初クラスのデータ数まで、と説明している記事もあったので、scikit-learn のバージョンにより挙動が異なるのかもしれません。
一応 scikit-learn の最新版のコードを読んで、最大クラスのデータ数で制限をかけていることを確認しました。

まとめ:n_splits の数に対する挙動

  • 最もサンプル数の少ないクラスのサンプル数以下である場合
    • そのまま動作する
  • 最もサンプル数の少ないクラスのサンプル数より多く、最もサンプル数の多いクラスのサンプル数以下である場合
    • UserWarning を出しながら動作する
  • 最もサンプル数の多いクラスのサンプル数より多い場合
    • Error

動作環境

  • python: 3.8.11
  • scikit-learn: 0.24.2