wav2vecで用いられるロス関数に関して

Facebook(現メタ) AIが公開した新しい音声フレームワークwav2vecに関して、ロス関数にフォーカスをあてた説明を行う。

wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations

2006.11477.pdf (arxiv.org)

音声からの文字起こしは様々なアプリケーションで必要とされているが、音声とそれから書き起こされた文字がペアになったデータセットを作成するのは容易ではない。そこで著者らはcontrastive learning(対照学習)に基づく音声データのための教師無し学習手法wav2vecを提案。

wav2vec 2.0を日本語で推論できるようにする - Fusic Tech Blog
wav2vecの概略図

wav2vecは、時間ごとに区切った音声データからCNNで特徴を抽出し、直後の特徴qを一旦保持しておき、さらにqをtransformerに入力することで得られた特徴を獲得し、transformer後の特徴とCNN後の特徴qのcontrastive lossを算出し、これを最小化することによって、音声データ表現を獲得するための汎用モデルを学習するアーキテクチャとなっている。

また、Transformerが使われていることからわかるように、wav2vecはNLPライクのアーキテクチャであり、自然言語でいうところの単語を区切られた音声データに置き換えたモデルだとも考えることができる。

wav2vecで用いられるロス関数に関して

本章では、論文から抜粋した記述を日本語訳してさらに補足説明を加える形でロス関数を理解していこうと思う。では以下より論文の抜粋を行っていく。

3.2. Objective

During pre-training, we learn representations of speech audio by solving a contrastive task Lm which requires to identify the true quantized latent speech representation for a masked time step within a set of distractors. This is augmented by a codebook diversity loss Ld to encourage the model to use the codebook entries equally often.

3.2. wav2vecの目的

事前学習(ラベル無し学習)では、マスクされた時間ステップ(トークン)の真の量子化された潜在的な音声表現を、distractorの集合の中から識別することを要求するcontrastive task: Lmを解くことによって、音声の表現を学習する。これはdiversity loss: Ldによって補強され、モデルがコードブックエントリーを等しく使用するように促す。

$$\mathcal{L} = \mathcal{L}_m + \alpha \mathcal{L}_d\;\;\;\;(2)$$

αはハイパーパラメータである。

  • 式(2)より、Lはcontrastive loss: Lmとdiversity loss: Ldで構成されていることわかる
  • distractorは (多項式選択問題の)正解以外の選択肢のことで、ここでは対照学習におけるネガティブサンプルのことをいっている。

Contrastive Loss

Given context network output ct centered over masked time step t, the model needs to identify the true quantized latent speech representation qt in a set of K + 1 quantized candidate representations ˜q ∈ Qt which includes qt and K distractors [23, 54]. Distractors are uniformly sampled from other masked time steps of the same utterance. The loss is defined as

where we compute the cosine similarity \(sim(\mathbf{a},\mathbf{b}) = \mathbf{a}^{T}\mathbf{b}/\|a\|\|b\|\) between context representations and quantized latent speech representations [19, 6].

Contrastive Loss (対照学習のためのロス)

マスクされた時間ステップtを中心とするcontext network output: ctが与えられると、モデルはqtとK個のディストラクターを含むK + 1個の量子化された候補表現〜q∈Qtの集合から真の量子化された潜在音声表現qtを特定する必要がある [23, 54].distractorは同じ発話のマスクされた他の時間ステップから一様にサンプリングされる。損失は以下のように定義される。

$$\mathcal{L}_m = -\mathrm{log} \frac{\mathrm{exp}(sim(\mathbf{c}_t, \mathbf{q}_t)/\kappa)}{\sum_{\mathbf{\tilde{q}}\in\mathbf{Q}_t}{\mathrm{exp}(sim(\mathbf{c}_t, \mathbf{\tilde{q}})/\kappa})} \;\;\;\;(3)$$

ここで、文脈表現と量子化された潜在音声表現の間の余弦類似度\(sim(\mathbf{a},\mathbf{b}) = \mathbf{a}^{T}\mathbf{b}/\|a\|\|b\|\)を計算する [19, 6].

  • context network outputはtransformerの出力のことである
  • この記述から対照学習におけるポジティブとネガティブペアの作り方がわかる。
  • wav2vecでは、transformerの特徴(出力)1つに対して、CNNから複数の特徴(出力)をサンプルするようだ。
  • ポジティブペア(正解ペア)は、同じ時間ステップtにおいて、CNNとtransformerのそれぞれから抽出された特徴のペアのことである。
  • ネガティブペア(不正解ペア)は、時間ステップtにおいてtransformerを用いて抽出された特徴と、それ以外の時間ステップからCNNを用いて得られた特徴とのペアである。
  • 式(3)のシグマの下の〜q∈Qtは量子化されたK + 1個の候補表現であり、K個のネガティブサンプルと、1個のポジティブサンプルで構成されている

Diversity Loss.

The contrastive task depends on the codebook to represent both positive and negative examples and the diversity loss Ld is designed to increase the use of the quantized codebook representations [10]. We encourage the equal use of the V entries in each of the G codebooks by maximizing the entropy of the averaged softmax distribution l over the codebook entries for each 3 codebook p¯g across a batch of utterances; the softmax disribution does not contain the gumbel noise nor a temperature:

Diversity Loss.

対照タスクはコードブックが正例と負例の両方を表現することに依存しており、多様性損失Ldは量子化されたコードブック表現の利用を増加させるように設計されている[10]。我々は、発話のバッチにわたる各 3 コードブック p¯g のコードブック・エントリに対する平均化ソフトマックス分布 l のエントロピーを最大化することによって、各 G コードブックの V エントリの均等使用を奨励する;ソフトマックス分布はガンベル雑音も温度も含まない。

$$\mathcal{L}_d = \frac{1}{GC}\sum^{G}_{g=1}{-H(\bar{p}_g)} = \frac{1}{GV}\sum_{g=1}^{G}\sum_{v=1}^{V}{\bar{p}_{g,v}\mathrm{log}\bar{p}_{g,v}}\;\;\;\;(4)$$

  • Diversity Lossは出力の偏りを抑えることが目的で、出力がなるべく等しくなるように設計されている
  • 役割としては正則化項なのではないか
  • 式(4)では、Gumbel-Softmax分布のエントロピーを最大化し、モデルが常に利用可能なすべてのコードブックエントリの小さなサブグループから選択するのを防いでいる

以上。wav2vecのロス関数に関する説明を終了する。

参考

An Illustrated Tour of Wav2vec 2.0 | Jonathan Bgn

【wav2vec 2.0】Facebook AIが新しい音声認識フレームワークを公開!自己教師あり学習により正解ラベルなしで高精度を達成!? | AI-SCHOLAR | AI:(人工知能)論文・技術情報メディア

Transformer

Posted by vastee