|
|
|
|
|
import torch.nn as nn |
|
|
|
class LoRA_qkv(nn.Module): |
|
""" LoRA qkv module for Vision Transformer. """ |
|
def __init__( |
|
self, |
|
qkv: nn.Module, |
|
linear_a_q: nn.Module, |
|
linear_b_q: nn.Module, |
|
linear_a_v: nn.Module, |
|
linear_b_v: nn.Module, |
|
): |
|
super().__init__() |
|
self.qkv = qkv |
|
self.dim = qkv.in_features |
|
self.q_lora = nn.Sequential(linear_a_q, linear_b_q) |
|
self.v_lora = nn.Sequential(linear_a_v, linear_b_v) |
|
|
|
def forward(self, x): |
|
qkv = self.qkv(x) |
|
new_q = self.q_lora(x) |
|
new_v = self.v_lora(x) |
|
qkv[:, :, : self.dim] += new_q |
|
qkv[:, :, -self.dim :] += new_v |
|
return qkv |
|
|