TCAVの日本一わかりやすい解説

今回は, TCAV (Testing with Concept Activation Vectors)について詳しく&わかりやすく解説していきます.

タイトルを世界一にしようかと思いましたが, 日本語の記事なので日本一としました (日本一というのもあくまで主観的な見解です).

TCAVとは

一言でまとめると

TCAVは, 画像分類に対して, あるConcept (人間が簡単に理解できるような概念)があるクラスを予測する際にどの程度重要であるかを判断する手法である.
→ → わかりやすく → →
画像分類するときにConcept (人間が簡単に理解できるような概念)をどれだけ重要視するかがわかるようになった!

Conceptとは?

本文では, "human-friendly concepts", "high-level concepts that humans easily understand"と記述されているため, Conceptとは「人間が直感的に理解できる何らかの概念」と表現するすることができます. より平易に説明すると, その言葉を聞いたときに, ほとんどの人が同じようなイメージを持つことができるものがConceptです.

本文で具体例として, 模様 (ストライプ, 縞模様, ドット柄)が挙げられており, 評価実験では, 色や人種などをConceptとして用いています.

既存研究に比べて優れているところ

予測に対する説明性の手法は非常に多く提案されていますが, それらと比較してTCAVが特に優れている点は以下の2点であると考えています.

  1. Conceptベースな説明法の獲得 (⇆ Pixelベース)
  2. あるクラス全体を同じ指標で評価可能 (Classベース) (⇆ Instanceベース)

これまでの説明方法は以下の図ように, ある予測に対して, どのピクセルが重要かどうか (Pixelベース)な説明方法が主流でした (色が白くなっているところが予測に重要). その一方でTCAVは, Conceptベースな説明法を提案しました.

また, 以下の図のように, Instanceベース (各サンプルについて説明を行う)ではなく, あるクラス全体に説明を行う(Classベース)ことで, より一貫性のある説明ができるようになりました. ここでClassベースな説明とは, 例えばシマウマの一つ一つの画像に対して説明を行うのではなくて, シマウマというクラス全体で1つの説明を行うということです (e.g. シマウマクラスは縞模様が重要である).

f:id:munemakun:20200523153727p:plain
Original image and Saliency Maps [1]

Conceptの重要度の求め方

Conceptの重要度の求め方の概要は以下の図のようになります. これからこの図について説明していきます.
f:id:munemakun:20200523155309p:plain
Testing with Concept Activation Vectors

TCAVは, 訓練済みモデルの予測に対するConceptの重要度を計算します (Post-hoc). つまり訓練済みモデルが重要度を計算するために必要です (あたりまえ). 重要度の計算方法の前に, 入力として必要なものを考えてみます. 入力は以下のようになります (アルファベットは上記の図と対応しております). ここからわかりやすいように, あるクラス=シマウマ, あるConcept=縞模様として具体例も同時に考えていきます.

  • (a) あるConceptを持つデータセット + あるConceptを持たないランダムなデータセット (e.g. 縞模様画像のデータセット+縞模様を含まない適当なデータセット)
  • (b) あるクラスのサンプル (シマウマの画像データセット)
  • (c) 訓練済みモデル \displaystyle f ※ (シマウマ, ウマ, ...を予測するネットワーク ※シマウマクラスが含まれていれば多クラスでも良い)

重要度の計算方法は, 大きく分けて3つの工程から構成されます.

  1. あるクラスに対して, 各サンプルのConceptの重要度 (Sensitivity)を求める. (e.g. シマウマクラスの各データに対して縞模様の重要度を計算)
  2. あるクラスに対するConceptの重要度を求める (e.g. シマウマクラス全体に対して縞模様の重要度を計算)
  3. 統計的仮説検定 (Two side t test)を用いて意味のある重要度かどうか検定して, 意味のある場合は, 重要度を出力 (意味のない場合は何も出力しない)

各サンプルベースの重要度を求めた後に, その結果をまとめてあるクラスの重要度を算出するようなイメージです. 3の工程については, 後ほど詳しく説明します.

重要度 (Sensitivity)の算出

ここでは, 重要度を定義します. ここで, 訓練済みモデルについて, (c)のように, 入力→中間層 \displaystyle l : \displaystyle f_l, 中間層 \displaystyle l→出力 : \displaystyle f_{l,k}と表します.

まずは自然言語で重要度の定義を行います.

