Confusion in the model architecture

#4
by Ink - opened

If I do a

for i, layer in enumerate(model.modules()):
    print(i,":\t",layer)

I get the first two layers as :

0 :      RavenForCausalLM(
  (transformer): ModuleDict(
    (wte): Embedding(65536, 5280)
    (prelude): ModuleList(
      (0-1): 2 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False) 
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)  
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (adapter): Linear(in_features=10560, out_features=5280, bias=False)   
    (core_block): ModuleList(
      (0-3): 4 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (coda): ModuleList(
      (0-1): 2 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=5280, out_features=65536, bias=False)
)
-------------------
1 :      ModuleDict(
  (wte): Embedding(65536, 5280)
  (prelude): ModuleList(
    (0-1): 2 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (adapter): Linear(in_features=10560, out_features=5280, bias=False)
  (core_block): ModuleList(
    (0-3): 4 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (coda): ModuleList(
    (0-1): 2 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (ln_f): RMSNorm()
)
..
..

where 0: and 1: are enumeration indices of the model.modules()
I am a little confused. Why is the entire prelude, core, coda modules repeating themselves for each item in model.modules(). Shoudn't it just be module 0,1: prelude, module 2,3,4,5: core and module 6,7: coda ?

Just a little confused on how this architecture should be structured.

Tom Goldstein's Lab at University of Maryland, College Park org

Hi! What is your usecase for model.modules()? You're getting a confusing references because you are iterating over a ModuleDict.

The layout of the modules is following Hugging Face layouts, separating the transformer from the model head:

        self.transformer = torch.nn.ModuleDict(
            dict(
                wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
                prelude=prelude,
                adapter=adapter,
                core_block=core_block,
                coda=coda,
                ln_f=RMSNorm(config.n_embd, eps=config.norm_eps),
            )
        )
        self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)

(see e.g. https://github.com/seal-rg/recurrent-pretraining/blob/bfd495b8ccc77b5c63674717c168a4b62b3c3e2a/recpre/raven_modeling_minimal.py#L331)

ok thanks. I was just trying to see if I can "Frankenstein" this model by adding additional layers to this setup. This is a very cool architecture !

Tom Goldstein's Lab at University of Maryland, College Park org

ok! I'm closing this for now then, feel free to come back if there are more questions, though!

JonasGeiping changed discussion status to closed
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment