import torch from transformers import AutoModelForCausalLM, AutoTokenizer from fastapi import FastAPI from pydantic import BaseModel # Load the model and tokenizer model_name = "databricks/dolly-v2-3b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # FastAPI app app = FastAPI() # Input format class ChatInput(BaseModel): user_input: str @app.post("/chat") async def chat(chat_input: ChatInput): inputs = tokenizer(chat_input.user_input, return_tensors="pt").to("cuda") outputs = model.generate(**inputs, max_length=200, do_sample=True) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": response_text}