import torch from torch import nn from .LMConfig import LMConfig import math import torch.nn.functional as F from typing import Optional from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def _norm(self, x): return x * torch.rsqrt(self.eps + x.pow(2).mean(-1, keepdim = True)) def forward(self, x): x = self._norm(x.float()).type_as(x) # 用 float 提高精确度,防止溢出 x = x * self.weight return x def repeat_kv(x: torch.Tensor, n_rep: int): ''' x 是 key 或者 value ,大小是 (batch_size, seq_len, kv_heads, head_dim) 要把它复制 n_rep 遍,变成 (batch_size, seq_len, kv_heads * n_rep, head_dim) ''' if n_rep == 1: return x else: bs, seq_len, kv_heads, head_dim = x.shape return x[:,:,:,None,:].expand(bs, seq_len, kv_heads, n_rep, head_dim).reshape(bs, seq_len, kv_heads * n_rep, head_dim) # expand 的用法:只能拓展大小为1的维度,或者增加维度。并且expand并没有实际占用内存,它只是用广播而已 # 这句不能用 view,要不然报错: # RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. def get_rotation(dim: int, seq_len: int, base: float = 10000.0): ''' 获得旋转矩阵,就是一个(seq_len, dim // 2)大小的矩阵W。 W[a][b] = cos(a*θ_b) + i*sin(a*θ_b) ,实际上就是模长为 1 ,旋转角度为 a*θ_b 的虚数向量 但是要注意,这里的 dim 并不是模型的大小,而是在每个注意力头里的 tensor 的大小。也就是 args.dim // args.n_heads ''' angles = 1.0 / (base ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim)) seq = torch.arange(0, seq_len, device = angles.device) angle_matrix = torch.outer(seq, angles).float() weight = torch.polar(torch.ones_like(angle_matrix), angle_matrix) return weight def position_encoding(xq, xk, weight): # 先把 xq 和 xk 转化成虚数 # xq.shape = [bsz, seq_len, n_heads, head_dim] # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # reshape 能处理内存不连续情况,view 不行 # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) xq = xq.float() xk = xk.float() xq = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) # reshape 能处理内存不连续情况,view 不行 xk = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) # 相乘,然后转化成实数 # xq_ 变成[bsz, seq_len, n_heads, head_dim // 2],把weight变成[1, seq_len, 1, head_dim // 2] # xq_pos = torch.view_as_real(weight[None, :, None, :] * xq_).flatten(3) # xk_pos = torch.view_as_real(weight[None, :, None, :] * xk_).flatten(3) xq = torch.view_as_real(weight[None, :, None, :] * xq).flatten(3) xk = torch.view_as_real(weight[None, :, None, :] * xk).flatten(3) # flatten(3)是把第三维度后面的内容全部合并成一维,因为虚数变实数之后就变成(b, s, n_h, h // 2, 2)了 # assert xq_pos.shape == xq.shape # assert xk_pos.shape == xk.shape return xq, xk class Attention(nn.Module): def __init__(self, args: LMConfig) -> None: super().__init__() self.dim = args.dim # 模型维度 512 self.n_heads = args.n_heads # 注意力头数 16 self.n_kv_heads = args.n_kv_heads # kv 头数 8 assert self.n_heads % self.n_kv_heads == 0 self.n_rep = self.n_heads // self.n_kv_heads # kv 重复次数 assert self.dim % self.n_heads == 0 self.head_dim = self.dim // self.n_heads # 每个注意力头里面的张量维度 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias = False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False) self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias = False) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias = False) # self.attn_dropout = nn.Dropout(args.dropout) # 注意力 dropout self.resid_dropout = nn.Dropout(args.dropout) # 残差 dropout self.dropout = args.dropout # 给 flash attn 用的 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn # 判断是否使用 Flash Attention。后者令 is_causal=True 可以实现掩码注意力功能。 if not self.flash: mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal = 1) # upper triangular 和 lower triangular self.register_buffer("mask", mask) # 这样 mask 就不会反向更新了 # kv 缓存。因为测试的时候参数不再更新,所以每个 token 生成的 xk 和 xv 都不变。因此可以直接复用 self.k_cache, self.v_cache = None, None def forward(self, x: torch.Tensor, weight: torch.Tensor, use_kv_cache = False): # x 是(seq_len, dim)的输入,weight是旋转矩阵 # print("进来了 FORWARD!!") bsz, seq_len, _ = x.shape # print("进来了 FORWARD!!") if use_kv_cache and self.eval(): # 评估模式,就是测试阶段的意思 # if self.k_cache is None or self.k_cache.shape[1] == x.shape[1] - 1: # x 的词数量比 k 缓存多一个 if self.k_cache is None or self.k_cache.shape[1] != x.shape[1] - 1: # print("缓冲是 None!") # self.k_cache.shape[1] != x.shape[1] - 1 这一句不能不写! # 因为你每处理一段新的上下文,是不会创建新模型对象的。换言之只用一个模型,处理若干个问题 # 那么当你切换到新的上下文的时候,你的 kv 缓冲按理必须要清空。 # 那么怎么判断你是否切换了新的上下文呢?就用 self.k_cache.shape[1] != x.shape[1] - 1 方法 # 否则你会出现 reshape 大小不匹配的问题! xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) else: token = x[:, -1:, :] # print("1 号concat:") xq = torch.concat((torch.zeros_like(x[:, : -1, :]), self.wq(token)), dim = 1) # 只更新最后一个 token 的值,因为后面要有残差,所以相当于对于前面的词向量什么都不做 # print("2 号concat:") xk = torch.concat((self.k_cache, self.wk(token)), dim = 1) # print("3 号concat:") xv = torch.concat((self.v_cache, self.wv(token)), dim = 1) # 复用之前的 xw 和 xv self.k_cache, self.v_cache = xk, xv else: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.reshape(bsz, seq_len, self.n_heads, self.head_dim) xk = xk.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) xv = xv.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) xq, xk = position_encoding(xq, xk, weight) # 给 q 和 k 加位置编码 xk, xv = repeat_kv(xk, self.n_rep), repeat_kv(xv, self.n_rep) # 把 k 和 v 重复 n_rep 遍 xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) if self.flash: # 直接算出 softmax(...) @ xv # output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask = None, # dropout_p = self.dropout if self.training else 0.0, # is_causal = True) # is_causal = True 表示使用掩码 x = F.scaled_dot_product_attention(xq, xk, xv, attn_mask = None, dropout_p = self.dropout if self.training else 0.0, is_causal = True) # is_causal = True 表示使用掩码 # self.training 用来指示模型当前是否处于训练模式 else: # scores = xq @ xk.transpose(2, 3) / math.sqrt(self.head_dim) # (bs, n_head, seq_len, seq_len) # assert hasattr(self, "mask") # scores = scores + self.mask[:, :, : seq_len, : seq_len] # 掩码,把后文盖住 # output = F.softmax(scores, dim = -1) @ xv # (bs, n_head, seq_len, head_dim) x = xq @ xk.transpose(2, 3) / math.sqrt(self.head_dim) # (bs, n_head, seq_len, seq_len) assert hasattr(self, "mask") x = x + self.mask[:, :, : seq_len, : seq_len] # 掩码,把后文盖住 x = F.softmax(x, dim = -1) @ xv # (bs, n_head, seq_len, head_dim) x = x.transpose(1, 2).contiguous().view(bsz, seq_len, -1) # (bs, seq_len, dim) x = self.resid_dropout(self.wo(x)) return x class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multi: int, dropout: float) -> None: # hidden_dim 默认是 None # multi 隐藏层维度的倍数,默认为 64 # dropout 默认是 0.0 super().__init__() if hidden_dim is None: hidden_dim = 4 * dim hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multi * ((hidden_dim + multi - 1) // multi) # 没理解这么做的目的 # 最后算出来是 1408 self.w1 = nn.Linear(dim, hidden_dim, bias = False) self.w2 = nn.Linear(dim, hidden_dim, bias = False) self.w3 = nn.Linear(hidden_dim, dim, bias = False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor): # return self.dropout(self.w3(F.silu(self.w1(x)) * self.w2(x))) # return self.w3(F.silu(self.w1(x)) * self.w2(x)) x_2 = self.w2(x) x = self.w1(x) x = F.silu(x) x = x * x_2 x = self.w3(x) return x class MoEGate(nn.Module): def __init__(self, args: LMConfig) -> None: super().__init__() self.topk = args.num_experts_per_tok # top-k 里面的 k ,也就是选择的专家个数 self.gating_dim = args.dim # 门控维度,跟模型大小是一样的 self.n_routed_experts = args.num_experts_per_tok # 专家个数 self.scoring_func = args.scoring_func # 评分函数 self.norm_topk_prob = args.norm_topk_prob # 标准化 top-k 概率 self.alpha = args.aux_loss_alpha # 辅助损失函数的 alpha 参数 self.seq_aux = args.seq_aux # 是否在序列级别上计算辅助损失,默认为 True self.w = nn.Linear(self.gating_dim, self.n_routed_experts, bias = False) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_normal_(self.w.weight) # 初始化参数 def forward(self, x: torch.Tensor): bsz, seq_len, dim = x.shape hidden_states = x.view(-1, dim) scores = self.w(hidden_states) # (bsz * seq_len, n_routed_experts) if self.scoring_func == "softmax": scores = F.softmax(scores, dim = -1) else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') # (bsz * seq_len, n_routed_experts),score[i][j] 表示每个序列里第 j 个专家的权重 / 评分 topk_weight, topk_idx = torch.topk(scores, self.topk, dim = -1, sorted = False) # 获得k个最大的权重和对应的专家 (bsz * seq_len, k) if self.norm_topk_prob: # 原文里还有判断self.topk > 1,我认为没有必要 denominator = topk_weight.sum(dim = -1) + 1e-20 topk_weight = topk_weight / denominator # 归一化权重 if self.training and self.alpha > 0: # 训练阶段,并且 alpha > 0。要是 alpha <= 0 那 aux_loss 就是非正数,就是不合法的 loss scores_for_aux = scores # (bsz * seq_len, n_routed_experts) aux_topk = self.topk topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # (bsz, seq_len * k) if self.seq_aux: # 在序列级别上计算辅助损失 scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1).mean(dim = 1) # 第一步:算出 ce (bsz * n_routed_experts) ce = torch.zeros(bsz, self.n_routed_experts) ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device = hidden_states.device).div_( seq_len * aux_topk / self.n_routed_experts )) # 保留 topk_idx 里面的bsz,用 idx 作为第二维度在 ce 里进行累加 # 每个 batch 里使用的专家总数就是 seq_len * k ,这可能就是为什么要除以 seq_len * aux_topk # 最后还要乘一个专家个数,这个在下面不在序列级别算损失的时候也要用 # 第二步:ce 和 scores_for_seq_aux 按位相乘 # 第三步:按位相乘的效果按专家求和,得到长为 bsz 的序列,然后求均值,再乘以 alpha aux_loss = (ce * scores_for_seq_aux).sum(dim = -1).mean() * self.alpha else: # 第一步:算出 ce (1, n_routed_experts) # 具体方法是把 idx 展平做出独热编码,然后求均值;然后还要乘以专家个数 ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes = self.n_routed_experts).mean(dim = 0) ce = ce * self.n_routed_experts # 保证维度是专家个数。因为不一定所有专家都被选上,所以不指定的话有可能独热码维度比专家小 # 独热码返回的维度是 (bsz * seq_len * k, n_routed_experts) # 第二步:算出每个专家权重的均值 (1, n_routed_experts) # 具体方法是对 scores_for_aux (也就是上面求出的 scores )求均值 # 第三步:对上面二者求和再乘以 alpha aux_loss = (ce * scores_for_aux.mean(dim = 0)).sum() * self.alpha else: aux_loss = None return topk_weight, topk_idx, aux_loss class MOEFeedForward(nn.Module): def __init__(self, args: LMConfig) -> None: super().__init__() self.topk = args.num_experts_per_tok # top-k 里面的 k ,也就是选择的专家个数 self.n_routed_experts = args.num_experts_per_tok # 专家个数 self.experts = nn.ModuleList([ FeedForward(dim = args.dim, hidden_dim = args.hidden_dim, multi = args.multiple_of, dropout = args.dropout) for _ in range(self.n_routed_experts) ]) self.gate = MoEGate(args) if args.n_shared_experts is not None: self.shared_experts = FeedForward( dim = args.dim, hidden_dim = args.hidden_dim, multi = args.multiple_of, dropout = args.dropout ) def work(self, x, topk_weight, topk_idx): bsz, seq_len, dim = x.shape # 先把 x 复制 k 份 x = x.view(-1, dim) x = x.repeat_interleave(self.topk, dim = 0) # (bsz * seq_len * k, dim) # 把权重展平 flat_topk_idx = topk_idx.view(-1) # (bsz * seq_len * k) # 过专家 y = torch.empty_like(x, dtype = torch.float16) # (bsz * seq_len * k, dim) for i in range(self.n_routed_experts): y[flat_topk_idx == i] = self.experts[i](x[flat_topk_idx == i]) # 乘以权重 y = y.view(bsz, seq_len, self.topk, -1) y = y * topk_weight.unsqueeze(-1).sum(dim = 1) # (bsz * seq_len, dim) # 恢复成输入的形状 y = y.view(bsz, seq_len, -1) return y def forward(self, x): # x 是(bsz, seq_len, dim) topk_weight, topk_idx, _ = self.gate(x) # 确定选哪些专家及其权重 # (bsz * seq_len, k) if self.training: y = self.work(x, topk_weight, topk_idx) else: with torch.no_grad: y = self.work(x, topk_weight, topk_idx) if self.args.n_shared_experts is not None: y = y + self.shared_experts(y) return y class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: LMConfig) -> None: # layer_id 是当前块的编号 super().__init__() self.attn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps) self.attn = Attention(args) self.ffn_norm = RMSNorm(dim = args.dim, eps = args.norm_eps) if args.use_moe: self.feed_forward = MOEFeedForward(args) else: self.feed_forward = FeedForward(dim = args.dim, hidden_dim = args.hidden_dim, multi = args.multiple_of, dropout = args.dropout) def forward(self, x, weight, use_kv_cache = False): # print("吾来也!!") # print("第一步") # def forward(self, x: torch.Tensor, weight: torch.Tensor, use_kv_cache = False) x = x + self.attn(self.attn_norm(x), weight, use_kv_cache) # print("第二步") x = x + self.feed_forward(self.ffn_norm(x)) # print("第三步") return x # class Transformer(nn.Module): class Transformer(PreTrainedModel): config_class = LMConfig last_loss: Optional[torch.Tensor] def __init__(self, args: LMConfig = None) -> None: super().__init__(args) if not args: args = LMConfig() self.args = args self.embedding = nn.Embedding(args.vocab_size, args.dim) self.dropout = nn.Dropout(args.dropout) # 在 embedding 之后就要进行一个 dropout self.layers = nn.ModuleList() # Transformer 块 for i in range(args.n_layers): self.layers.append(TransformerBlock(i, args)) # 下面是旋转位置嵌入的权重,尺寸是 (max_seq_len, weight.dim // 2) rotation_weight = get_rotation(dim = args.dim // args.n_heads, seq_len = args.max_seq_len) self.register_buffer('rotation_weight', rotation_weight, persistent = False) self.norm = RMSNorm(dim = args.dim, eps = args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias = False) # 最后的线性层 self.embedding.weight = self.output.weight self.OUT = CausalLMOutputWithPast() def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, use_kv_cache = False, **key_args): # Optional[torch.Tensor] 的意思就是可以传入张量,也可以传入 None if 'input_ids' in key_args: tokens = key_args['input_ids'] if 'attention_mask' in key_args: tokens = key_args['attention_mask'] _, seq_len = tokens.shape # 输入的文本,一共有 bsz 个 batch ,每个文本的长度是 seq_len x = self.embedding(tokens) # x 的尺寸是 (bsz, seq_len, dim) x = self.dropout(x) # print("embedding完成!") # 下面就是获得位置编码,然后过 transformer block r_w = self.rotation_weight[: seq_len] # print("旋转编码完成!") for layer in self.layers: x = layer(x, r_w, use_kv_cache) # print("正在训练......") # print("Transformer块完成!") x = self.norm(x) # 过归一化 if targets is not None: # 就是训练阶段的意思 logits = self.output(x) # (bsz, seq_len, vocal_size) # print("算出预测值!") # self.last_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index = -1) last_loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index = -1) # print("算出误差") # targets 是每个输入序列的下一个词 # ignore_idx 表示填充值,这里就是 -1 ,表示遇到 tensor 里有 -1 的直接当作空值处理 # F.cross_entropy 会先自动进行 softmax else: # 就是评估阶段的意思 logits = self.output(x[:, [-1], :]) # (bsz, 1, vocal_size),也就是每个 batch 的最后一个序列 # self.last_loss = None last_loss = None # 没明白为什么一个是类变量,一个不是 self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('last_loss', last_loss) # print("返回!") return self.OUT @torch.inference_mode() # 可参看 https://zhuanlan.zhihu.com/p/667025336 def generate(self, idx, eos, max_new_tokens, temperature = 0.7, top_k = None, stream = True, repetition_penalty = 1., use_kv_cache = True): # idx 是 (bsz, seq_len),每个 seq 里面都是文本的词的下标 # eos 是 结束符。如果最后推出来生成的内容是结束符,就停止生成 # max_new_tokens 是最多能生成的词的个数 # temperature 是用来平滑概率的。在原来概率的基础上 * temperature 再进行 softmax 就可以缩小各词概率之间的差距,让选择概率小的词的几率增大 # top_k 是 Top-K Sampling 的参数。如果 top_k 是 None,那就不进行 Top-K Sampling ;否则就让 Top-K Sampling 里面的 k 是 top_k # stream 指的是流式输出。如果要流式输出,那就是说每次生成新词就直接输出;否则就是全生成完毕再输出 # repetition_penalty 是惩罚项,用来降低前文出现过的词的概率,否则可能会出现循环文本 bsz, seq_len = idx.shape while idx.shape[1] < max_new_tokens - 1: # 文本的大小不超过最大的文本长度,就可以继续 res = self(idx, use_kv_cache = use_kv_cache) logits = res.logits # (bsz, vocal_size) logits = logits[:, -1, :] # (bsz, 1, vocal_size) # 降低前文出现过的词的概率 for b in range(bsz): # 遍历每一个 batch for token in set(idx.tolist()[b]): # 获得不重复的 token 序列 logits[b, token] /= repetition_penalty # 利用 temperature 进行概率的平滑 if temperature == 0.0: # 直接选概率最高的 token ,idx_nxt 的尺寸是 (bsz, 1) _, idx_nxt = torch.topk(logits, k = 1, dim = -1) else: logits = logits / temperature if top_k is not None: # 把概率排名在 k 以外的概率都设置成 0 ,这样就防止选择到概率过低的 token v, _ = torch.topk(logits, k = min(top_k, logits.shape[-1]), dim = -1) # v 在每个 batch 里是从大到小排好序的,尺寸是 (bsz, top_k) logits[logits < v[:, [-1]]] = -float("Inf") # v[:, [-1]] 就是每一个 batch 里的最小概率的概率值,大小是 (bsz, 1) # logits < v[:, [-1]] 返回一个大小和 logits 一样的,由 True 和 False 组成的矩阵 # 具体来说,如果 logits[i][j] < v[i][0],那[i][j]位置就返回 True ,否则是 False # 设置成负无穷,这样用 softmax 转成概率就是 0 了 probs = F.softmax(logits, dim = -1) idx_nxt = torch.multinomial(probs, num_samples = 1, generator = None) # 根据 prob 随机选择一个 token if idx_nxt == eos: break # 可能有问题 idx = torch.concat((idx, idx_nxt), dim = -1) # 放入新生成的内容 if stream: yield idx[:, seq_len:] # 每次新生成内容就s输出所有新生成的东西 if not stream: yield idx[:, seq_len:] @torch.inference_mode() def eval_answer(self, idx): # 没看出有什么作用 idx_cond = idx if idx.shape[1] < self.args.max_seq_len else idx[:, -self.args.max_seq_len:] res = self(idx_cond) logits = res.logits[:, -1, :] return logits