在多頭注意力機制中,通常輸入的數(shù)據(jù)包括查詢(Q)、鍵(K)和值(V)。這些數(shù)據(jù)的維度以及權(quán)重矩陣的維度在多頭注意力機制中扮演關鍵角色。下面對數(shù)據(jù)及權(quán)重的維度進行解釋:
-
輸入數(shù)據(jù)(Queries, Keys, Values):文章來源:http://www.zghlxwxcb.cn/news/detail-805928.html
-
Queries (Q): 表示待查詢的信息,通常對應輸入序列的每個位置。其維度通常為 (batch_size, seq_length, q_dim),其中
q_dim
是查詢向量的維度。 -
Keys (K): 表示用于計算注意力分數(shù)的信息,也通常對應輸入序列的每個位置。其維度通常為 (batch_size, seq_length, key_dim),其中
key_dim
是鍵向量的維度。 -
Values (V): 表示待加權(quán)求和的信息,同樣對應輸入序列的每個位置。其維度通常為 (batch_size, seq_length, value_dim),其中
value_dim
是值向量的維度。
-
Queries (Q): 表示待查詢的信息,通常對應輸入序列的每個位置。其維度通常為 (batch_size, seq_length, q_dim),其中
-
權(quán)重矩陣:文章來源地址http://www.zghlxwxcb.cn/news/detail-805928.html
-
查詢權(quán)重矩陣 (Q_weights): 用于對查詢(Q)進行線性變換,將其映射到多個注意力頭的維度。其維度通常為 (q_dim, num_heads,?head_dim),其中
num_heads
是注意力頭的數(shù)量,head_dim
是每個注意力頭的維度。 - 鍵權(quán)重矩陣 (K_weights): 用于對鍵(K)進行線性變換,同樣映射到多個注意力頭的維度。其維度通常為 (key_dim, num_heads,?head_dim)。
- 值權(quán)重矩陣 (V_weights): 用于對值(V)進行線性變換,映射到多個注意力頭的維度。其維度通常為 (value_dim, num_heads,?head_dim)。
-
查詢權(quán)重矩陣 (Q_weights): 用于對查詢(Q)進行線性變換,將其映射到多個注意力頭的維度。其維度通常為 (q_dim, num_heads,?head_dim),其中
def glorot_uniform():
? return hk.initializers.VarianceScaling(scale=1.0,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?mode='fan_avg',
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?distribution='uniform')
def stable_softmax(logits: jax.Array) -> jax.Array:
? """Numerically stable softmax for (potential) bfloat 16."""
? if logits.dtype == jnp.float32:
? ? output = jax.nn.softmax(logits)
? elif logits.dtype == jnp.bfloat16:
? ? # Need to explicitly do softmax in float32 to avoid numerical issues
? ? # with large negatives. Large negatives can occur if trying to mask
? ? # by adding on large negative logits so that things softmax to zero.
? ? output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
? else:
? ? raise ValueError(f'Unexpected input dtype {logits.dtype}')
? return output
class Attention(hk.Module):
? """Multihead attention."""
? def __init__(self, config, global_config, output_dim, name='attention'):
? ? super().__init__(name=name)
? ? self.config = config
? ? self.global_config = global_config
? ? self.output_dim = output_dim
? def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
? ? """Builds Attention module.
? ? Arguments:
? ? ? q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
? ? ? m_data: A tensor of memories from which the keys and values are
? ? ? ? projected, shape [batch_size, N_keys, m_channels].
? ? ? mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
? ? ? nonbatched_bias: Shared bias, shape [N_queries, N_keys].
? ? Returns:
? ? ? A float32 tensor of shape [batch_size, N_queries, output_dim].
? ? """
? ? # Sensible default for when the config keys are missing
? ? key_dim = self.config.get('key_dim', int(q_data.shape[-1]))
? ? value_dim = self.config.get('value_dim', int(m_data.shape[-1]))
? ? num_head = self.config.num_head
? ? assert key_dim % num_head == 0
? ? assert value_dim % num_head == 0
? ? key_dim = key_dim // num_head
? ? value_dim = value_dim // num_head
? ? # weights維度(數(shù)據(jù)最后一維的維度數(shù),注意力頭數(shù)量,每個注意力頭映射的數(shù)據(jù)維度)
? ? q_weights = hk.get_parameter(
? ? ? ? 'query_w', shape=(q_data.shape[-1], num_head, key_dim),
? ? ? ? dtype=q_data.dtype,
? ? ? ? init=glorot_uniform())
? ? k_weights = hk.get_parameter(
? ? ? ? 'key_w', shape=(m_data.shape[-1], num_head, key_dim),
? ? ? ? dtype=q_data.dtype,
? ? ? ? init=glorot_uniform())
? ? v_weights = hk.get_parameter(
? ? ? ? 'value_w', shape=(m_data.shape[-1], num_head, value_dim),
? ? ? ? dtype=q_data.dtype,
? ? ? ? init=glorot_uniform())
? ? # bqa: 輸入張量 q_data 的軸的標記。(batch_size, seq_length, q_dim)
? ? # 'b' :batch 維度,'q':查詢序列維度,'a' 查詢向量的維度。所以,'bqa' 表示 q_data 的三個軸。
? ? # ahc:查詢權(quán)重矩陣的形狀, a:查詢向量的維度,h:注意力頭的數(shù)量,c: 每個注意力頭中查詢的維度。
? ? # key_dim**(-0.5) 注意力縮放,避免注意力分數(shù)過大或過小
? ??
? ? # jnp.einsum:Einstein Summation Notation(愛因斯坦求和約定)。
? ? # 一種緊湊、靈活的方式來指定和計算張量的乘積、求和和轉(zhuǎn)置等操作。
? ? q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
? ? k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
? ? v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
? ??
? ? # 注意力分數(shù),計算每個查詢(q)和鍵(k)之間的點積,以獲得注意力分數(shù)。
? ? # 結(jié)果維度為bhqk (batch_size, num_heads, num_q, num_k),?
? ? # num_q/num_k為查詢/鍵的數(shù)量,一般為 seq_length。
? ? logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
? ? if nonbatched_bias is not None:
? ? ? logits += jnp.expand_dims(nonbatched_bias, axis=0)
? ??
? ? # 注意力分數(shù)中加入mask
? ? logits = jnp.where(mask, logits, _SOFTMAX_MASK)
? ??
? ? # 對注意力分數(shù)進行softmax操作,我們得到每個位置對輸入序列的權(quán)重分配。
? ? weights = stable_softmax(logits)
? ??
? ? # 注意力分數(shù)對值進行加權(quán)求和,得到多頭注意力機制的輸出
? ? # 兩個向量的點積可以用于度量它們之間的相似性。如果兩個向量越相似,它們的點積就越大
? ? weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
? ? if self.global_config.zero_init:
? ? ? init = hk.initializers.Constant(0.0)
? ? else:
? ? ? init = glorot_uniform()
? ??
? ? # 帶有bias的門控注意力
? ? if self.config.gating:
? ? ? gating_weights = hk.get_parameter(
? ? ? ? ? 'gating_w',
? ? ? ? ? shape=(q_data.shape[-1], num_head, value_dim),
? ? ? ? ? dtype=q_data.dtype,
? ? ? ? ? init=hk.initializers.Constant(0.0))
? ? ? gating_bias = hk.get_parameter(
? ? ? ? ? 'gating_b',
? ? ? ? ? shape=(num_head, value_dim),
? ? ? ? ? dtype=q_data.dtype,
? ? ? ? ? init=hk.initializers.Constant(1.0))
? ? ? gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?gating_weights) + gating_bias
? ? ? gate_values = jax.nn.sigmoid(gate_values)
? ? ? # ⊙ 對應元素相乘
? ? ? weighted_avg *= gate_values
? ? o_weights = hk.get_parameter(
? ? ? ? 'output_w', shape=(num_head, value_dim, self.output_dim),
? ? ? ? dtype=q_data.dtype,
? ? ? ? init=init)
? ? o_bias = hk.get_parameter(
? ? ? ? 'output_b', shape=(self.output_dim,),
? ? ? ? dtype=q_data.dtype,
? ? ? ? init=hk.initializers.Constant(0.0))
? ? # 線性變換到輸出維度大小
? ? output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias
? ? return output
到了這里,關于haiku實現(xiàn)門控多頭注意力模塊的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網(wǎng)!