from llama_cpp import Llama from concurrent.futures import ThreadPoolExecutor, as_completed import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import os from dotenv import load_dotenv from pydantic import BaseModel import requests import traceback load_dotenv() HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") global_data = { 'models': {}, 'tokens': { 'eos': 'eos_token', 'pad': 'pad_token', 'padding': 'padding_token', 'unk': 'unk_token', 'bos': 'bos_token', 'sep': 'sep_token', 'cls': 'cls_token', 'mask': 'mask_token' } } model_configs = [ {"repo_id": "Hjgugugjhuhjggg/mergekit-ties-tzamfyy-Q2_K-GGUF", "filename": "mergekit-ties-tzamfyy-q2_k.gguf", "name": "my_model"} ] models = {} def load_model(model_config): model_name = model_config['name'] if model_name not in models: try: model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN) models[model_name] = model global_data['models'] = models return model except Exception as e: print(f"Error loading model {model_name}: {e}") traceback.print_exc() models[model_name] = None return None for config in model_configs: load_model(config) class ChatRequest(BaseModel): message: str max_tokens_per_part: int = 256 def normalize_input(input_text): return input_text.strip() def remove_duplicates(text): lines = text.split('\n') unique_lines = [] seen_lines = set() for line in lines: line = line.strip() if line and line not in seen_lines: unique_lines.append(line) seen_lines.add(line) return '\n'.join(unique_lines) def generate_model_response(model, inputs, max_tokens_per_part): try: if model is None: return [] full_response = "" responses = [] response = model(inputs, max_tokens=max_tokens_per_part, stop=["\n\n"]) if 'choices' not in response or len(response['choices']) == 0 or 'text' not in response['choices'][0]: return [f"Error: Invalid model response format"] text = response['choices'][0]['text'] if text: responses.append(remove_duplicates(text)) return responses except Exception as e: print(f"Error generating response: {e}") traceback.print_exc() return [f"Error: {e}"] app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/generate") async def generate(request: ChatRequest): inputs = normalize_input(request.message) with ThreadPoolExecutor() as executor: futures = [executor.submit(generate_model_response, model, inputs, request.max_tokens_per_part) for model in models.values()] responses = [{'model': model_name, 'response': future.result()} for model_name, future in zip(models.keys(), as_completed(futures))] unique_responses = {} for response_set in responses: model_name = response_set['model'] if model_name not in unique_responses: unique_responses[model_name] = [] unique_responses[model_name].extend(response_set['response']) formatted_response = "" for model, response_parts in unique_responses.items(): formatted_response += f"**{model}:**\n" for i, part in enumerate(response_parts): formatted_response += f"Part {i+1}:\n{part}\n\n" return {"response": formatted_response} if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)