|
import os |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline |
|
from safetensors.torch import load_file |
|
import datetime |
|
|
|
|
|
MODEL_ID_V1 = "FlameF0X/Snowflake-G0-Release" |
|
MODEL_ID_V2 = "FlameF0X/Snowflake-G0-Release-2" |
|
MAX_LENGTH = 384 |
|
TEMPERATURE_MIN = 0.1 |
|
TEMPERATURE_MAX = 2.0 |
|
TEMPERATURE_DEFAULT = 0.7 |
|
TOP_P_MIN = 0.1 |
|
TOP_P_MAX = 1.0 |
|
TOP_P_DEFAULT = 0.9 |
|
TOP_K_MIN = 1 |
|
TOP_K_MAX = 100 |
|
TOP_K_DEFAULT = 40 |
|
MAX_NEW_TOKENS_MIN = 16 |
|
MAX_NEW_TOKENS_MAX = 1024 |
|
MAX_NEW_TOKENS_DEFAULT = 256 |
|
|
|
|
|
css = """ |
|
.gradio-container { |
|
background-color: #1e1e2f !important; |
|
color: #e0e0e0 !important; |
|
} |
|
.header { |
|
background-color: #2b2b3c; |
|
padding: 20px; |
|
margin-bottom: 20px; |
|
border-radius: 10px; |
|
text-align: center; |
|
} |
|
.header h1 { |
|
color: #66ccff; |
|
margin-bottom: 10px; |
|
} |
|
.snowflake-icon { |
|
font-size: 24px; |
|
margin-right: 10px; |
|
} |
|
.footer { |
|
text-align: center; |
|
margin-top: 20px; |
|
font-size: 0.9em; |
|
color: #999; |
|
} |
|
.parameter-section { |
|
background-color: #2a2a3a; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin-bottom: 15px; |
|
} |
|
.parameter-section h3 { |
|
margin-top: 0; |
|
color: #66ccff; |
|
} |
|
.example-section { |
|
background-color: #223344; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin-bottom: 15px; |
|
} |
|
.example-section h3 { |
|
margin-top: 0; |
|
color: #66ffaa; |
|
} |
|
.model-select { |
|
background-color: #2a2a4a; |
|
padding: 10px; |
|
border-radius: 8px; |
|
margin-bottom: 15px; |
|
} |
|
""" |
|
|
|
|
|
model_v1 = None |
|
tokenizer_v1 = None |
|
pipeline_v1 = None |
|
model_v2 = None |
|
tokenizer_v2 = None |
|
pipeline_v2 = None |
|
|
|
|
|
def load_models_and_tokenizers(): |
|
global model_v1, tokenizer_v1, pipeline_v1, model_v2, tokenizer_v2, pipeline_v2 |
|
|
|
|
|
print(f"Loading model from {MODEL_ID_V1}...") |
|
tokenizer_v1 = AutoTokenizer.from_pretrained(MODEL_ID_V1) |
|
if tokenizer_v1.pad_token is None: |
|
tokenizer_v1.pad_token = tokenizer_v1.eos_token |
|
|
|
model_file_path = os.path.join(MODEL_ID_V1, "model.safetensors") |
|
|
|
if os.path.exists(model_file_path): |
|
print("Loading model from safetensors file...") |
|
model_v1 = load_file(model_file_path) |
|
else: |
|
print("Loading model from .bin file...") |
|
model_v1 = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID_V1, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
pipeline_v1 = TextGenerationPipeline( |
|
model=model_v1, |
|
tokenizer=tokenizer_v1, |
|
return_full_text=False, |
|
max_length=MAX_LENGTH |
|
) |
|
|
|
|
|
print(f"Loading model from {MODEL_ID_V2}...") |
|
tokenizer_v2 = AutoTokenizer.from_pretrained(MODEL_ID_V2) |
|
if tokenizer_v2.pad_token is None: |
|
tokenizer_v2.pad_token = tokenizer_v2.eos_token |
|
|
|
model_file_path = os.path.join(MODEL_ID_V2, "model.safetensors") |
|
|
|
if os.path.exists(model_file_path): |
|
print("Loading model from safetensors file...") |
|
model_v2 = load_file(model_file_path) |
|
else: |
|
print("Loading model from .bin file...") |
|
model_v2 = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID_V2, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
pipeline_v2 = TextGenerationPipeline( |
|
model=model_v2, |
|
tokenizer=tokenizer_v2, |
|
return_full_text=False, |
|
max_length=MAX_LENGTH |
|
) |
|
|
|
return (model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2) |
|
|
|
|
|
def generate_text( |
|
prompt, |
|
model_version, |
|
temperature=TEMPERATURE_DEFAULT, |
|
top_p=TOP_P_DEFAULT, |
|
top_k=TOP_K_DEFAULT, |
|
max_new_tokens=MAX_NEW_TOKENS_DEFAULT, |
|
history=None |
|
): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
history.append({"role": "user", "content": prompt}) |
|
|
|
try: |
|
|
|
if model_version == "G0-Release": |
|
pipeline = pipeline_v1 |
|
tokenizer = tokenizer_v1 |
|
model_name = "Snowflake-G0-Release" |
|
else: |
|
pipeline = pipeline_v2 |
|
tokenizer = tokenizer_v2 |
|
model_name = "Snowflake-G0-Release-2" |
|
|
|
|
|
outputs = pipeline( |
|
prompt, |
|
do_sample=temperature > 0, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
max_new_tokens=max_new_tokens, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
response = outputs[0]["generated_text"] |
|
|
|
|
|
history.append({"role": "assistant", "content": response, "model": model_name}) |
|
|
|
|
|
formatted_history = [] |
|
for entry in history: |
|
if entry["role"] == "user": |
|
role_prefix = "👤 User: " |
|
else: |
|
model_indicator = f"[{entry.get('model', 'Snowflake')}]" |
|
role_prefix = f"❄️ {model_indicator}: " |
|
formatted_history.append(f"{role_prefix}{entry['content']}") |
|
|
|
return response, history, "\n\n".join(formatted_history) |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
history.append({"role": "assistant", "content": f"[ERROR] {error_msg}", "model": model_version}) |
|
return error_msg, history, str(history) |
|
|
|
def clear_conversation(): |
|
return "", [], "" |
|
|
|
def apply_preset_example(example, history): |
|
return example, history |
|
|
|
|
|
examples = [ |
|
"Write a short story about a snowflake that comes to life.", |
|
"Explain the concept of artificial neural networks to a 10-year-old.", |
|
"What are some interesting applications of natural language processing?", |
|
"Write a haiku about programming.", |
|
"Create a dialogue between two AI researchers discussing the future of language models." |
|
] |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
|
|
gr.HTML(""" |
|
<div class="header"> |
|
<h1><span class="snowflake-icon">❄️</span> Snowflake Models Demo</h1> |
|
<p>Experience the capabilities of the Snowflake series language models</p> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
|
with gr.Row(elem_classes="model-select"): |
|
model_version = gr.Radio( |
|
["G0-Release", "G0-Release-2"], |
|
label="Select Model Version", |
|
value="G0-Release-2", |
|
info="Choose which Snowflake model to use" |
|
) |
|
|
|
chat_history_display = gr.Textbox( |
|
value="", |
|
label="Conversation History", |
|
lines=10, |
|
max_lines=30, |
|
interactive=False |
|
) |
|
|
|
|
|
history_state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
prompt = gr.Textbox( |
|
placeholder="Type your message here...", |
|
label="Your Input", |
|
lines=2 |
|
) |
|
with gr.Column(scale=1): |
|
submit_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
response_output = gr.Textbox( |
|
value="", |
|
label="Model Response", |
|
lines=5, |
|
max_lines=10, |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
with gr.Column(elem_classes="parameter-section"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
temperature = gr.Slider( |
|
minimum=TEMPERATURE_MIN, |
|
maximum=TEMPERATURE_MAX, |
|
value=TEMPERATURE_DEFAULT, |
|
step=0.05, |
|
label="Temperature", |
|
info="Higher = more creative, Lower = more deterministic" |
|
) |
|
|
|
top_p = gr.Slider( |
|
minimum=TOP_P_MIN, |
|
maximum=TOP_P_MAX, |
|
value=TOP_P_DEFAULT, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
info="Controls diversity via cumulative probability" |
|
) |
|
|
|
with gr.Column(): |
|
top_k = gr.Slider( |
|
minimum=TOP_K_MIN, |
|
maximum=TOP_K_MAX, |
|
value=TOP_K_DEFAULT, |
|
step=1, |
|
label="Top-k", |
|
info="Limits word choice to top k options" |
|
) |
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=MAX_NEW_TOKENS_MIN, |
|
maximum=MAX_NEW_TOKENS_MAX, |
|
value=MAX_NEW_TOKENS_DEFAULT, |
|
step=8, |
|
label="Maximum New Tokens", |
|
info="Controls the length of generated response" |
|
) |
|
|
|
|
|
with gr.Accordion("Example Prompts", open=True): |
|
with gr.Column(elem_classes="example-section"): |
|
example_btn = gr.Examples( |
|
examples=examples, |
|
inputs=prompt, |
|
label="Click on an example to try it", |
|
examples_per_page=5 |
|
) |
|
|
|
|
|
gr.HTML(f""" |
|
<div class="footer"> |
|
<p>Snowflake Models Demo • Created with Gradio • {datetime.datetime.now().year}</p> |
|
</div> |
|
""") |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
|
|
prompt.submit( |
|
fn=generate_text, |
|
inputs=[prompt, model_version, temperature, top_p, top_k, max_new_tokens, history_state], |
|
outputs=[response_output, history_state, chat_history_display] |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
inputs=[], |
|
outputs=[prompt, history_state, chat_history_display] |
|
) |
|
|
|
return demo |
|
|
|
|
|
print("Loading Snowflake models and tokenizers...") |
|
try: |
|
(model_v1, tokenizer_v1, pipeline_v1), (model_v2, tokenizer_v2, pipeline_v2) = load_models_and_tokenizers() |
|
print("Models loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading models: {str(e)}") |
|
|
|
with gr.Blocks(css=css) as error_demo: |
|
gr.HTML(f""" |
|
<div class="header" style="background-color: #ffebee;"> |
|
<h1><span class="snowflake-icon">⚠️</span> Error Loading Models</h1> |
|
<p>There was a problem loading the Snowflake models: {str(e)}</p> |
|
</div> |
|
""") |
|
demo = error_demo |
|
|
|
|
|
demo = create_demo() |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |