FlameF0X commited on
Commit
3e16642
·
verified ·
1 Parent(s): 8600265

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PreTrainedTokenizerFast, AutoConfig
3
+ from safetensors.torch import load_model # Import load_model for safetensors
4
+ import torch
5
+ from EnhancedTransformerModel import EnhancedTransformerModel # Custom model class
6
+
7
+ # --- Load Model and Tokenizer ---
8
+ MODEL_PATH = "model.safetensors" # Path to the safetensors model file
9
+ CONFIG_PATH = "config.json" # Path to the model configuration
10
+ TOKENIZER_PATH = "tokenizer" # Path to the tokenizer directory
11
+
12
+ # Load configuration
13
+ config = AutoConfig.from_pretrained(CONFIG_PATH)
14
+
15
+ # Load tokenizer
16
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(TOKENIZER_PATH)
17
+
18
+ # Initialize the custom model
19
+ model = EnhancedTransformerModel(
20
+ vocab_size=config.vocab_size,
21
+ max_seq_length=config.max_position_embeddings,
22
+ d_model=config.hidden_size,
23
+ num_heads=config.num_attention_heads,
24
+ num_layers=config.num_hidden_layers,
25
+ ff_dim=config.intermediate_size,
26
+ dropout=0.1
27
+ )
28
+
29
+ # Load the model weights from safetensors
30
+ state_dict = load_model(MODEL_PATH)
31
+ model.load_state_dict(state_dict)
32
+ model.eval()
33
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ # --- Inference Function ---
36
+ def generate_text(prompt, max_length=50):
37
+ """
38
+ Generate text based on the input prompt using the trained model.
39
+ """
40
+ # Tokenize the input prompt
41
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=384)
42
+ input_ids = inputs["input_ids"].to(model.device)
43
+ attention_mask = inputs["attention_mask"].to(model.device)
44
+
45
+ # Generate output tokens
46
+ with torch.no_grad():
47
+ outputs = model(input_ids, attention_mask)
48
+ logits = outputs[:, -1, :] # Get the logits for the last token
49
+ next_token = torch.argmax(logits, dim=-1) # Greedy decoding
50
+
51
+ # Decode the generated token
52
+ generated_text = tokenizer.decode(next_token, skip_special_tokens=True)
53
+ return generated_text
54
+
55
+ # --- Gradio Interface ---
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# Snowflake-G0-stable Language Model")
58
+ gr.Markdown("This is an enhanced transformer language model trained on the DialogMLM-50K dataset. Try it out below!")
59
+
60
+ with gr.Row():
61
+ input_prompt = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...")
62
+ output_text = gr.Textbox(label="Generated Text")
63
+
64
+ submit_button = gr.Button("Generate")
65
+
66
+ def on_submit(prompt):
67
+ return generate_text(prompt)
68
+
69
+ submit_button.click(on_submit, inputs=input_prompt, outputs=output_text)
70
+
71
+ # Launch the app
72
+ demo.launch()