gustproof FranckAbgrall HF Staff commited on
Commit
c05ccb6
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: FranckAbgrall <[email protected]>

Files changed (6) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +34 -0
  3. README.md +19 -0
  4. freqs.json +0 -0
  5. requirements.txt +6 -0
  6. src/streamlit_app.py +188 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ build-essential \
5
+ curl \
6
+ software-properties-common \
7
+ git \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+
11
+ # Set up a new user named "user" with user ID 1000
12
+ RUN useradd -m -u 1000 user
13
+
14
+ # Switch to the "user" user
15
+ USER user
16
+
17
+ # Set home to the user's home directory
18
+ ENV HOME=/home/user \
19
+ PATH=/home/user/.local/bin:$PATH
20
+
21
+ # Set the working directory to the user's home directory
22
+ WORKDIR $HOME/app
23
+
24
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
25
+ COPY --chown=user . $HOME/app
26
+
27
+ # Try and run pip command after setting the user with `USER user` to avoid permission issues with Python
28
+ RUN pip3 install -r requirements.txt
29
+
30
+ EXPOSE 8501
31
+
32
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
33
+
34
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0", "--server.enableXsrfProtection=false", "--server.enableCORS=false"]
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dnbr Tagger Preview1 Demo
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: docker
7
+ app_port: 8501
8
+ tags:
9
+ - streamlit
10
+ pinned: false
11
+ short_description: Demo of a Danbooru tagger
12
+ ---
13
+
14
+ # Welcome to Streamlit!
15
+
16
+ Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
+
18
+ If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
+ forums](https://discuss.streamlit.io).
freqs.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ altair
2
+ pandas
3
+ streamlit
4
+ timm
5
+ torch
6
+ torchvision
src/streamlit_app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ from PIL import Image
4
+ import io # Used to handle image bytes
5
+ import torch
6
+ import timm
7
+ import json
8
+ from torchvision.transforms.v2 import (
9
+ ToImage,
10
+ Compose,
11
+ ToDtype,
12
+ Normalize,
13
+ )
14
+ import pandas as pd
15
+ import requests
16
+
17
+
18
+ st.set_page_config(layout="wide")
19
+
20
+ device = ["cpu", "cuda"][torch.cuda.is_available()]
21
+
22
+
23
+ def NativeResize(patch_size, seq_len_range):
24
+ p, lo, hi = patch_size, *seq_len_range
25
+ refs = sorted(
26
+ [
27
+ (i / j, i * p, j * p)
28
+ for i in range(4, 100)
29
+ for j in range(4, 100)
30
+ if 0.33 <= i / j <= 3 and lo <= i * j < hi
31
+ ]
32
+ )
33
+
34
+ def get_ratio(r):
35
+ return min(refs, key=lambda rr: max(r, rr[0]) / min(r, rr[0]) - 1)
36
+
37
+ def f(im: Image.Image):
38
+ w, h = im.size
39
+ _, sw, sh = get_ratio(w / h)
40
+ return im.resize((sw, sh), resample=Image.Resampling.BICUBIC)
41
+
42
+ return f
43
+
44
+ def load_json_from_url(url):
45
+ try:
46
+ response = requests.get(url)
47
+ response.raise_for_status() # Raise an exception for bad status codes
48
+ parsed_json = json.loads(response.text)
49
+ return parsed_json
50
+ except requests.exceptions.RequestException as e:
51
+ print(f"Error fetching data from URL: {e}")
52
+ return None
53
+ except json.JSONDecodeError as e:
54
+ print(f"Error decoding JSON data: {e}")
55
+ return None
56
+
57
+
58
+ @st.cache_data
59
+ def load_tags():
60
+ freqs = load_json_from_url("https://huggingface.co/gustproof/dnbr-tagger-preview1/raw/main/freqs.json")
61
+ freqs = [*freqs, (("PLACEHOLDER", 0))]
62
+ return freqs
63
+
64
+
65
+ tags = load_tags()
66
+
67
+
68
+ @st.cache_resource
69
+ def load_model():
70
+ torch.set_grad_enabled(False)
71
+ model = (
72
+ timm.create_model(
73
+ "hf_hub:gustproof/dnbr-tagger-preview1",
74
+ pretrained=True,
75
+ dynamic_img_size=True,
76
+ )
77
+ .eval()
78
+ .to(device)
79
+ )
80
+ print("loaded model")
81
+ tf = Compose(
82
+ [
83
+ ToImage(),
84
+ ToDtype(torch.float, scale=True),
85
+ Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
86
+ ]
87
+ )
88
+
89
+ class Model:
90
+ def __init__(self):
91
+ self.class_names = load_tags()
92
+
93
+ def predict(self, img):
94
+ x = tf(img).unsqueeze(0).to(device)
95
+ probs = model(x).squeeze(0).sigmoid().cpu()
96
+ return probs
97
+
98
+ return Model()
99
+
100
+
101
+ model = load_model()
102
+
103
+ # --- Streamlit App Layout ---
104
+ st.title("Tagger demo")
105
+ st.write("Model: [gustproof/dnbr-tagger-preview1](https://huggingface.co/gustproof/dnbr-tagger-preview1)")
106
+ st.write(
107
+ "Upload an image to see predicted labels."
108
+ )
109
+ st.write("---")
110
+
111
+ # --- Sidebar for Controls ---
112
+ st.sidebar.header("Settings")
113
+ # Confidence Threshold Slider
114
+ confidence_threshold = st.sidebar.slider(
115
+ "Threshold (recommended: ~0.4-~0.6)",
116
+ min_value=0.0,
117
+ max_value=1.0,
118
+ value=0.5, # Default threshold
119
+ step=0.01,
120
+ )
121
+
122
+ # --- Main Area ---
123
+ col1, col2 = st.columns(2)
124
+
125
+ with col1:
126
+ st.header("Upload Image")
127
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
128
+
129
+ if uploaded_file is not None:
130
+ # Read the image bytes
131
+ image_bytes = uploaded_file.getvalue()
132
+ # Display the uploaded image
133
+ try:
134
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
135
+ image = NativeResize(14, (270, 301))(image)
136
+ st.info(f"Resied (width, height): {image.size}")
137
+ st.image(image, caption="Uploaded Image.")
138
+ except Exception as e:
139
+ st.error(f"Error opening image: {e}")
140
+ st.warning("Please upload a valid image file (JPG, JPEG, PNG).")
141
+ uploaded_file = None # Reset uploaded_file so processing stops
142
+
143
+ with col2:
144
+ st.header("Predictions")
145
+ if uploaded_file is not None:
146
+ with st.spinner("Computing..."):
147
+ try:
148
+ scores = model.predict(image)
149
+ filtered_results = [
150
+ (i, p) for i, p in enumerate(scores) if p >= confidence_threshold
151
+ ]
152
+ filtered_results.sort(key=lambda x: x[1], reverse=True)
153
+
154
+ if filtered_results:
155
+ get_category = lambda ti: (
156
+ "Rating" if ti < 4 else "General" if ti < 8856 else "Character"
157
+ )
158
+ df = pd.DataFrame(
159
+ [
160
+ (i, tags[ti][0], f"{p:.4f}", get_category(ti), tags[ti][1])
161
+ for i, (ti, p) in enumerate(filtered_results[:150], 1)
162
+ ],
163
+ columns=[
164
+ "Rank",
165
+ "Label",
166
+ "Score",
167
+ "Category",
168
+ "Dataset frequency",
169
+ ],
170
+ )
171
+ st.dataframe(
172
+ df,
173
+ hide_index=True,
174
+ column_config={
175
+ "Dataset frequency": st.column_config.NumberColumn(
176
+ format="localized"
177
+ )
178
+ },
179
+ )
180
+ else:
181
+ st.info("No labels meet the current threshold.")
182
+
183
+ except Exception as e:
184
+ st.error("An error occurred during prediction or processing:")
185
+ st.exception(e) # Shows the full traceback
186
+
187
+ else:
188
+ st.info("Upload an image using the panel on the left to see predictions.")