ffreemt commited on
Commit
b9d6157
·
1 Parent(s): 34815a7

Add ruff.toml

Browse files
.gitignore CHANGED
@@ -145,3 +145,4 @@ install-sw.sh
145
  install-sw1.sh
146
  win10-install-memo.txt
147
  model-s
 
 
145
  install-sw1.sh
146
  win10-install-memo.txt
147
  model-s
148
+ model-s-v2
.ruff.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Assume Python 3.10.
2
+ target-version = "py310"
3
+ # Decrease the maximum line length to 79 characters.
4
+ line-length = 300
5
+
6
+ # pyflakes, pycodestyle, isort
7
+ # flake8 YTT, pydocstyle D, pylint PLC
8
+ lint.select = ["F", "E", "W", "I001", "YTT", "D", "PLC"]
9
+ # select = ["ALL"]
10
+
11
+ # D103 Missing docstring in public function
12
+ # D101 Missing docstring in public class
13
+ # `multi-line-summary-first-line` (D212)
14
+ # `one-blank-line-before-class` (D203)
15
+ # imported but unused (F401)
16
+ # assigned to but never used (F841)
17
+ # Missing dashed underline after section ("Args") (D407)
18
+ # D413 [*] Missing blank line after last section ("Returns")
19
+ # D213 [*] Multi-line docstring summary should start at the second line
20
+ lint.extend-ignore = ["D103", "D101", "D212", "D203", "F401", "F841", "D407", "D413", "D213"]
21
+
22
+ exclude = [".venv"]
st_mlbee/gen_cmat.py CHANGED
@@ -1,20 +1,18 @@
1
  """Gen cmat for de/en text."""
2
  # pylint: disable=invalid-name, too-many-branches
3
-
4
  from typing import List, Optional
5
 
6
  import more_itertools as mit
7
  import numpy as np
8
 
 
 
9
  from tqdm import tqdm
10
 
11
  # from model_pool import load_model_s
12
  # from hf_model_s_cpu import model_s # load_model_s directly
13
  from st_mlbee.load_model_s import load_model_s
14
 
15
- # from logzero import logger
16
- from loguru import logger
17
-
18
  # from st_mlbee.cos_matrix2 import cos_matrix2
19
  from .cos_matrix2 import cos_matrix2
20
 
@@ -29,7 +27,10 @@ except Exception as exc:
29
  try:
30
  # model = model_s()
31
  # model = model_s(alive_bar_on=True)
32
- model = load_model_s()
 
 
 
33
  except Exception as _:
34
  logger.error(_)
35
  raise
@@ -38,20 +39,31 @@ except Exception as _:
38
  def gen_cmat(
39
  text1: List[str],
40
  text2: List[str],
41
- bsize: int = 50
 
42
  ) -> np.ndarray:
43
  """Gen corr matrix for texts.
44
 
45
  Args:
46
- text1: typically '''...''' splitlines()
47
- text2: typically '''...''' splitlines()
48
- bsize: batch size, default 50
 
 
 
49
  text1 = 'this is a test'
50
  text2 = 'another test'
 
 
 
 
 
51
  """
 
 
52
  bsize = int(bsize)
53
  if bsize <= 0:
54
- bsize = 50
55
 
56
  if isinstance(text1, str):
57
  text1 = [text1]
 
1
  """Gen cmat for de/en text."""
2
  # pylint: disable=invalid-name, too-many-branches
 
3
  from typing import List, Optional
4
 
5
  import more_itertools as mit
6
  import numpy as np
7
 
8
+ # from logzero import logger
9
+ from loguru import logger
10
  from tqdm import tqdm
11
 
12
  # from model_pool import load_model_s
13
  # from hf_model_s_cpu import model_s # load_model_s directly
14
  from st_mlbee.load_model_s import load_model_s
15
 
 
 
 
16
  # from st_mlbee.cos_matrix2 import cos_matrix2
17
  from .cos_matrix2 import cos_matrix2
18
 
 
27
  try:
28
  # model = model_s()
29
  # model = model_s(alive_bar_on=True)
30
+
31
+ # default model-s mikeee/model_s_512
32
+ model_s = load_model_s()
33
+ # model_s_v2 = load_model_s("model_s_512v2") # model-s mikeee/model-s-512v2
34
  except Exception as _:
35
  logger.error(_)
36
  raise
 
39
  def gen_cmat(
40
  text1: List[str],
41
  text2: List[str],
42
+ bsize: int = 32, # default batch_size of model.encode
43
+ model=None,
44
  ) -> np.ndarray:
45
  """Gen corr matrix for texts.
46
 
47
  Args:
48
+ ----
49
+ text1: typically '''...''' splitlines()
50
+ text2: typically '''...''' splitlines()
51
+ bsize: batch size, default 50
52
+ model: for encoding list of strings, default model-s of mikeee/model_s_512
53
+
54
  text1 = 'this is a test'
55
  text2 = 'another test'
56
+
57
+ Returns:
58
+ -------
59
+ numpy array of cmat
60
+
61
  """
62
+ if model is None:
63
+ model = model_s
64
  bsize = int(bsize)
65
  if bsize <= 0:
66
+ bsize = 32
67
 
68
  if isinstance(text1, str):
69
  text1 = [text1]
st_mlbee/info.py CHANGED
@@ -1,6 +1,5 @@
1
  """Present info about st-mlbee."""
2
  from textwrap import dedent
3
-
4
  import streamlit as st
5
 
6
  from st_mlbee import __version__
@@ -10,7 +9,6 @@ from st_mlbee.utils import msg
10
 
11
  def info():
12
  """Prep info page."""
13
-
14
  st.subheader(f"st-mlbee {__version__}")
15
 
16
  st.markdown(msg, unsafe_allow_html=True)
 
1
  """Present info about st-mlbee."""
2
  from textwrap import dedent
 
3
  import streamlit as st
4
 
5
  from st_mlbee import __version__
 
9
 
10
  def info():
11
  """Prep info page."""
 
12
  st.subheader(f"st-mlbee {__version__}")
13
 
14
  st.markdown(msg, unsafe_allow_html=True)
st_mlbee/load_model_s.py CHANGED
@@ -3,13 +3,13 @@ Load model_s from hf.
3
 
4
  cf aslo align-model-pool\model_pool\load_model.py and ycco make-upload-model-s.ipynb.
5
  """
6
- import torch
7
  import joblib
8
  from huggingface_hub import hf_hub_download
9
  from loguru import logger
10
 
11
  try:
12
  loc = hf_hub_download("mikeee/model_s_512", "model-s", local_dir=".")
 
13
  except Exception as exc:
14
  logger.error(exc)
15
  raise SystemExit(1) from exc
 
3
 
4
  cf aslo align-model-pool\model_pool\load_model.py and ycco make-upload-model-s.ipynb.
5
  """
 
6
  import joblib
7
  from huggingface_hub import hf_hub_download
8
  from loguru import logger
9
 
10
  try:
11
  loc = hf_hub_download("mikeee/model_s_512", "model-s", local_dir=".")
12
+ # loc2 = hf_hub_download("mikeee/model_s_512v2", "model-s-v2", local_dir=".")
13
  except Exception as exc:
14
  logger.error(exc)
15
  raise SystemExit(1) from exc
st_mlbee/utils.py CHANGED
@@ -290,7 +290,8 @@ def to_excel(df):
290
 
291
 
292
  def get_table_download_link(df):
293
- """Generates a link.
 
294
 
295
  Allowing the data in a given panda dataframe
296
  to be downloaded.
@@ -307,7 +308,8 @@ def get_table_download_link(df):
307
 
308
 
309
  def get_table_download_link_sents(df):
310
- """Generates a link.
 
311
 
312
  Allowing the data in a given panda dataframe to be
313
  downloaded for sents aligned.
 
290
 
291
 
292
  def get_table_download_link(df):
293
+ """
294
+ Generate a link.
295
 
296
  Allowing the data in a given panda dataframe
297
  to be downloaded.
 
308
 
309
 
310
  def get_table_download_link_sents(df):
311
+ """
312
+ Generate a link.
313
 
314
  Allowing the data in a given panda dataframe to be
315
  downloaded for sents aligned.