Tousifahamed commited on
Commit
8c5b923
·
verified ·
1 Parent(s): a222095

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -67
app.py CHANGED
@@ -1,78 +1,40 @@
1
  import torch
2
- torch.backends.quantized.engine = 'fbgemm'
3
-
4
- print("PyTorch version:", torch.__version__)
5
- print("Supported quantized engines:", torch.backends.quantized.supported_engines)
6
-
7
  import torch.nn as nn
 
8
  from transformers import AutoTokenizer
9
- from model import TransformerModel
10
  import gradio as gr
11
 
12
- # Load the tokenizer
13
- tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
14
-
15
- def load_quantized_model(checkpoint_path):
16
- # 1. Create the float model
17
- model = TransformerModel(
18
- vocab_size=49152,
19
- hidden_size=576,
20
- num_hidden_layers=30,
21
- num_attention_heads=9,
22
- intermediate_size=1536,
23
- num_key_value_heads=3,
24
- max_position_embeddings=2048,
25
- rms_norm_eps=1e-5,
26
- hidden_act="silu",
27
- tie_word_embeddings=True,
28
- )
29
-
30
- # 2. Load the actual checkpoint weights
31
- # If "quantized_model.pt" is a state_dict, do:
32
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
33
- model.load_state_dict(checkpoint) # or checkpoint["model_state_dict"] if saved that way
34
- model.eval()
35
 
36
- # 3. Dynamically quantize relevant layers
37
- # For embeddings, we typically use torch.quint8
38
- # so we don't run into any embedding dtype errors
39
- quantized_model = torch.quantization.quantize_dynamic(
40
- model,
41
- {nn.Linear, nn.Embedding},
42
- dtype=torch.quint8
43
- )
44
 
45
- return quantized_model
 
 
46
 
47
- # 4. Load the quantized model
48
- model = load_quantized_model("quantized_model.pt")
49
 
50
- # 5. Inference function
51
- def generate_text(prompt, max_length=50, temperature=1.0, top_k=50):
52
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
53
  with torch.no_grad():
54
- output_ids = model.generate(
55
- input_ids,
56
- max_length=max_length,
57
- temperature=temperature,
58
- top_k=top_k,
59
- do_sample=True,
60
- )
61
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
62
- return generated_text
63
-
64
- # 6. Gradio interface
65
- interface = gr.Interface(
66
- fn=generate_text,
67
- inputs=[
68
- gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
69
- gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"),
70
- gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
71
- gr.Slider(minimum=1, maximum=100, value=50, label="Top-k Sampling"),
72
- ],
73
- outputs=gr.Textbox(label="Generated Text"),
74
- title="Text Generation with Quantized SMOL-LM2",
75
- description="Generate text using a dynamically quantized SMOL-LM2 model.",
76
- )
77
 
78
- interface.launch()
 
1
  import torch
 
 
 
 
 
2
  import torch.nn as nn
3
+ from model import TransformerModel # or however you define your model classes
4
  from transformers import AutoTokenizer
 
5
  import gradio as gr
6
 
7
+ # Load half-precision state_dict
8
+ checkpoint = torch.load("model_weights_fp16.pt", map_location="cpu")
9
+ state_dict_fp16 = checkpoint["model_state_dict"]
10
+
11
+ # Create model in FP16
12
+ model = TransformerModel(
13
+ vocab_size=49152,
14
+ hidden_size=576,
15
+ num_hidden_layers=30,
16
+ num_attention_heads=9,
17
+ intermediate_size=1536,
18
+ num_key_value_heads=3,
19
+ max_position_embeddings=2048,
20
+ rms_norm_eps=1e-5,
21
+ hidden_act="silu",
22
+ tie_word_embeddings=True,
23
+ )
 
 
 
 
 
 
24
 
25
+ # Convert model to half precision
26
+ model.half()
 
 
 
 
 
 
27
 
28
+ # Load the half-precision weights
29
+ model.load_state_dict(state_dict_fp16, strict=False)
30
+ model.eval()
31
 
32
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
 
33
 
34
+ def generate_text(prompt, max_length=50):
35
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").half() # match model dtype
 
36
  with torch.no_grad():
37
+ output_ids = model.generate(input_ids, max_length=max_length, do_sample=True)
38
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ gr.Interface(fn=generate_text, inputs="text", outputs="text").launch()