TCAVの日本一わかりやすい解説
今回は, TCAV (Testing with Concept Activation Vectors)について詳しく&わかりやすく解説していきます.
タイトルを世界一にしようかと思いましたが, 日本語の記事なので日本一としました (日本一というのもあくまで主観的な見解です).
- TCAVとは
- 既存研究に比べて優れているところ
- Conceptの重要度の求め方
- 最終的なTCAV Scoreの算出 (平均値+T検定(意味があるScoreかどうか検定))
- Conceptの相対的な比較 (Relative TCAV)
- 評価実験
- 参考文献
TCAVとは
- 機械学習の説明性 (Explainalbe AI, XAI)に関する手法
- Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV) (ICML2018)で提案された手法
- Tensorflowでライブラリ化 (GitHub - tensorflow/tcav)
一言でまとめると
TCAVは, 画像分類に対して, あるConcept (人間が簡単に理解できるような概念)があるクラスを予測する際にどの程度重要であるかを判断する手法である.→ → わかりやすく → →
画像分類するときにConcept (人間が簡単に理解できるような概念)をどれだけ重要視するかがわかるようになった!
Conceptとは?
本文では, "human-friendly concepts", "high-level concepts that humans easily understand"と記述されているため, Conceptとは「人間が直感的に理解できる何らかの概念」と表現するすることができます. より平易に説明すると, その言葉を聞いたときに, ほとんどの人が同じようなイメージを持つことができるものがConceptです.本文で具体例として, 模様 (ストライプ, 縞模様, ドット柄)が挙げられており, 評価実験では, 色や人種などをConceptとして用いています.
既存研究に比べて優れているところ
予測に対する説明性の手法は非常に多く提案されていますが, それらと比較してTCAVが特に優れている点は以下の2点であると考えています.- Conceptベースな説明法の獲得 (⇆ Pixelベース)
- あるクラス全体を同じ指標で評価可能 (Classベース) (⇆ Instanceベース)
これまでの説明方法は以下の図ように, ある予測に対して, どのピクセルが重要かどうか (Pixelベース)な説明方法が主流でした (色が白くなっているところが予測に重要). その一方でTCAVは, Conceptベースな説明法を提案しました.
また, 以下の図のように, Instanceベース (各サンプルについて説明を行う)ではなく, あるクラス全体に説明を行う(Classベース)ことで, より一貫性のある説明ができるようになりました. ここでClassベースな説明とは, 例えばシマウマの一つ一つの画像に対して説明を行うのではなくて, シマウマというクラス全体で1つの説明を行うということです (e.g. シマウマクラスは縞模様が重要である).
Conceptの重要度の求め方
Conceptの重要度の求め方の概要は以下の図のようになります. これからこの図について説明していきます.TCAVは, 訓練済みモデルの予測に対するConceptの重要度を計算します (Post-hoc). つまり訓練済みモデルが重要度を計算するために必要です (あたりまえ). 重要度の計算方法の前に, 入力として必要なものを考えてみます. 入力は以下のようになります (アルファベットは上記の図と対応しております). ここからわかりやすいように, あるクラス=シマウマ, あるConcept=縞模様として具体例も同時に考えていきます.
- (a) あるConceptを持つデータセット + あるConceptを持たないランダムなデータセット (e.g. 縞模様画像のデータセット+縞模様を含まない適当なデータセット)
- (b) あるクラスのサンプル (シマウマの画像データセット)
- (c) 訓練済みモデル ※ (シマウマ, ウマ, ...を予測するネットワーク ※シマウマクラスが含まれていれば多クラスでも良い)
重要度の計算方法は, 大きく分けて3つの工程から構成されます.
- あるクラスに対して, 各サンプルのConceptの重要度 (Sensitivity)を求める. (e.g. シマウマクラスの各データに対して縞模様の重要度を計算)
- あるクラスに対するConceptの重要度を求める (e.g. シマウマクラス全体に対して縞模様の重要度を計算)
- 統計的仮説検定 (Two side t test)を用いて意味のある重要度かどうか検定して, 意味のある場合は, 重要度を出力 (意味のない場合は何も出力しない)
各サンプルベースの重要度を求めた後に, その結果をまとめてあるクラスの重要度を算出するようなイメージです. 3の工程については, 後ほど詳しく説明します.
重要度 (Sensitivity)の算出
ここでは, 重要度を定義します. ここで, 訓練済みモデルについて, (c)のように, 入力→中間層 : , 中間層 →出力 : と表します.まずは自然言語で重要度の定義を行います.
これを数式で定義すると, 以下のようになります (e). ここで, Concept C方向のベクトルをとします.
\begin{eqnarray}
S_{C,k,l} (x) &=& \lim_{\epsilon \rightarrow 0} \frac{h_{l,k}(f_l(x)+\epsilon v_C^l) - h_{l,k}(f_l(x))}{\epsilon} \\ \nonumber
&=& \Delta h_{l,k} (f_l(x)) \cdot v_C^l
\end{eqnarray}
以上より, 工程1 (各サンプルのConceptの重要度算出)が達成されました. 工程2を求めるために以下を行います.
\begin{eqnarray}
&&\mathrm{TCAV_Q}_{C,k,l} = \frac{|\{x \in X_k : S_{C,k,l} (x) > 0 \}|}{|X_k|} \\ \nonumber
&&where X_k~is~all~inputs~with~that~given~label
\end{eqnarray}
つまり, Class kの全ての画像に対してSensitivityを計算して, そのsensitivityが正になったものの割合をTCAV Score とします.
以下の図は潜在空間が2次元の場合で単純化したときの計算方法です.
Concept方向のベクトル (Concept Activation Vector)の求め方
計算方法は示しましたが, Concept C方向のベクトルの求め方がわからないので計算することができませんよね. 次に実際にをどのように求めるかを説明します. ここで, をConcept Activation Vector (CAV)と呼びます.(a)のデータセットを用いて計算します. 求め方自体は非常に簡単で, Concept Cを持つデータセットとConcept Cを持たないランダムなデータセットの線形分離を考えて, その分離平面の法線ベクトルがとなります(d). これでTCAV Scoreを計算することができるようになりました (工程1, 2). ここで, TCAV Scoreは訓練済みモデル・Concept set (Conceptを含むデータセット)・Random set (Conceptを含まないデータセット)が与えられたら計算可能なので本文では, 便宜上, 関数と表現します.
最終的なTCAV Scoreの算出 (平均値+T検定(意味があるScoreかどうか検定))
CAVは, Concept set (Conceptを含むデータセット)とRandom set (Conceptを含まないデータセット)の線形分離によって求まります. つまり, Random setの選び方によって, CAVは大きく異なります (実装コードでは, Random setはImagenetのランダムなClassのデータセットとしている). そこで, CAVを安定させるために, 複数のRandom set ()を使用して, 求められたTCAV Scoreの平均値を最終的な重要度として使用します (評価実験では500セット). \begin{eqnarray}
&&\mathrm{TCAV}_{C,k,l} =\frac{1}{m} \sum_{i=1}^m {TCAV_Q}_{k,l}(f,C,Rand_i)
&&where~ m~is~number~of~Random~set.
\end{eqnarray}
Random setの選択方法によりCAVは大きく変化して, TCAV Scoreも大きく変化します. つまり, Random setの選択方法によって無意味なCAVが学習されて, 無意味なTCAV Scoreを出力される可能性があります. そこで, 意味のないTCAV Scoreを検定によって排除します (工程3). ここで, 無意味なTCAV ScoreをRandom sets同士によって得られるCAVによって算出されるTCAV Scoreと定義します (e.g. ).
得られたTCAV Scoreが無意味である場合, 値を出力しません. 無意味なTCAV Scoreかどうか判別するために, 以下に定義する集合 で両側T検定を行います.
\begin{eqnarray}
Set_{Concept} &=& \{{TCAV_Q}_{k,l}(f,C,Rand_1), {TCAV_Q}_{k,l}(f,C,Rand_2),\\&& {TCAV_Q}_{k,l}(f,C,Rand_3), \cdots \} \\
Set_{Rand} &=& \{ {TCAV_Q}_{k,l}(f,Rand_1,Rand_2), {TCAV_Q}_{k,l}(f,Rand_1,Rand_3),\\&& {TCAV_Q}_{k,l}(f,Rand_2,Rand_3), \cdots \}
\end{eqnarray}
ここで, 暗黙の仮定として, Random set同士のCAVおよびTCAV Scoreは意味のないスコアであるとしています. そして, 帰無仮説は「Random set同士のTCAV ScoreとConcept setとRandom setのTCAV Scoreの平均値に有意な差がない」となります. 最終的なTCAV Scorは平均値を使うので, 母集団が正規分布に従っているという仮定を満たしているかということは, 考慮しないとすると, T検定を行う妥当性はあると考えらます.
よって, 帰無仮説が棄却される (Random set同士のTCAV ScoreとConcept setとRandom setのTCAV Scoreの平均値に有意な差がないと言えない)場合, Concept setとRandom setのTCAV Scoreの平均値は (Randomと比べて)意味のある値であるとして, TCAV Scoreの平均値を出力します. その一方で, 帰無仮説が棄却されない場合は, Concept setとRandom setのTCAV Scoreの平均値は (Randomと比べて)意味のない値であるとして, 何も出力しません.
つまり, 検定によって, できるだけ無意味であると思われるようなTCAV Scoreを排除しているということになります.
Conceptの相対的な比較 (Relative TCAV)
ここで, Concept同士に何らかの関連がある (茶髪 vs 黒髪)場合は, Conceptの相対的な比較が可能なRelative TCAVを使用します. 手法自体は, 至ってシンプルで, 先ほどTCAVで用いていたRandom setをConcept setに置き換えるだけです. つまり, Concept set同士の比較になります.具体例を考えます. 比較したいConceptを{茶髪, 黒髪, 白髪, 金髪}として, 茶髪のConceptの相対的な重要度が知りたいとします. その場合は, 以下の集合 を用意して, T検定を行い棄却された場合に, の平均値を出力します.
\begin{eqnarray}
Set_{own} &=& \{{TCAV_Q}_{k,l}(f,茶髪,黒髪), {TCAV_Q}_{k,l}(f,茶髪, 白髪), \\&& {TCAV_Q}_{k,l}(f,茶髪, 金髪)\} \\
Set_{another} &=& \{ {TCAV_Q}_{k,l}(f,黒髪, 白髪), {TCAV_Q}_{k,l}(f,黒髪, 金髪),\\&& {TCAV_Q}_{k,l}(f,白髪, 黒髪), \cdots\}
\end{eqnarray}
※ TCAVに比べて極端にSetのサンプル数が少なくなることを防ぐために, Conceptの各データセットの一部を除去するなどをしてデータセットを水増しします (Githubのissueで言及).
評価実験
以下はImagenetで事前学習済みのGoogleNetとInceptionV3に対する各層でのTCAV Score (縦軸)です. アスタリスクがついているところは, T検定によって棄却されたTCAV Scoreなので, 値を0としています.左図では, ピンポンボールという予測に対する人種のConceptのスコアです. 人種によって大きく偏りがあることがわかります (TCAV Scoreの計算方法上, スコアが0.5上回ったときが重要なConceptであると判断して良いでしょう). また, 右図では, シマウマの予測に対するジグザグ模様, 縞模様, ドット模様の重要度です. 縞模様のスコアが他の模様に比べて大きいことがわかり直感的に正しそうな結果が得られました.
他にも興味深い実験を行っているので気になる方は論文の方を読んでみてください.
解説は以上となります!長くなってしまいましたがここまで読んで頂きありがとうございました!