matthoffner commited on
Commit
cd842ff
ยท
1 Parent(s): 2465248

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -6
main.py CHANGED
@@ -2,11 +2,12 @@ import fastapi
2
  import json
3
  import markdown
4
  import uvicorn
5
- from fastapi.responses import StreamingResponse, HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
  from ctransformers import AutoModelForCausalLM
9
  from pydantic import BaseModel
 
10
 
11
  config = {"max_seq_len": 4096}
12
  llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML',
@@ -45,13 +46,20 @@ async def chat(prompt = "Once upon a time there was a "):
45
  @app.post("/v1/chat/completions")
46
  async def chat(request: ChatCompletionRequest, response_mode=None):
47
  completion = llm(request.prompt)
48
- async def server_sent_events(chat_chunks):
 
 
 
49
  for chat_chunk in chat_chunks:
50
- print(chat_chunk)
51
- yield chat_chunk
52
- yield "[DONE]"
 
 
 
 
 
53
 
54
- return StreamingResponse(server_sent_events(completion))
55
 
56
  if __name__ == "__main__":
57
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import json
3
  import markdown
4
  import uvicorn
5
+ from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8
  from ctransformers import AutoModelForCausalLM
9
  from pydantic import BaseModel
10
+ from sse_starlette.sse import EventSourceResponse
11
 
12
  config = {"max_seq_len": 4096}
13
  llm = AutoModelForCausalLM.from_pretrained('TheBloke/MPT-7B-Storywriter-GGML',
 
46
  @app.post("/v1/chat/completions")
47
  async def chat(request: ChatCompletionRequest, response_mode=None):
48
  completion = llm(request.prompt)
49
+
50
+ async def server_sent_events(
51
+ chat_chunks,
52
+ ):
53
  for chat_chunk in chat_chunks:
54
+ yield dict(data=json.dumps(chat_chunk))
55
+ yield dict(data="[DONE]")
56
+
57
+ chunks = completion_or_chunks
58
+
59
+ return EventSourceResponse(
60
+ server_sent_events(chunks),
61
+ )
62
 
 
63
 
64
  if __name__ == "__main__":
65
  uvicorn.run(app, host="0.0.0.0", port=8000)