sanmmarr29 commited on
Commit
b57108e
·
verified ·
1 Parent(s): a2aaca5

Upload 4 files

Browse files
Files changed (3) hide show
  1. app/main.py +20 -29
  2. app/static/styles.css +27 -0
  3. app/templates/chat.html +34 -4
app/main.py CHANGED
@@ -2,14 +2,13 @@ from fastapi import FastAPI, Request
2
  from fastapi.templating import Jinja2Templates
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import HTMLResponse
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- import torch
7
  from .config import settings
8
  from pydantic import BaseModel
9
 
10
  app = FastAPI(
11
- title="Deepseek Chat API",
12
- description="A simple chat API using DeepSeek model",
13
  version="1.0.0"
14
  )
15
 
@@ -17,13 +16,12 @@ app = FastAPI(
17
  app.mount("/static", StaticFiles(directory="app/static"), name="static")
18
  templates = Jinja2Templates(directory="app/templates")
19
 
20
- # Initialize model and tokenizer
21
- tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME, token=settings.HUGGINGFACE_TOKEN)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- settings.MODEL_NAME,
 
24
  token=settings.HUGGINGFACE_TOKEN,
25
- torch_dtype=torch.float16,
26
- device_map="auto",
27
  trust_remote_code=True
28
  )
29
 
@@ -36,28 +34,21 @@ async def home(request: Request):
36
 
37
  @app.post("/chat")
38
  async def chat(message: ChatMessage):
39
- # Prepare the prompt
40
-
41
- print(message.message)
42
- prompt = f"### Instruction: {message.message}\n\n### Response:"
43
 
44
- # Generate response
45
- print(prompt)
46
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
- outputs = model.generate(
48
- **inputs,
49
- max_new_tokens=512,
50
- temperature=0.7,
51
- do_sample=True,
52
- pad_token_id=tokenizer.eos_token_id
53
- )
54
 
55
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- print(response)
57
- # Extract only the response part
58
- response = response.split("### Response:")[-1].strip()
 
59
 
60
- return {"response": response}
61
 
62
  if __name__ == "__main__":
63
  import uvicorn
 
2
  from fastapi.templating import Jinja2Templates
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.responses import HTMLResponse
5
+ from transformers import pipeline
 
6
  from .config import settings
7
  from pydantic import BaseModel
8
 
9
  app = FastAPI(
10
+ title="DeepSeek Chat",
11
+ description="A chat API using DeepSeek model",
12
  version="1.0.0"
13
  )
14
 
 
16
  app.mount("/static", StaticFiles(directory="app/static"), name="static")
17
  templates = Jinja2Templates(directory="app/templates")
18
 
19
+ # Initialize pipeline
20
+ print("Loading model pipeline...")
21
+ pipe = pipeline(
22
+ "text-generation",
23
+ model=settings.MODEL_NAME,
24
  token=settings.HUGGINGFACE_TOKEN,
 
 
25
  trust_remote_code=True
26
  )
27
 
 
34
 
35
  @app.post("/chat")
36
  async def chat(message: ChatMessage):
37
+ # Prepare messages
38
+ messages = [
39
+ {"role": "user", "content": message.message}
40
+ ]
41
 
42
+ # Generate response using pipeline
43
+ response = pipe(messages)
 
 
 
 
 
 
 
 
44
 
45
+ # Extract the response text
46
+ if isinstance(response, list):
47
+ response_text = response[0].get('generated_text', '')
48
+ else:
49
+ response_text = response.get('generated_text', '')
50
 
51
+ return {"response": response_text}
52
 
53
  if __name__ == "__main__":
54
  import uvicorn
app/static/styles.css CHANGED
@@ -24,4 +24,31 @@
24
 
25
  #chat-messages > div {
26
  animation: fadeIn 0.3s ease-out forwards;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  }
 
24
 
25
  #chat-messages > div {
26
  animation: fadeIn 0.3s ease-out forwards;
27
+ }
28
+
29
+ /* Typing indicator */
30
+ .typing-indicator {
31
+ display: flex;
32
+ align-items: center;
33
+ justify-content: center;
34
+ }
35
+
36
+ .typing-indicator .dot {
37
+ animation: typingAnimation 1.4s infinite;
38
+ display: inline-block;
39
+ margin: 0 2px;
40
+ }
41
+
42
+ .typing-indicator .dot:nth-child(2) {
43
+ animation-delay: 0.2s;
44
+ }
45
+
46
+ .typing-indicator .dot:nth-child(3) {
47
+ animation-delay: 0.4s;
48
+ }
49
+
50
+ @keyframes typingAnimation {
51
+ 0% { transform: translateY(0px); }
52
+ 28% { transform: translateY(-6px); }
53
+ 44% { transform: translateY(0px); }
54
  }
app/templates/chat.html CHANGED
@@ -25,7 +25,8 @@
25
  class="flex-1 rounded-lg border border-gray-300 px-4 py-2 focus:outline-none focus:border-blue-500"
26
  placeholder="Type your message...">
27
  <button type="submit"
28
- class="bg-blue-500 text-white px-6 py-2 rounded-lg hover:bg-blue-600 focus:outline-none">
 
29
  Send
30
  </button>
31
  </form>
@@ -37,20 +38,27 @@
37
  const chatMessages = document.getElementById('chat-messages');
38
  const chatForm = document.getElementById('chat-form');
39
  const messageInput = document.getElementById('message-input');
 
40
 
41
- function appendMessage(content, isUser) {
42
  const messageDiv = document.createElement('div');
43
  messageDiv.className = `flex ${isUser ? 'justify-end' : 'justify-start'}`;
44
 
45
  const bubble = document.createElement('div');
46
  bubble.className = `max-w-[70%] rounded-lg p-3 ${
47
  isUser ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-800'
48
- }`;
49
- bubble.textContent = content;
 
 
 
 
 
50
 
51
  messageDiv.appendChild(bubble);
52
  chatMessages.appendChild(messageDiv);
53
  chatMessages.scrollTop = chatMessages.scrollHeight;
 
54
  }
55
 
56
  chatForm.addEventListener('submit', async (e) => {
@@ -58,12 +66,20 @@
58
  const message = messageInput.value.trim();
59
  if (!message) return;
60
 
 
 
 
 
 
61
  // Clear input
62
  messageInput.value = '';
63
 
64
  // Add user message
65
  appendMessage(message, true);
66
 
 
 
 
67
  try {
68
  const response = await fetch('/chat', {
69
  method: 'POST',
@@ -74,10 +90,24 @@
74
  });
75
 
76
  const data = await response.json();
 
 
 
 
 
77
  appendMessage(data.response, false);
78
  } catch (error) {
 
 
 
79
  appendMessage('Sorry, something went wrong. Please try again.', false);
80
  console.error('Error:', error);
 
 
 
 
 
 
81
  }
82
  });
83
  </script>
 
25
  class="flex-1 rounded-lg border border-gray-300 px-4 py-2 focus:outline-none focus:border-blue-500"
26
  placeholder="Type your message...">
27
  <button type="submit"
28
+ id="send-button"
29
+ class="bg-blue-500 text-white px-6 py-2 rounded-lg hover:bg-blue-600 focus:outline-none transition-colors duration-200">
30
  Send
31
  </button>
32
  </form>
 
38
  const chatMessages = document.getElementById('chat-messages');
39
  const chatForm = document.getElementById('chat-form');
40
  const messageInput = document.getElementById('message-input');
41
+ const sendButton = document.getElementById('send-button');
42
 
43
+ function appendMessage(content, isUser, isLoading = false) {
44
  const messageDiv = document.createElement('div');
45
  messageDiv.className = `flex ${isUser ? 'justify-end' : 'justify-start'}`;
46
 
47
  const bubble = document.createElement('div');
48
  bubble.className = `max-w-[70%] rounded-lg p-3 ${
49
  isUser ? 'bg-blue-500 text-white' : 'bg-gray-100 text-gray-800'
50
+ } ${isLoading ? 'typing-indicator' : ''}`;
51
+
52
+ if (isLoading) {
53
+ bubble.innerHTML = '<span class="dot">.</span><span class="dot">.</span><span class="dot">.</span>';
54
+ } else {
55
+ bubble.textContent = content;
56
+ }
57
 
58
  messageDiv.appendChild(bubble);
59
  chatMessages.appendChild(messageDiv);
60
  chatMessages.scrollTop = chatMessages.scrollHeight;
61
+ return messageDiv;
62
  }
63
 
64
  chatForm.addEventListener('submit', async (e) => {
 
66
  const message = messageInput.value.trim();
67
  if (!message) return;
68
 
69
+ // Disable input and button
70
+ messageInput.disabled = true;
71
+ sendButton.disabled = true;
72
+ sendButton.classList.add('opacity-50');
73
+
74
  // Clear input
75
  messageInput.value = '';
76
 
77
  // Add user message
78
  appendMessage(message, true);
79
 
80
+ // Add loading message
81
+ const loadingDiv = appendMessage('', false, true);
82
+
83
  try {
84
  const response = await fetch('/chat', {
85
  method: 'POST',
 
90
  });
91
 
92
  const data = await response.json();
93
+
94
+ // Remove loading message
95
+ loadingDiv.remove();
96
+
97
+ // Add AI response
98
  appendMessage(data.response, false);
99
  } catch (error) {
100
+ // Remove loading message
101
+ loadingDiv.remove();
102
+
103
  appendMessage('Sorry, something went wrong. Please try again.', false);
104
  console.error('Error:', error);
105
+ } finally {
106
+ // Re-enable input and button
107
+ messageInput.disabled = false;
108
+ sendButton.disabled = false;
109
+ sendButton.classList.remove('opacity-50');
110
+ messageInput.focus();
111
  }
112
  });
113
  </script>