Conceptの重要度 (Sensitivity)Class kのサンプル xにおけるConcept Cの重要度 \displaystyle S_{C,k,l} (x)とは, 潜在空間 \displaystyle f_l内で, Class kの入力画像の特徴量 \displaystyle f_l(x)を微小にあるConcept方向に変化させたときに, Class kの予測確率の変化率である.
これをより平易な言葉で説明すると, 例えば, シマウマの画像に対して縞模様具合 (シマシマ度?)を高めたときに, どれだけ予測器がシマウマと分類しやすくなったかがシマウマの縞模様の重要度になります.

これを数式で定義すると, 以下のようになります (e). ここで, Concept C方向のベクトルを\displaystyle v_C^lとします.
\begin{eqnarray}
S_{C,k,l} (x) &=& \lim_{\epsilon \rightarrow 0} \frac{h_{l,k}(f_l(x)+\epsilon v_C^l) - h_{l,k}(f_l(x))}{\epsilon} \\ \nonumber
&=& \Delta h_{l,k} (f_l(x)) \cdot v_C^l
\end{eqnarray}

以上より, 工程1 (各サンプルのConceptの重要度算出)が達成されました. 工程2を求めるために以下を行います.
\begin{eqnarray}
&&\mathrm{TCAV_Q}_{C,k,l} = \frac{|\{x \in X_k : S_{C,k,l} (x) > 0 \}|}{|X_k|} \\ \nonumber
&&where X_k~is~all~inputs~with~that~given~label
\end{eqnarray}

つまり, Class kの全ての画像に対してSensitivityを計算して, そのsensitivityが正になったものの割合をTCAV Score \displaystyle \mathrm{TCAV_Q}_{C,k,l}とします.

以下の図は潜在空間が2次元の場合で単純化したときの計算方法です.

f:id:munemakun:20200523215719p:plain


Concept方向のベクトル (Concept Activation Vector)の求め方

計算方法は示しましたが, Concept C方向のベクトル\displaystyle v_C^lの求め方がわからないので計算することができませんよね. 次に実際に\displaystyle v_C^lをどのように求めるかを説明します. ここで, \displaystyle v_C^lをConcept Activation Vector (CAV)と呼びます.
(a)のデータセットを用いて計算します. 求め方自体は非常に簡単で, Concept Cを持つデータセットとConcept Cを持たないランダムなデータセットの線形分離を考えて, その分離平面の法線ベクトルが\displaystyle v_C^lとなります(d). これでTCAV Scoreを計算することができるようになりました (工程1, 2). ここで, TCAV Scoreは訓練済みモデル・Concept set (Conceptを含むデータセット)・Random set (Conceptを含まないデータセット)が与えられたら計算可能なので本文では, 便宜上, 関数\displaystyle {TCAV_Q}_{k,l}(f,C_1,C_2)と表現します.

最終的なTCAV Scoreの算出 (平均値+T検定(意味があるScoreかどうか検定))

CAVは, Concept set (Conceptを含むデータセット)とRandom set (Conceptを含まないデータセット)の線形分離によって求まります. つまり, Random setの選び方によって, CAVは大きく異なります (実装コードでは, Random setはImagenetのランダムなClassのデータセットとしている). そこで, CAVを安定させるために, 複数のRandom set (\displaystyle Rand_1,Rand_2,\cdots,Rand_m)を使用して, 求められたTCAV Scoreの平均値を最終的な重要度として使用します (評価実験では500セット).

\begin{eqnarray}
&&\mathrm{TCAV}_{C,k,l} =\frac{1}{m} \sum_{i=1}^m {TCAV_Q}_{k,l}(f,C,Rand_i)
&&where~ m~is~number~of~Random~set.
\end{eqnarray}

f:id:munemakun:20200523220945p:plain

Random setの選択方法によりCAVは大きく変化して, TCAV Scoreも大きく変化します. つまり, Random setの選択方法によって無意味なCAVが学習されて, 無意味なTCAV Scoreを出力される可能性があります. そこで, 意味のないTCAV Scoreを検定によって排除します (工程3). ここで, 無意味なTCAV ScoreをRandom sets同士によって得られるCAVによって算出されるTCAV Scoreと定義します (e.g. \displaystyle {TCAV_Q}_{k,l}(f,Rand_1,Rand_2)).

得られたTCAV Scoreが無意味である場合, 値を出力しません. 無意味なTCAV Scoreかどうか判別するために, 以下に定義する集合 \displaystyle Set_{Concept},Set_{Rand}で両側T検定を行います.

\begin{eqnarray}
Set_{Concept} &=& \{{TCAV_Q}_{k,l}(f,C,Rand_1), {TCAV_Q}_{k,l}(f,C,Rand_2),\\&& {TCAV_Q}_{k,l}(f,C,Rand_3), \cdots \} \\
Set_{Rand} &=& \{ {TCAV_Q}_{k,l}(f,Rand_1,Rand_2), {TCAV_Q}_{k,l}(f,Rand_1,Rand_3),\\&& {TCAV_Q}_{k,l}(f,Rand_2,Rand_3), \cdots \}
\end{eqnarray}
ここで, 暗黙の仮定として, Random set同士のCAVおよびTCAV Scoreは意味のないスコアであるとしています. そして, 帰無仮説は「Random set同士のTCAV ScoreとConcept setとRandom setのTCAV Scoreの平均値に有意な差がない」となります. 最終的なTCAV Scorは平均値を使うので, 母集団が正規分布に従っているという仮定を満たしているかということは, 考慮しないとすると, T検定を行う妥当性はあると考えらます.

よって, 帰無仮説が棄却される (Random set同士のTCAV ScoreとConcept setとRandom setのTCAV Scoreの平均値に有意な差がないと言えない)場合, Concept setとRandom setのTCAV Scoreの平均値は (Randomと比べて)意味のある値であるとして, TCAV Scoreの平均値を出力します. その一方で, 帰無仮説が棄却されない場合は, Concept setとRandom setのTCAV Scoreの平均値は (Randomと比べて)意味のない値であるとして, 何も出力しません.

つまり, 検定によって, できるだけ無意味であると思われるようなTCAV Scoreを排除しているということになります.

Conceptの相対的な比較 (Relative TCAV)

ここで, Concept同士に何らかの関連がある (茶髪 vs 黒髪)場合は, Conceptの相対的な比較が可能なRelative TCAVを使用します. 手法自体は, 至ってシンプルで, 先ほどTCAVで用いていたRandom setをConcept setに置き換えるだけです. つまり, Concept set同士の比較になります.

具体例を考えます. 比較したいConceptを{茶髪, 黒髪, 白髪, 金髪}として, 茶髪のConceptの相対的な重要度が知りたいとします. その場合は, 以下の集合 \displaystyle Set_{own},Set_{another}を用意して, T検定を行い棄却された場合に, \displaystyle Set_{own}の平均値を出力します.

\begin{eqnarray}
Set_{own} &=& \{{TCAV_Q}_{k,l}(f,茶髪,黒髪), {TCAV_Q}_{k,l}(f,茶髪, 白髪), \\&& {TCAV_Q}_{k,l}(f,茶髪, 金髪)\} \\
Set_{another} &=& \{ {TCAV_Q}_{k,l}(f,黒髪, 白髪), {TCAV_Q}_{k,l}(f,黒髪, 金髪),\\&& {TCAV_Q}_{k,l}(f,白髪, 黒髪), \cdots\}
\end{eqnarray}

※ TCAVに比べて極端にSetのサンプル数が少なくなることを防ぐために, Conceptの各データセットの一部を除去するなどをしてデータセットを水増しします (Githubのissueで言及).

評価実験

以下はImagenetで事前学習済みのGoogleNetとInceptionV3に対する各層でのTCAV Score (縦軸)です. アスタリスクがついているところは, T検定によって棄却されたTCAV Scoreなので, 値を0としています.
左図では, ピンポンボールという予測に対する人種のConceptのスコアです. 人種によって大きく偏りがあることがわかります (TCAV Scoreの計算方法上, スコアが0.5上回ったときが重要なConceptであると判断して良いでしょう). また, 右図では, シマウマの予測に対するジグザグ模様, 縞模様, ドット模様の重要度です. 縞模様のスコアが他の模様に比べて大きいことがわかり直感的に正しそうな結果が得られました.

f:id:munemakun:20200523225532p:plain

他にも興味深い実験を行っているので気になる方は論文の方を読んでみてください.

解説は以上となります!長くなってしまいましたがここまで読んで頂きありがとうございました!

参考文献

[1] Karen Simonyan, Andrea Vedaldi and Andrew Zisserman. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps. arXiv:1312.6034, 2013.