Spaces:
Running
Running
First Release
Browse files- .gitignore +3 -0
- app.py +53 -0
- config.py +16 -0
- database.py +34 -0
- encode.py +25 -0
- examples//345/214/226_alpha_bg.png +0 -0
- examples//346/260/270_white_bg.png +0 -0
- requirements.txt +0 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
.env
|
3 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from config import (
|
5 |
+
description,
|
6 |
+
article,
|
7 |
+
)
|
8 |
+
from encode import get_embeddings
|
9 |
+
from database import search_vector, format_search_results
|
10 |
+
|
11 |
+
def search_images(values):
|
12 |
+
image = Image.new("RGBA", values["composite"].size, (255, 255, 255, 255))
|
13 |
+
image.paste(values["composite"], mask=values["composite"])
|
14 |
+
embedding = get_embeddings([image])[0]
|
15 |
+
results = search_vector(embedding, limit=100)
|
16 |
+
formatted = format_search_results(results)
|
17 |
+
|
18 |
+
_deduplicated = '\t'.join(dict.fromkeys(result.kanji for result in formatted))
|
19 |
+
# TODO Format the results better
|
20 |
+
# Huge boxes using the right font for each of them?
|
21 |
+
|
22 |
+
return f"Results: {_deduplicated}"
|
23 |
+
|
24 |
+
# TODO FIND OUT HOW TO CHANGE THE DEFAULT EDITOR TAB?
|
25 |
+
input_image = gr.ImageEditor(
|
26 |
+
label="Write the Kanji you want to search for",
|
27 |
+
show_label=False,
|
28 |
+
type="pil",
|
29 |
+
brush=gr.Brush(
|
30 |
+
default_size=3,
|
31 |
+
color_mode="fixed",
|
32 |
+
colors=["#000000", "#ffffff"],
|
33 |
+
),
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
output_box = gr.Textbox()
|
38 |
+
|
39 |
+
|
40 |
+
demo = gr.Interface(
|
41 |
+
fn=search_images,
|
42 |
+
inputs=[input_image],
|
43 |
+
outputs=output_box,
|
44 |
+
title="Kanji Lookup",
|
45 |
+
description=description,
|
46 |
+
article=article,
|
47 |
+
examples="examples",
|
48 |
+
# cache_examples=False,
|
49 |
+
# live=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
qdrant_location = os.getenv('QDRANT_URL', "localhost")
|
4 |
+
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
5 |
+
|
6 |
+
description = """This is a Kanji image search demo. Draw or upload an image of an individual Kanji character."""
|
7 |
+
|
8 |
+
article = """
|
9 |
+
### About this project
|
10 |
+
|
11 |
+
You can find the source code as well as more information in https://github.com/etrotta/kanji_lookup
|
12 |
+
|
13 |
+
It uses the "kha-white/manga-ocr-base" ViT Encoder model to create embeddings, then uses a vector database (qdrant) to find similar characters.
|
14 |
+
|
15 |
+
The vector database has been populated with over 10k characters from [The KANJIDIC project](https://www.edrdg.org/wiki/index.php/KANJIDIC_Project), each rendered in multiple fonts downloaded from Google Fonts
|
16 |
+
"""
|
database.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import torch
|
3 |
+
from qdrant_client import QdrantClient, models
|
4 |
+
|
5 |
+
from config import qdrant_location, qdrant_api_key
|
6 |
+
|
7 |
+
qdrant = QdrantClient(qdrant_location, api_key=qdrant_api_key)
|
8 |
+
|
9 |
+
def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]:
|
10 |
+
hits = qdrant.search(
|
11 |
+
collection_name="kanji",
|
12 |
+
# query_vector=query_vector,
|
13 |
+
query_vector=query_vector.numpy(),
|
14 |
+
limit=limit,
|
15 |
+
with_payload=True,
|
16 |
+
)
|
17 |
+
return hits
|
18 |
+
|
19 |
+
@dataclasses.dataclass
|
20 |
+
class SearchResult:
|
21 |
+
kanji: str
|
22 |
+
font: str
|
23 |
+
score: float
|
24 |
+
|
25 |
+
def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]:
|
26 |
+
formatted = []
|
27 |
+
for point in hits:
|
28 |
+
kanji, font = point.payload["kanji"], point.payload["font"]
|
29 |
+
formatted.append(SearchResult(
|
30 |
+
kanji = kanji,
|
31 |
+
font = font,
|
32 |
+
score = point.score,
|
33 |
+
))
|
34 |
+
return formatted
|
encode.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
from transformers import (
|
4 |
+
VisionEncoderDecoderModel,
|
5 |
+
ViTImageProcessor, # Load extractor
|
6 |
+
ViTModel, # Load ViT encoder
|
7 |
+
)
|
8 |
+
MODEL = "kha-white/manga-ocr-base"
|
9 |
+
|
10 |
+
print("Loading models")
|
11 |
+
feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False)
|
12 |
+
encoder: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder
|
13 |
+
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
print('Using CUDA')
|
16 |
+
encoder.cuda()
|
17 |
+
else:
|
18 |
+
print('Using CPU')
|
19 |
+
|
20 |
+
def get_embeddings(images: list[Image.Image]) -> torch.Tensor:
|
21 |
+
"""Processes the images and returns their Embeddings"""
|
22 |
+
images_rgb = [image.convert("RGB") for image in images]
|
23 |
+
with torch.inference_mode():
|
24 |
+
pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"]
|
25 |
+
return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()
|
examples//345/214/226_alpha_bg.png
ADDED
![]() |
examples//346/260/270_white_bg.png
ADDED
![]() |
requirements.txt
ADDED
Binary file (2.99 kB). View file
|
|