FlameF0X commited on
Commit
650d6db
·
verified ·
1 Parent(s): c32adac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -301
app.py CHANGED
@@ -1,310 +1,53 @@
1
  import os
2
- import torch
3
- import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
5
- import datetime
6
-
7
- # Model Constants
8
- MODEL_ID = "./model" # Local folder containing model files
9
- MAX_LENGTH = 384
10
- TEMPERATURE_MIN = 0.1
11
- TEMPERATURE_MAX = 2.0
12
- TEMPERATURE_DEFAULT = 0.7
13
- TOP_P_MIN = 0.1
14
- TOP_P_MAX = 1.0
15
- TOP_P_DEFAULT = 0.9
16
- TOP_K_MIN = 1
17
- TOP_K_MAX = 100
18
- TOP_K_DEFAULT = 40
19
- MAX_NEW_TOKENS_MIN = 16
20
- MAX_NEW_TOKENS_MAX = 1024
21
- MAX_NEW_TOKENS_DEFAULT = 256
22
-
23
- # CSS for the app
24
- css = """
25
- .gradio-container {
26
- background-color: #1e1e2f !important;
27
- color: #e0e0e0 !important;
28
- }
29
- .header {
30
- background-color: #2b2b3c;
31
- padding: 20px;
32
- margin-bottom: 20px;
33
- border-radius: 10px;
34
- text-align: center;
35
- }
36
- .header h1 {
37
- color: #66ccff;
38
- margin-bottom: 10px;
39
- }
40
- .snowflake-icon {
41
- font-size: 24px;
42
- margin-right: 10px;
43
- }
44
- .footer {
45
- text-align: center;
46
- margin-top: 20px;
47
- font-size: 0.9em;
48
- color: #999;
49
- }
50
- .parameter-section {
51
- background-color: #2a2a3a;
52
- padding: 15px;
53
- border-radius: 8px;
54
- margin-bottom: 15px;
55
- }
56
- .parameter-section h3 {
57
- margin-top: 0;
58
- color: #66ccff;
59
- }
60
- .example-section {
61
- background-color: #223344;
62
- padding: 15px;
63
- border-radius: 8px;
64
- margin-bottom: 15px;
65
- }
66
- .example-section h3 {
67
- margin-top: 0;
68
- color: #66ffaa;
69
- }
70
- """
71
-
72
- # Helper function to load model and tokenizer
73
- def load_model_and_tokenizer():
74
- global model, tokenizer, pipeline
75
-
76
- print("Loading Snowflake-G0-Release model and tokenizer...")
77
-
78
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
79
-
80
- if tokenizer.pad_token is None:
81
- tokenizer.pad_token = tokenizer.eos_token
82
-
83
- model = AutoModelForCausalLM.from_pretrained(
84
- MODEL_ID,
85
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
86
- device_map="auto"
87
- )
88
-
89
- pipeline = TextGenerationPipeline(
90
- model=model,
91
- tokenizer=tokenizer,
92
- return_full_text=False,
93
- max_length=MAX_LENGTH
94
- )
95
-
96
- print("Model loaded successfully!")
97
- return model, tokenizer, pipeline
98
-
99
- # Helper functions for generation
100
- def generate_text(
101
- prompt,
102
- temperature=TEMPERATURE_DEFAULT,
103
- top_p=TOP_P_DEFAULT,
104
- top_k=TOP_K_DEFAULT,
105
- max_new_tokens=MAX_NEW_TOKENS_DEFAULT,
106
- history=None
107
- ):
108
- if history is None:
109
- history = []
110
-
111
- history.append({"role": "user", "content": prompt})
112
-
113
- try:
114
- outputs = pipeline(
115
- prompt,
116
- do_sample=temperature > 0,
117
- temperature=temperature,
118
- top_p=top_p,
119
- top_k=top_k,
120
- max_new_tokens=max_new_tokens,
121
- pad_token_id=tokenizer.pad_token_id,
122
- num_return_sequences=1
123
- )
124
-
125
- response = outputs[0]["generated_text"]
126
- history.append({"role": "assistant", "content": response})
127
-
128
- formatted_history = []
129
- for entry in history:
130
- role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: "
131
- formatted_history.append(f"{role_prefix}{entry['content']}")
132
-
133
- return response, history, "\n\n".join(formatted_history)
134
-
135
- except Exception as e:
136
- error_msg = f"Error generating response: {str(e)}"
137
- history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"})
138
- return error_msg, history, str(history)
139
-
140
- def clear_conversation():
141
- return "", [], ""
142
-
143
- def apply_preset_example(example, history):
144
- return example, history
145
-
146
- # Example prompts
147
- examples = [
148
- "Write a short story about a snowflake that comes to life.",
149
- "Explain the concept of artificial neural networks to a 10-year-old.",
150
- "What are some interesting applications of natural language processing?",
151
- "Write a haiku about programming.",
152
- "Create a dialogue between two AI researchers discussing the future of language models."
153
- ]
154
-
155
- # Main app creation
156
- def create_demo():
157
- with gr.Blocks(css=css) as demo:
158
- # Header
159
- gr.HTML("""
160
- <div class="header">
161
- <h1><span class="snowflake-icon">❄️</span> Snowflake-G0-Release Demo</h1>
162
- <p>Experience the capabilities of the Snowflake-G0-Release language model</p>
163
- </div>
164
- """)
165
-
166
- # About accordion
167
- with gr.Accordion("About Snowflake-G0-Release", open=False):
168
- gr.Markdown("""
169
- ## Snowflake-G0-Release
170
-
171
- Initial release of the Snowflake series trained on DialogMLM-50K.
172
-
173
- ### Model Details
174
- - Architecture: SnowflakeCore
175
- - Hidden size: 384
176
- - Attention heads: 6
177
- - Layers: 4
178
- - Feed-forward dim: 768
179
- - Max seq length: 384
180
- - Vocabulary size: 30522 (BERT tokenizer)
181
-
182
- ### Features
183
- - Memory-efficient
184
- - Fused QKV for faster inference
185
- - Pre-norm for stability
186
- - Hugging Face compatible
187
- """)
188
-
189
- # Chat interface
190
- with gr.Column():
191
- chat_history_display = gr.Textbox(
192
- value="",
193
- label="Conversation History",
194
- lines=10,
195
- max_lines=30,
196
- interactive=False
197
- )
198
- history_state = gr.State([])
199
-
200
- with gr.Row():
201
- with gr.Column(scale=4):
202
- prompt = gr.Textbox(
203
- placeholder="Type your message here...",
204
- label="Your Input",
205
- lines=2
206
- )
207
- with gr.Column(scale=1):
208
- submit_btn = gr.Button("Send", variant="primary")
209
- clear_btn = gr.Button("Clear Conversation")
210
-
211
- response_output = gr.Textbox(
212
- value="",
213
- label="Model Response",
214
- lines=5,
215
- max_lines=10,
216
- interactive=False
217
- )
218
-
219
- # Generation Parameters
220
- with gr.Accordion("Generation Parameters", open=False):
221
- with gr.Column(elem_classes="parameter-section"):
222
- with gr.Row():
223
- with gr.Column():
224
- temperature = gr.Slider(
225
- minimum=TEMPERATURE_MIN,
226
- maximum=TEMPERATURE_MAX,
227
- value=TEMPERATURE_DEFAULT,
228
- step=0.05,
229
- label="Temperature"
230
- )
231
- top_p = gr.Slider(
232
- minimum=TOP_P_MIN,
233
- maximum=TOP_P_MAX,
234
- value=TOP_P_DEFAULT,
235
- step=0.05,
236
- label="Top-p (nucleus sampling)"
237
- )
238
- with gr.Column():
239
- top_k = gr.Slider(
240
- minimum=TOP_K_MIN,
241
- maximum=TOP_K_MAX,
242
- value=TOP_K_DEFAULT,
243
- step=1,
244
- label="Top-k"
245
- )
246
- max_new_tokens = gr.Slider(
247
- minimum=MAX_NEW_TOKENS_MIN,
248
- maximum=MAX_NEW_TOKENS_MAX,
249
- value=MAX_NEW_TOKENS_DEFAULT,
250
- step=8,
251
- label="Maximum New Tokens"
252
- )
253
 
