|
import os |
|
import torch |
|
import gradio as gr |
|
import datetime |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline |
|
|
|
import spaces |
|
|
|
|
|
MODEL_CONFIG = { |
|
"G0-Release": "FlameF0X/SnowflakeCore-G0-Release", |
|
"G0-Release-2": "FlameF0X/SnowflakeCore-G0-Release-2", |
|
"G0-Release-2.5": "FlameF0X/SnowflakeCore-G0-Release-2.5" |
|
} |
|
|
|
MAX_LENGTH = 384 |
|
TEMPERATURE_DEFAULT = 0.7 |
|
TOP_P_DEFAULT = 0.9 |
|
TOP_K_DEFAULT = 40 |
|
MAX_NEW_TOKENS_DEFAULT = 256 |
|
|
|
TEMPERATURE_MIN, TEMPERATURE_MAX = 0.1, 2.0 |
|
TOP_P_MIN, TOP_P_MAX = 0.1, 1.0 |
|
TOP_K_MIN, TOP_K_MAX = 1, 100 |
|
MAX_NEW_TOKENS_MIN, MAX_NEW_TOKENS_MAX = 16, 1024 |
|
|
|
css = """ |
|
.gradio-container { background-color: #1e1e2f !important; color: #e0e0e0 !important; } |
|
.header { background-color: #2b2b3c; padding: 20px; margin-bottom: 20px; border-radius: 10px; text-align: center; } |
|
.header h1 { color: #66ccff; margin-bottom: 10px; } |
|
.snowflake-icon { font-size: 24px; margin-right: 10px; } |
|
.footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #999; } |
|
.parameter-section { background-color: #2a2a3a; padding: 15px; border-radius: 8px; margin-bottom: 15px; } |
|
.parameter-section h3 { margin-top: 0; color: #66ccff; } |
|
.example-section { background-color: #223344; padding: 15px; border-radius: 8px; margin-bottom: 15px; } |
|
.example-section h3 { margin-top: 0; color: #66ffaa; } |
|
.model-select { background-color: #2a2a4a; padding: 10px; border-radius: 8px; margin-bottom: 15px; } |
|
""" |
|
|
|
|
|
model_registry = {} |
|
|
|
def load_model_cpu(model_id): |
|
"""Load model on CPU only - no CUDA initialization""" |
|
print(f"Loading model on CPU: {model_id}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float32, |
|
device_map=None, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
return model, tokenizer |
|
|
|
@spaces.GPU |
|
def generate_text_gpu(prompt, model_version, temperature, top_p, top_k, max_new_tokens): |
|
"""GPU-decorated function for text generation""" |
|
try: |
|
|
|
if model_version not in model_registry: |
|
model_id = MODEL_CONFIG[model_version] |
|
model, tokenizer = load_model_cpu(model_id) |
|
model_registry[model_version] = (model, tokenizer) |
|
|
|
model, tokenizer = model_registry[model_version] |
|
|
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
pipeline = TextGenerationPipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
return_full_text=False, |
|
max_length=MAX_LENGTH, |
|
device=device |
|
) |
|
|
|
outputs = pipeline( |
|
prompt, |
|
do_sample=temperature > 0, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
max_new_tokens=max_new_tokens, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
response = outputs[0]["generated_text"] |
|
return response, None |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
return error_msg, str(e) |
|
|
|
def generate_text(prompt, model_version, temperature, top_p, top_k, max_new_tokens, history=None): |
|
"""Main generation function that calls GPU function""" |
|
if history is None: |
|
history = [] |
|
|
|
|
|
history.append({"role": "user", "content": prompt}) |
|
|
|
try: |
|
|
|
response, error = generate_text_gpu( |
|
prompt, model_version, temperature, top_p, top_k, max_new_tokens |
|
) |
|
|
|
if error: |
|
history.append({"role": "assistant", "content": f"[ERROR] {response}", "model": model_version}) |
|
else: |
|
history.append({"role": "assistant", "content": response, "model": model_version}) |
|
|
|
|
|
formatted_history = [] |
|
for entry in history: |
|
prefix = "👤 User: " if entry["role"] == "user" else f"❄️ [{entry.get('model', 'Model')}]: " |
|
formatted_history.append(f"{prefix}{entry['content']}") |
|
|
|
return response, history, "\n\n".join(formatted_history) |
|
|
|
except Exception as e: |
|
error_msg = f"Error in generation pipeline: {str(e)}" |
|
history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version}) |
|
return error_msg, history, str(history) |
|
|
|
def clear_conversation(): |
|
return "", [], "" |
|
|
|
def create_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(""" |
|
<div class="header"> |
|
<h1><span class="snowflake-icon">❄️</span> SnowflakeCore Demo Inteface</h1> |
|
<p>Experience the capabilities of the SnowflakeCore series language models</p> |
|
</div> |
|
""") |
|
|
|
with gr.Column(): |
|
with gr.Row(elem_classes="model-select"): |
|
model_version = gr.Radio( |
|
choices=list(MODEL_CONFIG.keys()), |
|
value=list(MODEL_CONFIG.keys())[0], |
|
label="Select Model Version", |
|
info="Choose which SnowflakeCore model to use" |
|
) |
|
|
|
chat_history_display = gr.Textbox( |
|
value="", |
|
label="Conversation History", |
|
lines=10, |
|
max_lines=30, |
|
interactive=False |
|
) |
|
|
|
history_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt = gr.Textbox( |
|
placeholder="Type your message here...", |
|
label="Your Input", |
|
lines=2 |
|
) |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
response_output = gr.Textbox( |
|
value="", |
|
label="Model Response", |
|
lines=5, |
|
max_lines=10, |
|
interactive=False |
|
) |
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
with gr.Column(elem_classes="parameter-section"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
temperature = gr.Slider( |
|
minimum=TEMPERATURE_MIN, maximum=TEMPERATURE_MAX, |
|
value=TEMPERATURE_DEFAULT, step=0.05, |
|
label="Temperature" |
|
) |
|
top_p = gr.Slider( |
|
minimum=TOP_P_MIN, maximum=TOP_P_MAX, |
|
value=TOP_P_DEFAULT, step=0.05, |
|
label="Top-p" |
|
) |
|
with gr.Column(): |
|
top_k = gr.Slider( |
|
minimum=TOP_K_MIN, maximum=TOP_K_MAX, |
|
value=TOP_K_DEFAULT, step=1, |
|
label="Top-k" |
|
) |
|
max_new_tokens = gr.Slider( |
|
minimum=MAX_NEW_TOKENS_MIN, maximum=MAX_NEW_TOKENS_MAX, |
|
value=MAX_NEW_TOKENS_DEFAULT, step=8, |
|
label="Maximum New Tokens" |
|
) |
|
|
|
examples = [ |
|
"Write a short story about a snowflake that comes to life.", |
|
"Explain the concept of artificial neural networks to a 10-year-old.", |
|
"What are some interesting applications of natural language processing?", |
|
"Write a haiku about programming.", |
|
"Create a dialogue between two AI researchers discussing the future of language models." |
|
] |
|
|
|
with gr.Accordion("Example Prompts", open=True): |
|
with gr.Column(elem_classes="example-section"): |
|
gr.Examples( |
|
examples=examples, |
|
inputs=prompt, |
|
label="Click on an example to try it", |
|
examples_per_page=5 |
|
) |
|
|
|
gr.HTML(f""" |
|
<div class="footer"> |
|
<p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p> |
|
</div> |
|
""") |
|
|
|
submit_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
prompt.submit( |
|
fn=generate_text, |
|
inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
clear_btn.click( |
|
fn=clear_conversation, |
|
inputs=[], |
|
outputs=[prompt, history_state, chat_history_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
print("Initializing Snowflake Models Demo...") |
|
demo = create_demo() |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |