Larrytech commited on
Commit
4f6b66b
·
1 Parent(s): fd84066

Summirizer_agent Update

Browse files
Files changed (1) hide show
  1. main.py +22 -12
main.py CHANGED
@@ -1,12 +1,13 @@
1
  from fastapi import FastAPI, Form
2
  from fastapi.responses import HTMLResponse
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
- # Load Granite 2B model
9
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
11
  model = AutoModelForCausalLM.from_pretrained(
12
  MODEL_ID,
@@ -14,6 +15,9 @@ model = AutoModelForCausalLM.from_pretrained(
14
  device_map="auto"
15
  )
16
 
 
 
 
17
  @app.get("/", response_class=HTMLResponse)
18
  def index():
19
  return """
@@ -31,15 +35,21 @@ def index():
31
 
32
  @app.post("/summarize", response_class=HTMLResponse)
33
  def summarize(text: str = Form(...)):
34
- prompt = f"Summarize the following text:\n{text.strip()}\nSummary:"
35
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
36
- summary_ids = model.generate(
37
- **inputs,
 
 
 
38
  max_new_tokens=150,
39
- do_sample=False,
40
- temperature=0.7
 
 
 
41
  )
42
- output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
43
- # Remove the prompt from the output
44
- summary = output.replace(prompt, "").strip()
45
- return f"<h2>Summary</h2><pre>{summary}</pre><a href='/'>Back</a>"
 
1
  from fastapi import FastAPI, Form
2
  from fastapi.responses import HTMLResponse
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
5
 
6
  app = FastAPI()
7
 
 
8
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
9
+
10
+ # Load tokenzier and model
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
 
15
  device_map="auto"
16
  )
17
 
18
+ # Use pipeline for easier text generation
19
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
20
+
21
  @app.get("/", response_class=HTMLResponse)
22
  def index():
23
  return """
 
35
 
36
  @app.post("/summarize", response_class=HTMLResponse)
37
  def summarize(text: str = Form(...)):
38
+ prompt = (
39
+ "Below is a passage of text. Please provide a concise summary in 2-4 sentences.\n\n"
40
+ f"Text:\n{text.strip()}\n\nSummary:"
41
+ )
42
+ # Generate output using the pipeline
43
+ outputs = pipe(
44
+ prompt,
45
  max_new_tokens=150,
46
+ do_sample=True,
47
+ temperature=0.7,
48
+ top_p=0.95,
49
+ eos_token_id=tokenizer.eos_token_id,
50
+ pad_token_id=tokenizer.eos_token_id
51
  )
52
+ output_text = outputs[0]['generated_text']
53
+ # Extract only the summary after 'Summary:'
54
+ summary = output_text.split("Summary:")[-1].strip()
55
+ return f"<h2>Summary</h2><pre>{summary}</pre><a href='/'>Back</a>"