|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import datetime |
|
|
|
model = SnowflakeCore.from_pretrained("FlameF0X/Snowflake-GO-Release") |
|
|
|
|
|
MODEL_ID = "FlameF0X/Snowflake-G0-Release" |
|
MAX_LENGTH = 384 |
|
TEMPERATURE_MIN = 0.1 |
|
TEMPERATURE_MAX = 2.0 |
|
TEMPERATURE_DEFAULT = 0.7 |
|
TOP_P_MIN = 0.1 |
|
TOP_P_MAX = 1.0 |
|
TOP_P_DEFAULT = 0.9 |
|
TOP_K_MIN = 1 |
|
TOP_K_MAX = 100 |
|
TOP_K_DEFAULT = 40 |
|
MAX_NEW_TOKENS_MIN = 16 |
|
MAX_NEW_TOKENS_MAX = 1024 |
|
MAX_NEW_TOKENS_DEFAULT = 256 |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
background-color: #1e1e2f !important; |
|
color: #e0e0e0 !important; |
|
} |
|
.header { |
|
background-color: #2b2b3c; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
border-radius: 10px; |
|
text-align: center; |
|
} |
|
.header h1 { |
|
color: #66ccff; |
|
margin-bottom: 10px; |
|
} |
|
.snowflake-icon { |
|
font-size: 24px; |
|
margin-right: 10px; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 20px; |
|
font-size: 0.9em; |
|
color: #999; |
|
} |
|
.parameter-section { |
|
background-color: #2a2a3a; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin-bottom: 15px; |
|
} |
|
.parameter-section h3 { |
|
margin-top: 0; |
|
color: #66ccff; |
|
} |
|
.example-section { |
|
background-color: #223344; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin-bottom: 15px; |
|
} |
|
.example-section h3 { |
|
margin-top: 0; |
|
color: #66ffaa; |
|
} |
|
""" |
|
|
|
|
|
|
|
def load_model_and_tokenizer(): |
|
global model, tokenizer, pipeline |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
pipeline = TextGenerationPipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
return_full_text=False, |
|
max_length=MAX_LENGTH |
|
) |
|
|
|
return model, tokenizer, pipeline |
|
|
|
def generate_text( |
|
prompt, |
|
temperature=TEMPERATURE_DEFAULT, |
|
top_p=TOP_P_DEFAULT, |
|
top_k=TOP_K_DEFAULT, |
|
max_new_tokens=MAX_NEW_TOKENS_DEFAULT, |
|
history=None |
|
): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
history.append({"role": "user", "content": prompt}) |
|
|
|
try: |
|
|
|
outputs = pipeline( |
|
prompt, |
|
do_sample=temperature > 0, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
max_new_tokens=max_new_tokens, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
response = outputs[0]["generated_text"] |
|
|
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
formatted_history = [] |
|
for entry in history: |
|
role_prefix = "👤 User: " if entry["role"] == "user" else "❄️ Snowflake: " |
|
formatted_history.append(f"{role_prefix}{entry['content']}") |
|
|
|
return response, history, "\n\n".join(formatted_history) |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
history.append({"role": "assistant", "content": f"[ERROR] {error_msg}"}) |
|
return error_msg, history, str(history) |
|
|
|
def clear_conversation(): |
|
return "", [], "" |
|
|
|
def apply_preset_example(example, history): |
|
return example, history |
|
|
|
|
|
examples = [ |
|
"Write a short story about a snowflake that comes to life.", |
|
"Explain the concept of artificial neural networks to a 10-year-old.", |
|
"What are some interesting applications of natural language processing?", |
|
"Write a haiku about programming.", |
|
"Create a dialogue between two AI researchers discussing the future of language models." |
|
] |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
|
|
gr.HTML(""" |
|
<div class="header"> |
|
<h1><span class="snowflake-icon">❄️</span> Snowflake-G0-Release Demo</h1> |
|
<p>Experience the capabilities of the Snowflake-G0-Release language model</p> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Accordion("About Snowflake-G0-Release", open=False): |
|
gr.Markdown(""" |
|
## Snowflake-G0-Release |
|
|
|
This is the initial release of the Snowflake series language models, trained on the DialogMLM-50K dataset with optimized memory usage. |
|
|
|
### Model details |
|
- Architecture: SnowflakeCore |
|
- Hidden size: 384 |
|
- Number of attention heads: 6 |
|
- Number of layers: 4 |
|
- Feed-forward dimension: 768 |
|
- Maximum sequence length: 384 |
|
- Vocabulary size: 30522 (BERT tokenizer) |
|
|
|
### Key Features |
|
- Efficient memory usage |
|
- Fused QKV projection for faster inference |
|
- Pre-norm architecture for stable training |
|
- Compatible with HuggingFace Transformers |
|
""") |
|
|
|
|
|
with gr.Column(): |
|
chat_history_display = gr.Textbox( |
|
value="", |
|
label="Conversation History", |
|
lines=10, |
|
max_lines=30, |
|
interactive=False |
|
) |
|
|
|
|
|
history_state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt = gr.Textbox( |
|
placeholder="Type your message here...", |
|
label="Your Input", |
|
lines=2 |
|
) |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
response_output = gr.Textbox( |
|
value="", |
|
label="Model Response", |
|
lines=5, |
|
max_lines=10, |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
with gr.Column(elem_classes="parameter-section"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
temperature = gr.Slider( |
|
minimum=TEMPERATURE_MIN, |
|
maximum=TEMPERATURE_MAX, |
|
value=TEMPERATURE_DEFAULT, |
|
step=0.05, |
|
label="Temperature", |
|
info="Higher = more creative, Lower = more deterministic" |
|
) |
|
|
|
top_p = gr.Slider( |
|
minimum=TOP_P_MIN, |
|
maximum=TOP_P_MAX, |
|
value=TOP_P_DEFAULT, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
info="Controls diversity via cumulative probability" |
|
) |
|
|
|
with gr.Column(): |
|
top_k = gr.Slider( |
|
minimum=TOP_K_MIN, |
|
maximum=TOP_K_MAX, |
|
value=TOP_K_DEFAULT, |
|
step=1, |
|
label="Top-k", |
|
info="Limits word choice to top k options" |
|
) |
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=MAX_NEW_TOKENS_MIN, |
|
maximum=MAX_NEW_TOKENS_MAX, |
|
value=MAX_NEW_TOKENS_DEFAULT, |
|
step=8, |
|
label="Maximum New Tokens", |
|
info="Controls the length of generated response" |
|
) |
|
|
|
|
|
with gr.Accordion("Example Prompts", open=True): |
|
with gr.Column(elem_classes="example-section"): |
|
example_btn = gr.Examples( |
|
examples=examples, |
|
inputs=prompt, |
|
label="Click on an example to try it", |
|
examples_per_page=5 |
|
) |
|
|
|
|
|
gr.HTML(f""" |
|
<div class="footer"> |
|
<p>Snowflake-G0-Release Demo • Created with Gradio • {datetime.datetime.now().year}</p> |
|
</div> |
|
""") |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
|
|
prompt.submit( |
|
fn=generate_text, |
|
inputs=[prompt, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
inputs=[], |
|
outputs=[prompt, history_state, chat_history_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
print("Loading Snowflake-G0-Release model and tokenizer...") |
|
try: |
|
model, tokenizer, pipeline = load_model_and_tokenizer() |
|
print("Model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
|
|
with gr.Blocks(css=css) as error_demo: |
|
gr.HTML(f""" |
|
<div class="header" style="background-color: #ffebee;"> |
|
<h1><span class="snowflake-icon">⚠️</span> Error Loading Model</h1> |
|
<p>There was a problem loading the Snowflake-G0-Release model: {str(e)}</p> |
|
</div> |
|
""") |
|
demo = error_demo |
|
|
|
|
|
demo = create_demo() |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |