Spaces:
Runtime error
Runtime error
freemt
commited on
Commit
Β·
c471598
1
Parent(s):
0af688a
Update async aconvbot
Browse files- README.md +10 -1
- convbot/__init__.py +5 -2
- convbot/convbot.py +21 -1
- convbot/force_async.py +43 -0
- dist/convbot-0.1.0-py3-none-any.whl +0 -0
- dist/convbot-0.1.0.tar.gz +0 -0
- tests/test_convbot.py +21 -2
README.md
CHANGED
@@ -24,6 +24,15 @@ prin(convertbot("How are you?"))
|
|
24 |
# I am good # or along that line
|
25 |
```
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
Interactive
|
28 |
|
29 |
```bash
|
@@ -31,4 +40,4 @@ python -m convbot
|
|
31 |
```
|
32 |
## Not tested in Windows 10 and Mac
|
33 |
|
34 |
-
The module uses pytorch that is installed differently in Windows than in Linux. To run in Windows or Mac, you
|
|
|
24 |
# I am good # or along that line
|
25 |
```
|
26 |
|
27 |
+
The async version `aconvbot`, potentialy for `fastapi` or `Nonebot` plugins and such, is rather artificial since it's based on `ThreadPoolExecutor`. Hence it's not intended for production. You probably should not spawn too many instances.
|
28 |
+
```python
|
29 |
+
from convbot import aconvbot
|
30 |
+
|
31 |
+
async def afunc(text):
|
32 |
+
resp = await aconvbot(text)
|
33 |
+
...
|
34 |
+
```
|
35 |
+
|
36 |
Interactive
|
37 |
|
38 |
```bash
|
|
|
40 |
```
|
41 |
## Not tested in Windows 10 and Mac
|
42 |
|
43 |
+
The module uses pytorch that is installed differently in Windows than in Linux. To run `convbot` in Windows or Mac, you may give it a spin by cloning the repo (git clone [https://github.com/ffreemt/convbot](https://github.com/ffreemt/convbot)) and installing pytorch manually.
|
convbot/__init__.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
"""Init."""
|
2 |
__version__ = "0.1.0"
|
3 |
-
from .convbot import convbot
|
4 |
|
5 |
-
__all__ = (
|
|
|
|
|
|
|
|
1 |
"""Init."""
|
2 |
__version__ = "0.1.0"
|
3 |
+
from .convbot import convbot, aconvbot
|
4 |
|
5 |
+
__all__ = (
|
6 |
+
"convbot",
|
7 |
+
"aconvbot",
|
8 |
+
)
|
convbot/convbot.py
CHANGED
@@ -4,6 +4,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4 |
import torch
|
5 |
from logzero import logger
|
6 |
|
|
|
|
|
7 |
# model_name = "microsoft/DialoGPT-large"
|
8 |
# model_name = "microsoft/DialoGPT-small"
|
9 |
# pylint: disable=invalid-name
|
@@ -39,7 +41,7 @@ def _convbot(
|
|
39 |
chat_history_ids = ""
|
40 |
|
41 |
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
|
42 |
-
if chat_history_ids:
|
43 |
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
|
44 |
else:
|
45 |
bot_input_ids = input_ids
|
@@ -113,6 +115,24 @@ def convbot(
|
|
113 |
return resp
|
114 |
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def main():
|
117 |
print("Bot: Talk to me")
|
118 |
while 1:
|
|
|
4 |
import torch
|
5 |
from logzero import logger
|
6 |
|
7 |
+
from .force_async import force_async
|
8 |
+
|
9 |
# model_name = "microsoft/DialoGPT-large"
|
10 |
# model_name = "microsoft/DialoGPT-small"
|
11 |
# pylint: disable=invalid-name
|
|
|
41 |
chat_history_ids = ""
|
42 |
|
43 |
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
|
44 |
+
if isinstance(chat_history_ids, torch.Tensor):
|
45 |
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1)
|
46 |
else:
|
47 |
bot_input_ids = input_ids
|
|
|
115 |
return resp
|
116 |
|
117 |
|
118 |
+
@force_async
|
119 |
+
def aconvbot(
|
120 |
+
text: str,
|
121 |
+
n_retries: int = 3,
|
122 |
+
max_length: int = 1000,
|
123 |
+
do_sample: bool = True,
|
124 |
+
top_p: float = 0.95,
|
125 |
+
top_k: int = 0,
|
126 |
+
temperature: float = 0.75,
|
127 |
+
) -> str:
|
128 |
+
try:
|
129 |
+
_ = convbot(text,n_retries, max_length, do_sample, top_p, top_k, temperature)
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(e)
|
132 |
+
raise
|
133 |
+
return _
|
134 |
+
|
135 |
+
|
136 |
def main():
|
137 |
print("Bot: Talk to me")
|
138 |
while 1:
|
convbot/force_async.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Turn a sync func to async."""
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
import asyncio
|
4 |
+
import functools
|
5 |
+
|
6 |
+
def force_async(func):
|
7 |
+
"""Turn a sync func to async.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
func: a sync func
|
11 |
+
|
12 |
+
Return:
|
13 |
+
async func
|
14 |
+
|
15 |
+
Usage:
|
16 |
+
@force_async
|
17 |
+
def normal_func():
|
18 |
+
...
|
19 |
+
loop = asyncio.get_event_loop()
|
20 |
+
#~ tasks = [sync_loop1(1, 5), sync_loop1(2, 10)]
|
21 |
+
#~ res = loop.run_until_complete(asyncio.gather(*tasks)) # OK
|
22 |
+
res = loop.run_until_complete(
|
23 |
+
asyncio.gather(
|
24 |
+
*[
|
25 |
+
sync_loop1(1, 7),
|
26 |
+
sync_loop1(2, 6),
|
27 |
+
sync_loop1(2, 6),
|
28 |
+
async_func,
|
29 |
+
]
|
30 |
+
)
|
31 |
+
)
|
32 |
+
"""
|
33 |
+
# executor = ThreadPoolExecutor()
|
34 |
+
# from concurrent.futures import ThreadPoolExecutor
|
35 |
+
executor = ThreadPoolExecutor(max_workers=10)
|
36 |
+
|
37 |
+
@functools.wraps(func)
|
38 |
+
def wrapper(*args, **kwargs):
|
39 |
+
"""Preserve func info."""
|
40 |
+
future = executor.submit(func, *args, **kwargs)
|
41 |
+
return asyncio.wrap_future(future) # make it awaitable
|
42 |
+
|
43 |
+
return wrapper
|
dist/convbot-0.1.0-py3-none-any.whl
ADDED
Binary file (4.83 kB). View file
|
|
dist/convbot-0.1.0.tar.gz
ADDED
Binary file (3.97 kB). View file
|
|
tests/test_convbot.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
"""Test convbot."""
|
|
|
|
|
|
|
2 |
from convbot import __version__
|
3 |
-
from convbot import convbot
|
|
|
|
|
4 |
|
5 |
|
6 |
def test_version():
|
@@ -11,12 +16,26 @@ def test_version():
|
|
11 |
def test_sanity():
|
12 |
"""Sanity check."""
|
13 |
try:
|
14 |
-
assert not convbot()
|
15 |
except Exception:
|
16 |
assert True
|
17 |
|
18 |
|
19 |
def test_convbot():
|
|
|
20 |
resp = convbot("How are you?")
|
21 |
assert len(resp) > 3
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Test convbot."""
|
2 |
+
# import asyncio
|
3 |
+
import pytest
|
4 |
+
|
5 |
from convbot import __version__
|
6 |
+
from convbot import convbot, aconvbot
|
7 |
+
|
8 |
+
pytestmark = pytest.mark.asyncio
|
9 |
|
10 |
|
11 |
def test_version():
|
|
|
16 |
def test_sanity():
|
17 |
"""Sanity check."""
|
18 |
try:
|
19 |
+
assert not convbot("")
|
20 |
except Exception:
|
21 |
assert True
|
22 |
|
23 |
|
24 |
def test_convbot():
|
25 |
+
"""Test convbot."""
|
26 |
resp = convbot("How are you?")
|
27 |
assert len(resp) > 3
|
28 |
|
29 |
+
# 2nd call uses chat_history_ids
|
30 |
+
resp = convbot("How old are you?")
|
31 |
+
assert len(resp) > 3
|
32 |
+
|
33 |
+
|
34 |
+
async def tests_aconvbot():
|
35 |
+
"""Test aconvbot."""
|
36 |
+
resp = await aconvbot("How are you?")
|
37 |
+
assert len(resp) > 3
|
38 |
+
|
39 |
+
# 2nd call uses chat_history_ids
|
40 |
+
resp = await aconvbot("How old are you?")
|
41 |
+
assert len(resp) > 3
|