Spaces:
Running
Running
ffreemt
commited on
Commit
·
b9d6157
1
Parent(s):
34815a7
Add ruff.toml
Browse files- .gitignore +1 -0
- .ruff.toml +22 -0
- st_mlbee/gen_cmat.py +22 -10
- st_mlbee/info.py +0 -2
- st_mlbee/load_model_s.py +1 -1
- st_mlbee/utils.py +4 -2
.gitignore
CHANGED
@@ -145,3 +145,4 @@ install-sw.sh
|
|
145 |
install-sw1.sh
|
146 |
win10-install-memo.txt
|
147 |
model-s
|
|
|
|
145 |
install-sw1.sh
|
146 |
win10-install-memo.txt
|
147 |
model-s
|
148 |
+
model-s-v2
|
.ruff.toml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Assume Python 3.10.
|
2 |
+
target-version = "py310"
|
3 |
+
# Decrease the maximum line length to 79 characters.
|
4 |
+
line-length = 300
|
5 |
+
|
6 |
+
# pyflakes, pycodestyle, isort
|
7 |
+
# flake8 YTT, pydocstyle D, pylint PLC
|
8 |
+
lint.select = ["F", "E", "W", "I001", "YTT", "D", "PLC"]
|
9 |
+
# select = ["ALL"]
|
10 |
+
|
11 |
+
# D103 Missing docstring in public function
|
12 |
+
# D101 Missing docstring in public class
|
13 |
+
# `multi-line-summary-first-line` (D212)
|
14 |
+
# `one-blank-line-before-class` (D203)
|
15 |
+
# imported but unused (F401)
|
16 |
+
# assigned to but never used (F841)
|
17 |
+
# Missing dashed underline after section ("Args") (D407)
|
18 |
+
# D413 [*] Missing blank line after last section ("Returns")
|
19 |
+
# D213 [*] Multi-line docstring summary should start at the second line
|
20 |
+
lint.extend-ignore = ["D103", "D101", "D212", "D203", "F401", "F841", "D407", "D413", "D213"]
|
21 |
+
|
22 |
+
exclude = [".venv"]
|
st_mlbee/gen_cmat.py
CHANGED
@@ -1,20 +1,18 @@
|
|
1 |
"""Gen cmat for de/en text."""
|
2 |
# pylint: disable=invalid-name, too-many-branches
|
3 |
-
|
4 |
from typing import List, Optional
|
5 |
|
6 |
import more_itertools as mit
|
7 |
import numpy as np
|
8 |
|
|
|
|
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
# from model_pool import load_model_s
|
12 |
# from hf_model_s_cpu import model_s # load_model_s directly
|
13 |
from st_mlbee.load_model_s import load_model_s
|
14 |
|
15 |
-
# from logzero import logger
|
16 |
-
from loguru import logger
|
17 |
-
|
18 |
# from st_mlbee.cos_matrix2 import cos_matrix2
|
19 |
from .cos_matrix2 import cos_matrix2
|
20 |
|
@@ -29,7 +27,10 @@ except Exception as exc:
|
|
29 |
try:
|
30 |
# model = model_s()
|
31 |
# model = model_s(alive_bar_on=True)
|
32 |
-
|
|
|
|
|
|
|
33 |
except Exception as _:
|
34 |
logger.error(_)
|
35 |
raise
|
@@ -38,20 +39,31 @@ except Exception as _:
|
|
38 |
def gen_cmat(
|
39 |
text1: List[str],
|
40 |
text2: List[str],
|
41 |
-
bsize: int =
|
|
|
42 |
) -> np.ndarray:
|
43 |
"""Gen corr matrix for texts.
|
44 |
|
45 |
Args:
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
text1 = 'this is a test'
|
50 |
text2 = 'another test'
|
|
|
|
|
|
|
|
|
|
|
51 |
"""
|
|
|
|
|
52 |
bsize = int(bsize)
|
53 |
if bsize <= 0:
|
54 |
-
bsize =
|
55 |
|
56 |
if isinstance(text1, str):
|
57 |
text1 = [text1]
|
|
|
1 |
"""Gen cmat for de/en text."""
|
2 |
# pylint: disable=invalid-name, too-many-branches
|
|
|
3 |
from typing import List, Optional
|
4 |
|
5 |
import more_itertools as mit
|
6 |
import numpy as np
|
7 |
|
8 |
+
# from logzero import logger
|
9 |
+
from loguru import logger
|
10 |
from tqdm import tqdm
|
11 |
|
12 |
# from model_pool import load_model_s
|
13 |
# from hf_model_s_cpu import model_s # load_model_s directly
|
14 |
from st_mlbee.load_model_s import load_model_s
|
15 |
|
|
|
|
|
|
|
16 |
# from st_mlbee.cos_matrix2 import cos_matrix2
|
17 |
from .cos_matrix2 import cos_matrix2
|
18 |
|
|
|
27 |
try:
|
28 |
# model = model_s()
|
29 |
# model = model_s(alive_bar_on=True)
|
30 |
+
|
31 |
+
# default model-s mikeee/model_s_512
|
32 |
+
model_s = load_model_s()
|
33 |
+
# model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2
|
34 |
except Exception as _:
|
35 |
logger.error(_)
|
36 |
raise
|
|
|
39 |
def gen_cmat(
|
40 |
text1: List[str],
|
41 |
text2: List[str],
|
42 |
+
bsize: int = 32, # default batch_size of model.encode
|
43 |
+
model=None,
|
44 |
) -> np.ndarray:
|
45 |
"""Gen corr matrix for texts.
|
46 |
|
47 |
Args:
|
48 |
+
----
|
49 |
+
text1: typically '''...''' splitlines()
|
50 |
+
text2: typically '''...''' splitlines()
|
51 |
+
bsize: batch size, default 50
|
52 |
+
model: for encoding list of strings, default model-s of mikeee/model_s_512
|
53 |
+
|
54 |
text1 = 'this is a test'
|
55 |
text2 = 'another test'
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
-------
|
59 |
+
numpy array of cmat
|
60 |
+
|
61 |
"""
|
62 |
+
if model is None:
|
63 |
+
model = model_s
|
64 |
bsize = int(bsize)
|
65 |
if bsize <= 0:
|
66 |
+
bsize = 32
|
67 |
|
68 |
if isinstance(text1, str):
|
69 |
text1 = [text1]
|
st_mlbee/info.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
"""Present info about st-mlbee."""
|
2 |
from textwrap import dedent
|
3 |
-
|
4 |
import streamlit as st
|
5 |
|
6 |
from st_mlbee import __version__
|
@@ -10,7 +9,6 @@ from st_mlbee.utils import msg
|
|
10 |
|
11 |
def info():
|
12 |
"""Prep info page."""
|
13 |
-
|
14 |
st.subheader(f"st-mlbee {__version__}")
|
15 |
|
16 |
st.markdown(msg, unsafe_allow_html=True)
|
|
|
1 |
"""Present info about st-mlbee."""
|
2 |
from textwrap import dedent
|
|
|
3 |
import streamlit as st
|
4 |
|
5 |
from st_mlbee import __version__
|
|
|
9 |
|
10 |
def info():
|
11 |
"""Prep info page."""
|
|
|
12 |
st.subheader(f"st-mlbee {__version__}")
|
13 |
|
14 |
st.markdown(msg, unsafe_allow_html=True)
|
st_mlbee/load_model_s.py
CHANGED
@@ -3,13 +3,13 @@ Load model_s from hf.
|
|
3 |
|
4 |
cf aslo align-model-pool\model_pool\load_model.py and ycco make-upload-model-s.ipynb.
|
5 |
"""
|
6 |
-
import torch
|
7 |
import joblib
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
from loguru import logger
|
10 |
|
11 |
try:
|
12 |
loc = hf_hub_download("mikeee/model_s_512", "model-s", local_dir=".")
|
|
|
13 |
except Exception as exc:
|
14 |
logger.error(exc)
|
15 |
raise SystemExit(1) from exc
|
|
|
3 |
|
4 |
cf aslo align-model-pool\model_pool\load_model.py and ycco make-upload-model-s.ipynb.
|
5 |
"""
|
|
|
6 |
import joblib
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from loguru import logger
|
9 |
|
10 |
try:
|
11 |
loc = hf_hub_download("mikeee/model_s_512", "model-s", local_dir=".")
|
12 |
+
# loc2 = hf_hub_download("mikeee/model_s_512v2", "model-s-v2", local_dir=".")
|
13 |
except Exception as exc:
|
14 |
logger.error(exc)
|
15 |
raise SystemExit(1) from exc
|
st_mlbee/utils.py
CHANGED
@@ -290,7 +290,8 @@ def to_excel(df):
|
|
290 |
|
291 |
|
292 |
def get_table_download_link(df):
|
293 |
-
"""
|
|
|
294 |
|
295 |
Allowing the data in a given panda dataframe
|
296 |
to be downloaded.
|
@@ -307,7 +308,8 @@ def get_table_download_link(df):
|
|
307 |
|
308 |
|
309 |
def get_table_download_link_sents(df):
|
310 |
-
"""
|
|
|
311 |
|
312 |
Allowing the data in a given panda dataframe to be
|
313 |
downloaded for sents aligned.
|
|
|
290 |
|
291 |
|
292 |
def get_table_download_link(df):
|
293 |
+
"""
|
294 |
+
Generate a link.
|
295 |
|
296 |
Allowing the data in a given panda dataframe
|
297 |
to be downloaded.
|
|
|
308 |
|
309 |
|
310 |
def get_table_download_link_sents(df):
|
311 |
+
"""
|
312 |
+
Generate a link.
|
313 |
|
314 |
Allowing the data in a given panda dataframe to be
|
315 |
downloaded for sents aligned.
|