Commit
·
c05ccb6
verified
·
0
Parent(s):
Super-squash branch 'main' using huggingface_hub
Browse filesCo-authored-by: FranckAbgrall <[email protected]>
- .gitattributes +35 -0
- Dockerfile +34 -0
- README.md +19 -0
- freqs.json +0 -0
- requirements.txt +6 -0
- src/streamlit_app.py +188 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get install -y \
|
4 |
+
build-essential \
|
5 |
+
curl \
|
6 |
+
software-properties-common \
|
7 |
+
git \
|
8 |
+
&& rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
|
11 |
+
# Set up a new user named "user" with user ID 1000
|
12 |
+
RUN useradd -m -u 1000 user
|
13 |
+
|
14 |
+
# Switch to the "user" user
|
15 |
+
USER user
|
16 |
+
|
17 |
+
# Set home to the user's home directory
|
18 |
+
ENV HOME=/home/user \
|
19 |
+
PATH=/home/user/.local/bin:$PATH
|
20 |
+
|
21 |
+
# Set the working directory to the user's home directory
|
22 |
+
WORKDIR $HOME/app
|
23 |
+
|
24 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
25 |
+
COPY --chown=user . $HOME/app
|
26 |
+
|
27 |
+
# Try and run pip command after setting the user with `USER user` to avoid permission issues with Python
|
28 |
+
RUN pip3 install -r requirements.txt
|
29 |
+
|
30 |
+
EXPOSE 8501
|
31 |
+
|
32 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
33 |
+
|
34 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0", "--server.enableXsrfProtection=false", "--server.enableCORS=false"]
|
README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Dnbr Tagger Preview1 Demo
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: red
|
6 |
+
sdk: docker
|
7 |
+
app_port: 8501
|
8 |
+
tags:
|
9 |
+
- streamlit
|
10 |
+
pinned: false
|
11 |
+
short_description: Demo of a Danbooru tagger
|
12 |
+
---
|
13 |
+
|
14 |
+
# Welcome to Streamlit!
|
15 |
+
|
16 |
+
Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
|
17 |
+
|
18 |
+
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
19 |
+
forums](https://discuss.streamlit.io).
|
freqs.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair
|
2 |
+
pandas
|
3 |
+
streamlit
|
4 |
+
timm
|
5 |
+
torch
|
6 |
+
torchvision
|
src/streamlit_app.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import io # Used to handle image bytes
|
5 |
+
import torch
|
6 |
+
import timm
|
7 |
+
import json
|
8 |
+
from torchvision.transforms.v2 import (
|
9 |
+
ToImage,
|
10 |
+
Compose,
|
11 |
+
ToDtype,
|
12 |
+
Normalize,
|
13 |
+
)
|
14 |
+
import pandas as pd
|
15 |
+
import requests
|
16 |
+
|
17 |
+
|
18 |
+
st.set_page_config(layout="wide")
|
19 |
+
|
20 |
+
device = ["cpu", "cuda"][torch.cuda.is_available()]
|
21 |
+
|
22 |
+
|
23 |
+
def NativeResize(patch_size, seq_len_range):
|
24 |
+
p, lo, hi = patch_size, *seq_len_range
|
25 |
+
refs = sorted(
|
26 |
+
[
|
27 |
+
(i / j, i * p, j * p)
|
28 |
+
for i in range(4, 100)
|
29 |
+
for j in range(4, 100)
|
30 |
+
if 0.33 <= i / j <= 3 and lo <= i * j < hi
|
31 |
+
]
|
32 |
+
)
|
33 |
+
|
34 |
+
def get_ratio(r):
|
35 |
+
return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1)
|
36 |
+
|
37 |
+
def f(im: Image.Image):
|
38 |
+
w, h = im.size
|
39 |
+
_, sw, sh = get_ratio(w / h)
|
40 |
+
return im.resize((sw, sh), resample=Image.Resampling.BICUBIC)
|
41 |
+
|
42 |
+
return f
|
43 |
+
|
44 |
+
def load_json_from_url(url):
|
45 |
+
try:
|
46 |
+
response = requests.get(url)
|
47 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
48 |
+
parsed_json = json.loads(response.text)
|
49 |
+
return parsed_json
|
50 |
+
except requests.exceptions.RequestException as e:
|
51 |
+
print(f"Error fetching data from URL: {e}")
|
52 |
+
return None
|
53 |
+
except json.JSONDecodeError as e:
|
54 |
+
print(f"Error decoding JSON data: {e}")
|
55 |
+
return None
|
56 |
+
|
57 |
+
|
58 |
+
@st.cache_data
|
59 |
+
def load_tags():
|
60 |
+
freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json")
|
61 |
+
freqs = [*freqs, (("PLACEHOLDER", 0))]
|
62 |
+
return freqs
|
63 |
+
|
64 |
+
|
65 |
+
tags = load_tags()
|
66 |
+
|
67 |
+
|
68 |
+
@st.cache_resource
|
69 |
+
def load_model():
|
70 |
+
torch.set_grad_enabled(False)
|
71 |
+
model = (
|
72 |
+
timm.create_model(
|
73 |
+
"hf_hub:gustproof/dnbr-tagger-preview1",
|
74 |
+
pretrained=True,
|
75 |
+
dynamic_img_size=True,
|
76 |
+
)
|
77 |
+
.eval()
|
78 |
+
.to(device)
|
79 |
+
)
|
80 |
+
print("loaded model")
|
81 |
+
tf = Compose(
|
82 |
+
[
|
83 |
+
ToImage(),
|
84 |
+
ToDtype(torch.float, scale=True),
|
85 |
+
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
class Model:
|
90 |
+
def __init__(self):
|
91 |
+
self.class_names = load_tags()
|
92 |
+
|
93 |
+
def predict(self, img):
|
94 |
+
x = tf(img).unsqueeze(0).to(device)
|
95 |
+
probs = model(x).squeeze(0).sigmoid().cpu()
|
96 |
+
return probs
|
97 |
+
|
98 |
+
return Model()
|
99 |
+
|
100 |
+
|
101 |
+
model = load_model()
|
102 |
+
|
103 |
+
# --- Streamlit App Layout ---
|
104 |
+
st.title("Tagger demo")
|
105 |
+
st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)")
|
106 |
+
st.write(
|
107 |
+
"Upload an image to see predicted labels."
|
108 |
+
)
|
109 |
+
st.write("---")
|
110 |
+
|
111 |
+
# --- Sidebar for Controls ---
|
112 |
+
st.sidebar.header("Settings")
|
113 |
+
# Confidence Threshold Slider
|
114 |
+
confidence_threshold = st.sidebar.slider(
|
115 |
+
"Threshold (recommended: ~0.4-~0.6)",
|
116 |
+
min_value=0.0,
|
117 |
+
max_value=1.0,
|
118 |
+
value=0.5, # Default threshold
|
119 |
+
step=0.01,
|
120 |
+
)
|
121 |
+
|
122 |
+
# --- Main Area ---
|
123 |
+
col1, col2 = st.columns(2)
|
124 |
+
|
125 |
+
with col1:
|
126 |
+
st.header("Upload Image")
|
127 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
128 |
+
|
129 |
+
if uploaded_file is not None:
|
130 |
+
# Read the image bytes
|
131 |
+
image_bytes = uploaded_file.getvalue()
|
132 |
+
# Display the uploaded image
|
133 |
+
try:
|
134 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
135 |
+
image = NativeResize(14, (270, 301))(image)
|
136 |
+
st.info(f"Resied (width, height): {image.size}")
|
137 |
+
st.image(image, caption="Uploaded Image.")
|
138 |
+
except Exception as e:
|
139 |
+
st.error(f"Error opening image: {e}")
|
140 |
+
st.warning("Please upload a valid image file (JPG, JPEG, PNG).")
|
141 |
+
uploaded_file = None # Reset uploaded_file so processing stops
|
142 |
+
|
143 |
+
with col2:
|
144 |
+
st.header("Predictions")
|
145 |
+
if uploaded_file is not None:
|
146 |
+
with st.spinner("Computing..."):
|
147 |
+
try:
|
148 |
+
scores = model.predict(image)
|
149 |
+
filtered_results = [
|
150 |
+
(i, p) for i, p in enumerate(scores) if p >= confidence_threshold
|
151 |
+
]
|
152 |
+
filtered_results.sort(key=lambda x: x[1], reverse=True)
|
153 |
+
|
154 |
+
if filtered_results:
|
155 |
+
get_category = lambda ti: (
|
156 |
+
"Rating" if ti < 4 else "General" if ti < 8856 else "Character"
|
157 |
+
)
|
158 |
+
df = pd.DataFrame(
|
159 |
+
[
|
160 |
+
(i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1])
|
161 |
+
for i, (ti, p) in enumerate(filtered_results[:150], 1)
|
162 |
+
],
|
163 |
+
columns=[
|
164 |
+
"Rank",
|
165 |
+
"Label",
|
166 |
+
"Score",
|
167 |
+
"Category",
|
168 |
+
"Dataset frequency",
|
169 |
+
],
|
170 |
+
)
|
171 |
+
st.dataframe(
|
172 |
+
df,
|
173 |
+
hide_index=True,
|
174 |
+
column_config={
|
175 |
+
"Dataset frequency": st.column_config.NumberColumn(
|
176 |
+
format="localized"
|
177 |
+
)
|
178 |
+
},
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
st.info("No labels meet the current threshold.")
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
st.error("An error occurred during prediction or processing:")
|
185 |
+
st.exception(e) # Shows the full traceback
|
186 |
+
|
187 |
+
else:
|
188 |
+
st.info("Upload an image using the panel on the left to see predictions.")
|