Tousifahamed commited on
Commit
825827f
·
verified ·
1 Parent(s): 92671c4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
- import torch.nn as nn # Added missing import
3
- import torch.ao.quantization as quantization
4
  from transformers import AutoTokenizer
5
  from model import TransformerModel
6
  import gradio as gr
@@ -9,7 +9,6 @@ import gradio as gr
9
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
10
 
11
  def load_quantized_model(checkpoint_path):
12
- # Define the model architecture
13
  model = TransformerModel(
14
  vocab_size=49152,
15
  hidden_size=576,
@@ -23,29 +22,37 @@ def load_quantized_model(checkpoint_path):
23
  tie_word_embeddings=True,
24
  )
25
 
26
- # Dynamic quantization for embeddings
27
- model.embed_tokens = torch.ao.quantization.quantize_dynamic(
28
  model.embed_tokens, {nn.Embedding}, dtype=torch.qint8
29
  )
30
- model.embed_positions = torch.ao.quantization.quantize_dynamic(
31
  model.embed_positions, {nn.Embedding}, dtype=torch.qint8
32
  )
33
 
34
- # Static quantization for other layers
35
- model.qconfig = torch.ao.quantization.default_qconfig
36
- model = torch.ao.quantization.prepare(model, inplace=False)
37
- model = torch.ao.quantization.convert(model, inplace=False)
 
 
 
 
 
 
 
 
38
 
39
  # Load checkpoint
40
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
41
  model.load_state_dict(checkpoint)
42
-
43
  model.eval()
44
  return model
45
 
46
 
47
  # Load the quantized model
48
- model = load_quantized_model("checkpoint_quantized.pt")
49
 
50
  # Function to generate text
51
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.quantization # <--- Use the older namespace for default_qconfig
4
  from transformers import AutoTokenizer
5
  from model import TransformerModel
6
  import gradio as gr
 
9
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
10
 
11
  def load_quantized_model(checkpoint_path):
 
12
  model = TransformerModel(
13
  vocab_size=49152,
14
  hidden_size=576,
 
22
  tie_word_embeddings=True,
23
  )
24
 
25
+ # Dynamic quant for embeddings
26
+ model.embed_tokens = torch.quantization.quantize_dynamic(
27
  model.embed_tokens, {nn.Embedding}, dtype=torch.qint8
28
  )
29
+ model.embed_positions = torch.quantization.quantize_dynamic(
30
  model.embed_positions, {nn.Embedding}, dtype=torch.qint8
31
  )
32
 
33
+ # Static quant config for the rest of the model
34
+ model.qconfig = torch.quantization.get_default_qconfig("fbgemm") # CPU
35
+ model = torch.quantization.prepare(model, inplace=False)
36
+
37
+ #
38
+ # >>> RUN CALIBRATION HERE (forward pass with sample data) <<<
39
+ # e.g. with torch.no_grad():
40
+ # for input_ids in some_calibration_loader:
41
+ # outputs = model(input_ids)
42
+ #
43
+
44
+ model = torch.quantization.convert(model, inplace=False)
45
 
46
  # Load checkpoint
47
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
48
  model.load_state_dict(checkpoint)
49
+
50
  model.eval()
51
  return model
52
 
53
 
54
  # Load the quantized model
55
+ model = load_quantized_model("quantized_model.pt")
56
 
57
  # Function to generate text
58
  def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):