Skip to content

toona note

pandas の apply と transform の違いを知る

はじめに

Pandas の apply, tranform はどちらも group に対する演算を行う物です。
transform は演算対象が series である制約があります。

問題設定

ある 2 つのグループに対して、数値が与えられている。
この 2 つのグループそれぞれに対して、何らかの演算結果(今回は平均)を得たい。

df = pd.DataFrame(
    [[1, 2], [1, 3], [1, 4], [1, 5], [2, 6], [2, 7], [2, 8], [2, 9]],
    columns=["group", "value"],
)
diff_of_apply_transform_001

apply

やり方はいくつかありますが、まずは apply から。

d = df.groupby(by=['group']).mean()['value'].to_dict()
df['mean'] = df.apply(lambda x: d[x['group']], axis=1)
print(df)

# ===============================
   group  value  mean
0      1      2   3.5
1      1      3   3.5
2      1      4   3.5
3      1      5   3.5
4      2      6   7.5
5      2      7   7.5
6      2      8   7.5
7      2      9   7.5

map でもできる。 単純な平均を得たいならばこちらでやる気がする。

d = dict()
d['mean'] = df.groupby(by=['group'])['value'].mean()
df['mean'] = df['group'].map(d['mean'].to_dict())
print(df)

上記 2 つの変換ではどちらも、辞書を使いました。
groupby でまとめた DataFrame に対して演算を行うとグループごとにまとめられた結果が得られ、もとの DataFrame と形が違うためです。

print(df.groupby(by=['group'])['value'].mean())

# ==================================================
group
1    3.5
2    7.5

この方法は、平均以外の複雑な計算にも対応できます。
しかし、もっと賢い方法がありそう…
(より良い方法が見つかったら追記する。)

transform とは

transform を用いた計算では、もとの DataFrame と同一の形状を得ることができます。

df2 = df.groupby(by=['group'])['value'].transform(lambda x: x.mean())
print(df2)

# =====================================================
0    3.5
1    3.5
2    3.5
3    3.5
4    7.5
5    7.5
6    7.5
7    7.5

従って、今回得たい DataFrame は下記のように得られます。

df['mean'] = df.groupby(by=['group'])['value'].transform(lambda x: x.mean())

簡単!! 綺麗!!

transform の注意点

上のプログラムを apply と同じ思想で、書き換えると動きません。

# これは動かない
df["mean"] = df.groupby(by=["group"]).transform(
    lambda x: x["value"].mean(), axis=1
)
# ===========================
TypeError: Transform function invalid for data types

これは、transform の中で演算可能なものが Series であるから。
上の例では、axis=1 として、複数列に対する演算結果を得ようとしています。
この操作は、単一の series に対する計算ではありません。
このように、DataFrame 中の 2 つ以上の列に対する演算結果を得たいならば、 apply を使わなければなりません。
Series に対して演算可能で、単一の数値を返すような sum(), len() などは transofrm で演算可能です。

おわりに

transform という綺麗な方法を知り喜んで使っていたら見事に error に当たったのでまとめました。 groupby を経由する演算は途中経過を print できないので挙動を掴みにくいですが、参考先の web site が非常に役に立ちました。

参考