MjolnirThor commited on
Commit
3aa3e20
·
verified ·
1 Parent(s): a96669d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -42
app.py CHANGED
@@ -1,53 +1,56 @@
1
- from datasets import load_dataset
2
- from transformers import (
3
- AutoModelForSeq2SeqLM,
4
- AutoTokenizer,
5
- Trainer,
6
- DataCollatorForSeq2Seq
7
- )
8
- from training_config import training_args
9
- import os
10
 
11
- # Load dataset
12
- dataset = load_dataset("health360/Healix-Shot", split=f"train[:100000]")
13
 
14
  # Initialize model and tokenizer
15
  model_name = "google/flan-t5-large"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
- def tokenize_function(examples):
20
- return tokenizer(
21
- examples['text'],
22
- padding="max_length",
23
- truncation=True,
24
- max_length=512,
25
- return_attention_mask=True
26
- )
27
 
28
- # Process dataset
29
- train_test_split = dataset.train_test_split(test_size=0.1)
30
- tokenized_train = train_test_split['train'].map(
31
- tokenize_function,
32
- batched=True,
33
- remove_columns=dataset.column_names
34
- )
35
- tokenized_eval = train_test_split['test'].map(
36
- tokenize_function,
37
- batched=True,
38
- remove_columns=dataset.column_names
39
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Initialize trainer
42
- trainer = Trainer(
43
- model=model,
44
- args=training_args,
45
- train_dataset=tokenized_train,
46
- eval_dataset=tokenized_eval,
47
- data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
48
  )
49
 
50
- # Train and save
51
- trainer.train()
52
- model.push_to_hub("MjolnirThor/flan-t5-custom-handler")
53
- tokenizer.push_to_hub("MjolnirThor/flan-t5-custom-handler")
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ import torch
5
+ import gradio as gr
 
 
 
 
6
 
7
+ app = FastAPI()
 
8
 
9
  # Initialize model and tokenizer
10
  model_name = "google/flan-t5-large"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
+ class Query(BaseModel):
15
+ inputs: str
 
 
 
 
 
 
16
 
17
+ @app.post("/")
18
+ async def generate(query: Query):
19
+ try:
20
+ # Tokenize input
21
+ inputs = tokenizer(query.inputs, return_tensors="pt", max_length=512, truncation=True)
22
+
23
+ # Generate response
24
+ outputs = model.generate(
25
+ inputs.input_ids,
26
+ max_length=512,
27
+ num_beams=4,
28
+ temperature=0.7,
29
+ top_p=0.9,
30
+ repetition_penalty=1.2,
31
+ early_stopping=True
32
+ )
33
+
34
+ # Decode response
35
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return {"generated_text": response}
37
+
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+
41
+ # Gradio interface
42
+ def generate_text(prompt):
43
+ query = Query(inputs=prompt)
44
+ response = generate(query)
45
+ return response["generated_text"]
46
 
47
+ iface = gr.Interface(
48
+ fn=generate_text,
49
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
50
+ outputs="text",
51
+ title="Medical Assistant",
52
+ description="Ask me anything about medical topics!"
 
53
  )
54
 
55
+ # Mount the Gradio app
56
+ app = gr.mount_gradio_app(app, iface, path="/")