etrotta commited on
Commit
be12cc9
·
1 Parent(s): 510a6b4

First Release

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ .env
3
+ __pycache__
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import gradio as gr
3
+
4
+ from config import (
5
+ description,
6
+ article,
7
+ )
8
+ from encode import get_embeddings
9
+ from database import search_vector, format_search_results
10
+
11
+ def search_images(values):
12
+ image = Image.new("RGBA", values["composite"].size, (255, 255, 255, 255))
13
+ image.paste(values["composite"], mask=values["composite"])
14
+ embedding = get_embeddings([image])[0]
15
+ results = search_vector(embedding, limit=100)
16
+ formatted = format_search_results(results)
17
+
18
+ _deduplicated = '\t'.join(dict.fromkeys(result.kanji for result in formatted))
19
+ # TODO Format the results better
20
+ # Huge boxes using the right font for each of them?
21
+
22
+ return f"Results: {_deduplicated}"
23
+
24
+ # TODO FIND OUT HOW TO CHANGE THE DEFAULT EDITOR TAB?
25
+ input_image = gr.ImageEditor(
26
+ label="Write the Kanji you want to search for",
27
+ show_label=False,
28
+ type="pil",
29
+ brush=gr.Brush(
30
+ default_size=3,
31
+ color_mode="fixed",
32
+ colors=["#000000", "#ffffff"],
33
+ ),
34
+ )
35
+
36
+
37
+ output_box = gr.Textbox()
38
+
39
+
40
+ demo = gr.Interface(
41
+ fn=search_images,
42
+ inputs=[input_image],
43
+ outputs=output_box,
44
+ title="Kanji Lookup",
45
+ description=description,
46
+ article=article,
47
+ examples="examples",
48
+ # cache_examples=False,
49
+ # live=True,
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ demo.launch()
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ qdrant_location = os.getenv('QDRANT_URL', "localhost")
4
+ qdrant_api_key = os.getenv('QDRANT_API_KEY')
5
+
6
+ description = """This is a Kanji image search demo. Draw or upload an image of an individual Kanji character."""
7
+
8
+ article = """
9
+ ### About this project
10
+
11
+ You can find the source code as well as more information in https://github.com/etrotta/kanji_lookup
12
+
13
+ It uses the "kha-white/manga-ocr-base" ViT Encoder model to create embeddings, then uses a vector database (qdrant) to find similar characters.
14
+
15
+ The vector database has been populated with over 10k characters from [The KANJIDIC project](https://www.edrdg.org/wiki/index.php/KANJIDIC_Project), each rendered in multiple fonts downloaded from Google Fonts
16
+ """
database.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from qdrant_client import QdrantClient, models
4
+
5
+ from config import qdrant_location, qdrant_api_key
6
+
7
+ qdrant = QdrantClient(qdrant_location, api_key=qdrant_api_key)
8
+
9
+ def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[models.ScoredPoint]:
10
+ hits = qdrant.search(
11
+ collection_name="kanji",
12
+ # query_vector=query_vector,
13
+ query_vector=query_vector.numpy(),
14
+ limit=limit,
15
+ with_payload=True,
16
+ )
17
+ return hits
18
+
19
+ @dataclasses.dataclass
20
+ class SearchResult:
21
+ kanji: str
22
+ font: str
23
+ score: float
24
+
25
+ def format_search_results(hits: list[models.ScoredPoint]) -> list[SearchResult]:
26
+ formatted = []
27
+ for point in hits:
28
+ kanji, font = point.payload["kanji"], point.payload["font"]
29
+ formatted.append(SearchResult(
30
+ kanji = kanji,
31
+ font = font,
32
+ score = point.score,
33
+ ))
34
+ return formatted
encode.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from transformers import (
4
+ VisionEncoderDecoderModel,
5
+ ViTImageProcessor, # Load extractor
6
+ ViTModel, # Load ViT encoder
7
+ )
8
+ MODEL = "kha-white/manga-ocr-base"
9
+
10
+ print("Loading models")
11
+ feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False)
12
+ encoder: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder
13
+
14
+ if torch.cuda.is_available():
15
+ print('Using CUDA')
16
+ encoder.cuda()
17
+ else:
18
+ print('Using CPU')
19
+
20
+ def get_embeddings(images: list[Image.Image]) -> torch.Tensor:
21
+ """Processes the images and returns their Embeddings"""
22
+ images_rgb = [image.convert("RGB") for image in images]
23
+ with torch.inference_mode():
24
+ pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"]
25
+ return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()
examples//345/214/226_alpha_bg.png ADDED
examples//346/260/270_white_bg.png ADDED
requirements.txt ADDED
Binary file (2.99 kB). View file