254
- # Example Prompts
255
- with gr.Accordion("Example Prompts", open=True):
256
- with gr.Column(elem_classes="example-section"):
257
- gr.Examples(
258
- examples=examples,
259
- inputs=prompt,
260
- label="Click an example to try",
261
- examples_per_page=5
262
- )
263
 
264
- # Footer
265
- gr.HTML(f"""
266
- <div class="footer">
267
- <p>Snowflake-G0-Release Demo • Created with Gradio • {datetime.datetime.now().year}</p>
268
- </div>
269
- """)
270
 
271
- # Interactions
272
- submit_btn.click(
273
- fn=generate_text,
274
- inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
275
- outputs=[response_output, history_state, chat_history_display]
276
- )
277
 
278
- prompt.submit(
279
- fn=generate_text,
280
- inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state],
281
- outputs=[response_output, history_state, chat_history_display]
282
- )
283
 
284
- clear_btn.click(
285
- fn=clear_conversation,
286
- inputs=[],
287
- outputs=[prompt, history_state, chat_history_display]
288
- )
289
 
290
- return demo
291
 
292
- # Initialize model
293
- try:
294
- model, tokenizer, pipeline = load_model_and_tokenizer()
295
- except Exception as e:
296
- print(f"Error loading model: {str(e)}")
297
- with gr.Blocks(css=css) as error_demo:
298
- gr.HTML(f"""
299
- <div class="header" style="background-color: #ffebee;">
300
- <h1><span class="snowflake-icon">⚠️</span> Error Loading Model</h1>
301
- <p>There was a problem loading the model: {str(e)}</p>
302
- </div>
303
- """)
304
- demo = error_demo
305
- else:
306
- demo = create_demo()
307
 
