PAVULURI KIRAN commited on
Commit
9c37d23
·
1 Parent(s): 146a932

Initial commit

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +81 -0
  3. requirement.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime
2
+ FROM python:3.10
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file
8
+ COPY requirements.txt .
9
+
10
+ # Install dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the FastAPI app file
14
+ COPY app.py .
15
+
16
+ # Expose the port FastAPI runs on
17
+ EXPOSE 7860
18
+
19
+ # Command to run FastAPI using Uvicorn
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ import torch
3
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
4
+ from PIL import Image
5
+ import io
6
+ import base64
7
+
8
+ # Initialize FastAPI app
9
+ app = FastAPI()
10
+
11
+ # Load the model and processor from Hugging Face
12
+ model_name = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini"
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ processor = AutoProcessor.from_pretrained(model_name)
16
+ model = LlavaForConditionalGeneration.from_pretrained(model_name).to(device)
17
+
18
+ @app.post("/predict/")
19
+ async def predict(file: UploadFile = File(...)):
20
+ try:
21
+ # Read image
22
+ image_bytes = await file.read()
23
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
+
25
+ # Convert image to base64 (for compatibility with reference implementation)
26
+ buffered = io.BytesIO()
27
+ image.save(buffered, format="JPEG")
28
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
29
+
30
+ # Step 1: Validate Image Type (Ensure it’s an X-ray or CT scan)
31
+ validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'."
32
+ validation_inputs = processor(text=validation_prompt, images=image, return_tensors="pt").to(device)
33
+
34
+ with torch.no_grad():
35
+ validation_output = model.generate(**validation_inputs, max_new_tokens=10, temperature=0.1, top_p=0.7, top_k=50, repetition_penalty=1)
36
+
37
+ validation_result = processor.batch_decode(validation_output, skip_special_tokens=True)[0].strip().lower()
38
+
39
+ if "yes" not in validation_result:
40
+ return {"error": "Uploaded image is not an X-ray or CT scan. Please upload a valid medical imaging scan."}
41
+
42
+ # Step 2: Generate Structured Medical Analysis
43
+ analysis_prompt = """Please analyze this X-ray image and provide a detailed medical report using the following format:
44
+
45
+ Type of X-ray:
46
+ [Describe the type and orientation of the X-ray]
47
+
48
+ Key Findings:
49
+ • [List each finding on a new line with a bullet point]
50
+ • [Focus on normal and abnormal findings]
51
+ • [Include major anatomical structures]
52
+
53
+ Potential Conditions:
54
+ • [List potential conditions based on findings]
55
+ • [Include likelihood assessments]
56
+
57
+ Recommendations:
58
+ • [Provide any follow-up recommendations]
59
+
60
+ Please provide the analysis in plain text without any special characters or markdown formatting."""
61
+
62
+ analysis_inputs = processor(text=analysis_prompt, images=image, return_tensors="pt").to(device)
63
+
64
+ with torch.no_grad():
65
+ analysis_output = model.generate(**analysis_inputs, max_new_tokens=512, temperature=0.7, top_p=0.7, top_k=50, repetition_penalty=1)
66
+
67
+ analysis_content = processor.batch_decode(analysis_output, skip_special_tokens=True)[0]
68
+
69
+ # Step 3: Clean Up Response (Remove special characters, markdown formatting)
70
+ cleaned_analysis = (
71
+ analysis_content.replace("**", "") # Remove double asterisks
72
+ .replace("*", "•") # Replace single asterisks with bullet points
73
+ .replace("_", "") # Remove underscores
74
+ .replace("#", "") # Remove markdown headers
75
+ .strip()
76
+ )
77
+
78
+ return {"analysis": cleaned_analysis}
79
+
80
+ except Exception as e:
81
+ return {"error": str(e)}
requirement.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ pillow
6
+ python-multipart