MjolnirThor commited on
Commit
13b79a4
·
verified ·
1 Parent(s): 232a012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -20
app.py CHANGED
@@ -1,24 +1,53 @@
1
- from fastapi import FastAPI, HTTPException
2
- from handler import EndpointHandler
3
- from pydantic import BaseModel
 
 
 
 
 
 
4
 
5
- class Input(BaseModel):
6
- inputs: str
7
 
8
- app = FastAPI()
9
- handler = EndpointHandler()
 
 
10
 
11
- @app.post("/generate")
12
- async def generate(input_data: Input):
13
- try:
14
- result = handler({"inputs": input_data.inputs})
15
- return result
16
- except Exception as e:
17
- raise HTTPException(status_code=500, detail=str(e))
 
18
 
19
- @app.get("/")
20
- async def root():
21
- return {
22
- "message": "FLAN-T5 Custom Handler API",
23
- "usage": "POST /generate with {'inputs': 'your text here'}"
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoModelForSeq2SeqLM,
4
+ AutoTokenizer,
5
+ Trainer,
6
+ DataCollatorForSeq2Seq
7
+ )
8
+ from training_config import training_args
9
+ import os
10
 
11
+ # Load dataset
12
+ dataset = load_dataset("health360/Healix-Shot", split=f"train[:100000]")
13
 
14
+ # Initialize model and tokenizer
15
+ model_name = "google/flan-t5-large"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
+ def tokenize_function(examples):
20
+ return tokenizer(
21
+ examples['text'],
22
+ padding="max_length",
23
+ truncation=True,
24
+ max_length=512,
25
+ return_attention_mask=True
26
+ )
27
 
28
+ # Process dataset
29
+ train_test_split = dataset.train_test_split(test_size=0.1)
30
+ tokenized_train = train_test_split['train'].map(
31
+ tokenize_function,
32
+ batched=True,
33
+ remove_columns=dataset.column_names
34
+ )
35
+ tokenized_eval = train_test_split['test'].map(
36
+ tokenize_function,
37
+ batched=True,
38
+ remove_columns=dataset.column_names
39
+ )
40
+
41
+ # Initialize trainer
42
+ trainer = Trainer(
43
+ model=model,
44
+ args=training_args,
45
+ train_dataset=tokenized_train,
46
+ eval_dataset=tokenized_eval,
47
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
48
+ )
49
+
50
+ # Train and save
51
+ trainer.train()
52
+ model.push_to_hub("MjolnirThor/flan-t5-custom-handler")
53
+ tokenizer.push_to_hub("MjolnirThor/flan-t5-custom-handler")