308
- # Launch the app
309
- if __name__ == "__main__":
310
- demo.launch()
 
1
  import os
2
+ import json
3
+ from safetensors.torch import load_file, save_file
4
+
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ # Path to your model folder
8
+ model_dir = "./model"
9
+
10
+ # Step 1: Fix config.json if missing model_type
11
+ config_path = os.path.join(model_dir, "config.json")
12
+ if os.path.exists(config_path):
13
+ with open(config_path, "r") as f:
14
+ config = json.load(f)
15
+
16
+ if "model_type" not in config:
17
+ print("⚙️ Adding missing 'model_type' to config.json...")
18
+ # You can adjust 'gpt2' to whatever your real model type is
19
+ config["model_type"] = "gpt2"
20
+ with open(config_path, "w") as f:
21
+ json.dump(config, f, indent=2)
22
+ else:
23
+ print("✅ 'model_type' already exists in config.json.")
24
+ else:
25
+ raise FileNotFoundError("config.json not found in model directory!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Step 2: Fix .safetensors file metadata
28
+ safetensors_files = [f for f in os.listdir(model_dir) if f.endswith(".safetensors")]
29
+ if safetensors_files:
30
+ safetensors_path = os.path.join(model_dir, safetensors_files[0])
31
+ print(f"🛠 Fixing metadata in: {safetensors_path}")
 
 
 
 
32
 
33
+ state_dict = load_file(safetensors_path)
34
+ fixed_path = os.path.join(model_dir, "model_fixed.safetensors")
35
+ save_file(state_dict, fixed_path, metadata={"format": "pt"})
 
 
 
36
 
37
+ print(f"✅ Saved fixed safetensors: {fixed_path}")
38
+ else:
39
+ print("⚠️ No .safetensors file found to fix.")
 
 
 
40
 
41
+ # Step 3: Load model to verify it works now
42
+ print("🚀 Trying to load the model...")
 
 
 
43
 
44
+ model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
45
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
 
 
 
46
 
47
+ print("🎉 Model loaded successfully!")
48
 
49
+ # Step 4 (optional): Save model again safely
50
+ model.save_pretrained(model_dir, safe_serialization=True)
51
+ tokenizer.save_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ print("✅ Model and tokenizer saved safely with correct format!")