Spaces:
Runtime error
Runtime error
Upload 13 files
Browse files- app.py +728 -0
- codegen_torch.py +187 -0
- gpt2_pytorch.py +210 -0
- image_to_3d_openlrm.py +31 -0
- imagegen_vae_unet.py +164 -0
- lipsync_wav2lip.py +57 -0
- musicgen_torch.py +36 -0
- sentiment_roberta.py +195 -0
- stt_wav2vec2.py +46 -0
- summarization_bart.py +34 -0
- text_to_video_clip4clip.py +34 -0
- translation_mbart.py +267 -0
- tts_vits.py +57 -0
app.py
ADDED
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import copy
|
9 |
+
import requests
|
10 |
+
from functools import lru_cache
|
11 |
+
from tqdm import tqdm
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
from sklearn.datasets import fetch_20newsgroups
|
14 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
15 |
+
from sklearn.linear_model import LogisticRegression
|
16 |
+
from sklearn.multiclass import OneVsRestClassifier
|
17 |
+
import time
|
18 |
+
import threading
|
19 |
+
import queue
|
20 |
+
import httpx
|
21 |
+
import asyncio
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import uuid
|
25 |
+
import wget
|
26 |
+
from duckduckgo_search import DDGS
|
27 |
+
import warnings
|
28 |
+
from datetime import datetime
|
29 |
+
import unicodedata
|
30 |
+
import nltk
|
31 |
+
import torchaudio
|
32 |
+
import logging
|
33 |
+
from PIL import Image
|
34 |
+
from io import BytesIO
|
35 |
+
import sentencepiece as spm
|
36 |
+
from flask import Flask, request, jsonify, send_file, Response
|
37 |
+
from flask_cors import CORS
|
38 |
+
|
39 |
+
nltk.download('punkt', quiet=True)
|
40 |
+
|
41 |
+
GPT2_FOLDER = "./GPT2"
|
42 |
+
MODEL_FILE = "gpt2-pytorch_model.bin"
|
43 |
+
ENCODER_FILE = "encoder.json"
|
44 |
+
VOCAB_FILE = "vocab.bpe"
|
45 |
+
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
|
46 |
+
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/encoder.json"
|
47 |
+
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/vocab.bpe"
|
48 |
+
GPT2_FILES_URLS = [
|
49 |
+
(MODEL_URL, MODEL_FILE),
|
50 |
+
(ENCODER_URL, ENCODER_FILE),
|
51 |
+
(VOCAB_URL, VOCAB_FILE),
|
52 |
+
]
|
53 |
+
|
54 |
+
TEXT_GENERATION_RATE = 40000
|
55 |
+
MAX_LENGTH = 1024
|
56 |
+
MAX_XDD = 5
|
57 |
+
END_OF_TEXT_TOKEN = "<|endoftext|>"
|
58 |
+
|
59 |
+
html_code = """<!DOCTYPE html>
|
60 |
+
<html lang="en">
|
61 |
+
<head>
|
62 |
+
<meta charset="UTF-8">
|
63 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
64 |
+
<title>AI Text Generation</title>
|
65 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
|
66 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
67 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
68 |
+
<style>
|
69 |
+
body {
|
70 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
71 |
+
background: #f0f0f0;
|
72 |
+
color: #333;
|
73 |
+
margin: 0;
|
74 |
+
padding: 0;
|
75 |
+
display: flex;
|
76 |
+
flex-direction: column;
|
77 |
+
align-items: center;
|
78 |
+
min-height: 100vh;
|
79 |
+
}
|
80 |
+
.container {
|
81 |
+
width: 95%;
|
82 |
+
max-width: 900px;
|
83 |
+
padding: 20px;
|
84 |
+
background-color: #fff;
|
85 |
+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
86 |
+
border-radius: 8px;
|
87 |
+
margin-top: 20px;
|
88 |
+
margin-bottom: 20px;
|
89 |
+
display: flex;
|
90 |
+
flex-direction: column;
|
91 |
+
}
|
92 |
+
.header {
|
93 |
+
text-align: center;
|
94 |
+
margin-bottom: 20px;
|
95 |
+
}
|
96 |
+
.header h1 {
|
97 |
+
font-size: 2em;
|
98 |
+
color: #333;
|
99 |
+
}
|
100 |
+
.form-group {
|
101 |
+
margin-bottom: 15px;
|
102 |
+
}
|
103 |
+
.form-group textarea {
|
104 |
+
width: 100%;
|
105 |
+
padding: 10px;
|
106 |
+
border: 1px solid #ccc;
|
107 |
+
border-radius: 5px;
|
108 |
+
font-size: 16px;
|
109 |
+
box-sizing: border-box;
|
110 |
+
resize: vertical;
|
111 |
+
}
|
112 |
+
button {
|
113 |
+
padding: 10px 15px;
|
114 |
+
border: none;
|
115 |
+
border-radius: 5px;
|
116 |
+
background-color: #007bff;
|
117 |
+
color: white;
|
118 |
+
font-size: 18px;
|
119 |
+
cursor: pointer;
|
120 |
+
transition: background-color 0.3s ease;
|
121 |
+
}
|
122 |
+
button:hover {
|
123 |
+
background-color: #0056b3;
|
124 |
+
}
|
125 |
+
#output {
|
126 |
+
margin-top: 20px;
|
127 |
+
padding: 15px;
|
128 |
+
border: 1px solid #ddd;
|
129 |
+
border-radius: 5px;
|
130 |
+
background-color: #f9f9f9;
|
131 |
+
white-space: pre-wrap;
|
132 |
+
word-break: break-word;
|
133 |
+
overflow-y: auto;
|
134 |
+
max-height: 100vh;
|
135 |
+
}
|
136 |
+
#output strong {
|
137 |
+
font-weight: bold;
|
138 |
+
}
|
139 |
+
.animated-text {
|
140 |
+
position: fixed;
|
141 |
+
top: 20px;
|
142 |
+
left: 20px;
|
143 |
+
font-size: 1.5em;
|
144 |
+
color: rgba(0, 0, 0, 0.1);
|
145 |
+
pointer-events: none;
|
146 |
+
z-index: -1;
|
147 |
+
}
|
148 |
+
@media (max-width: 768px) {
|
149 |
+
.container {
|
150 |
+
width: 98%;
|
151 |
+
margin-top: 10px;
|
152 |
+
margin-bottom: 10px;
|
153 |
+
padding: 15px;
|
154 |
+
}
|
155 |
+
.header h1 {
|
156 |
+
font-size: 1.8em;
|
157 |
+
}
|
158 |
+
.form-group textarea, .form-group input[type="text"] {
|
159 |
+
font-size: 14px;
|
160 |
+
padding: 8px;
|
161 |
+
}
|
162 |
+
button {
|
163 |
+
font-size: 16px;
|
164 |
+
padding: 8px 12px;
|
165 |
+
}
|
166 |
+
#output {
|
167 |
+
font-size: 14px;
|
168 |
+
padding: 10px;
|
169 |
+
margin-top: 15px;
|
170 |
+
}
|
171 |
+
}
|
172 |
+
</style>
|
173 |
+
</head>
|
174 |
+
<body>
|
175 |
+
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
|
176 |
+
<div class="container">
|
177 |
+
<div class="header animate__animated animate__fadeInDown">
|
178 |
+
</div>
|
179 |
+
<div class="form-group animate__animated animate__fadeInLeft">
|
180 |
+
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
|
181 |
+
</div>
|
182 |
+
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
|
183 |
+
<div id="output" class="animate__animated">
|
184 |
+
<strong >Response:</strong><br>
|
185 |
+
<div id="generatedText"></div>
|
186 |
+
</div>
|
187 |
+
</div>
|
188 |
+
<script>
|
189 |
+
let eventSource = null;
|
190 |
+
let accumulatedText = "";
|
191 |
+
let lastResponse = "";
|
192 |
+
let currentSpan = null;
|
193 |
+
let messageCounter = 0;
|
194 |
+
|
195 |
+
async function generateText() {
|
196 |
+
const inputText = document.getElementById("text").value;
|
197 |
+
const generatedTextDiv = document.getElementById("generatedText");
|
198 |
+
generatedTextDiv.innerHTML = "";
|
199 |
+
accumulatedText = "";
|
200 |
+
lastResponse = "";
|
201 |
+
currentSpan = null;
|
202 |
+
messageCounter = 0;
|
203 |
+
|
204 |
+
if (eventSource) {
|
205 |
+
eventSource.close();
|
206 |
+
}
|
207 |
+
const temp = 0.7;
|
208 |
+
const top_k_val = 40;
|
209 |
+
const top_p_val = 0.0;
|
210 |
+
const repetition_penalty_val = 1.2;
|
211 |
+
eventSource = new EventSource(`/generate_stream?text=${encodeURIComponent(inputText)}&temp=${temp}&top_k=${top_k_val}&top_p=${top_p_val}&reppenalty=${reppenalty_val}`);
|
212 |
+
eventSource.onmessage = function(event) {
|
213 |
+
if (event.data === "<END_STREAM>") {
|
214 |
+
eventSource.close();
|
215 |
+
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
|
216 |
+
if (currentResponse === lastResponse.trim()) {
|
217 |
+
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
|
218 |
+
} else {
|
219 |
+
lastResponse = currentResponse;
|
220 |
+
}
|
221 |
+
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
|
222 |
+
return;
|
223 |
+
}
|
224 |
+
try {
|
225 |
+
const jsonData = JSON.parse(event.data);
|
226 |
+
const token = jsonData.token;
|
227 |
+
if (token === "<|endoftext|>" || token === "<END_STREAM>") {
|
228 |
+
return;
|
229 |
+
}
|
230 |
+
if (token === "<NEW_MESSAGE>") {
|
231 |
+
messageCounter++;
|
232 |
+
if (messageCounter > 1) {
|
233 |
+
generatedTextDiv.innerHTML += "<br><br><hr style='border-top: 1px dashed #8c8b8b; margin-top: 10px; margin-bottom: 10px;'><strong>Continued Response:</strong><br><div id='generatedText_" + messageCounter + "'></div>";
|
234 |
+
generatedTextDiv = document.getElementById("generatedText_" + messageCounter);
|
235 |
+
accumulatedText = "";
|
236 |
+
}
|
237 |
+
return;
|
238 |
+
}
|
239 |
+
accumulatedText += token + " ";
|
240 |
+
} catch (e) {
|
241 |
+
console.error("Error parsing SSE data:", event.data, e);
|
242 |
+
}
|
243 |
+
};
|
244 |
+
eventSource.onerror = function(error) {
|
245 |
+
console.error("SSE error", error);
|
246 |
+
eventSource.close();
|
247 |
+
};
|
248 |
+
const outputDiv = document.getElementById("output");
|
249 |
+
outputDiv.classList.add("show");
|
250 |
+
}
|
251 |
+
</script>
|
252 |
+
</body>
|
253 |
+
</html>
|
254 |
+
"""
|
255 |
+
|
256 |
+
TRANSLATION_FOLDER = "./TranslationModel"
|
257 |
+
TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
|
258 |
+
TRANSLATION_MODEL_CONFIG_FILE = "config.json"
|
259 |
+
TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
|
260 |
+
TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
|
261 |
+
TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
|
262 |
+
TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
263 |
+
TRANSLATION_MODEL_FILES_URLS = [
|
264 |
+
(TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
|
265 |
+
(TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
|
266 |
+
(TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
|
267 |
+
]
|
268 |
+
|
269 |
+
CODEGEN_FOLDER = "./CodeGenModel"
|
270 |
+
CODEGEN_MODEL_NAME = "codegen-350M-multi"
|
271 |
+
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
272 |
+
CODEGEN_CONFIG = "config.json"
|
273 |
+
CODEGEN_VOCAB = "vocab.json"
|
274 |
+
CODEGEN_MERGES = "merges.txt"
|
275 |
+
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
|
276 |
+
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
|
277 |
+
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
|
278 |
+
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
|
279 |
+
CODEGEN_FILES_URLS = [
|
280 |
+
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
|
281 |
+
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
|
282 |
+
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
|
283 |
+
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
|
284 |
+
]
|
285 |
+
|
286 |
+
TTS_FOLDER = "./TTSModel"
|
287 |
+
TTS_MODEL_NAME = "vits"
|
288 |
+
TTS_MODEL_CONFIG = "config.json"
|
289 |
+
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
|
290 |
+
TTS_VOCAB = "vocab.json"
|
291 |
+
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
|
292 |
+
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
|
293 |
+
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
|
294 |
+
TTS_FILES_URLS = [
|
295 |
+
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
|
296 |
+
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
|
297 |
+
(TTS_VOCAB_URL, TTS_VOCAB),
|
298 |
+
]
|
299 |
+
|
300 |
+
STT_FOLDER = "./STTModel"
|
301 |
+
STT_MODEL_NAME = "wav2vec2"
|
302 |
+
STT_MODEL_WEIGHTS = "pytorch_model.bin"
|
303 |
+
STT_CONFIG = "config.json"
|
304 |
+
STT_VOCAB = "vocab.json"
|
305 |
+
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
|
306 |
+
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
|
307 |
+
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
308 |
+
STT_FILES_URLS = [
|
309 |
+
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
|
310 |
+
(STT_CONFIG_URL, STT_CONFIG),
|
311 |
+
(STT_VOCAB_URL, STT_VOCAB),
|
312 |
+
]
|
313 |
+
|
314 |
+
SENTIMENT_FOLDER = "./SentimentModel"
|
315 |
+
SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
|
316 |
+
SENTIMENT_VOCAB = "sentiment_vocab.json"
|
317 |
+
SENTIMENT_CONFIG = "config.json"
|
318 |
+
SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/pytorch_model.bin"
|
319 |
+
SENTIMENT_VOCAB_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/vocab.json"
|
320 |
+
SENTIMENT_CONFIG_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/config.json"
|
321 |
+
SENTIMENT_FILES_URLS = [
|
322 |
+
(SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
|
323 |
+
(SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
|
324 |
+
(SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG),
|
325 |
+
]
|
326 |
+
|
327 |
+
IMAGEGEN_FOLDER = "./ImageGenModel"
|
328 |
+
IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
|
329 |
+
IMAGEGEN_CONFIG = "config.json"
|
330 |
+
IMAGEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
|
331 |
+
IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
|
332 |
+
IMAGEGEN_FILES_URLS = [
|
333 |
+
(IMAGEGEN_MODEL_WEIGHTS_URL, IMAGEGEN_MODEL_WEIGHTS),
|
334 |
+
(IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG),
|
335 |
+
]
|
336 |
+
|
337 |
+
LIPSYNC_FOLDER = "./LipSyncModel"
|
338 |
+
LIPSYNC_MODEL_WEIGHTS = "lipsync_expert.pth"
|
339 |
+
LIPSYNC_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Flipsync%5Fexpert%2Epth"
|
340 |
+
LIPSYNC_FILES_URLS = [
|
341 |
+
(LIPSYNC_MODEL_WEIGHTS_URL, LIPSYNC_MODEL_WEIGHTS),
|
342 |
+
]
|
343 |
+
|
344 |
+
WAV2LIP_FOLDER = "./Wav2LipModel"
|
345 |
+
WAV2LIP_MODEL_WEIGHTS = "wav2lip_gan.pth"
|
346 |
+
WAV2LIP_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Fwav2lip%5Fgan%2Epth"
|
347 |
+
WAV2LIP_FILES_URLS = [
|
348 |
+
(WAV2LIP_MODEL_WEIGHTS_URL, WAV2LIP_MODEL_WEIGHTS),
|
349 |
+
]
|
350 |
+
|
351 |
+
MUSICGEN_FOLDER = "./MusicGenModel"
|
352 |
+
MUSICGEN_MODEL_NAME = "melody"
|
353 |
+
MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
354 |
+
MUSICGEN_CONFIG = "config.json"
|
355 |
+
MUSICGEN_SAMPLE_RATE = 32000
|
356 |
+
MUSICGEN_DURATION = 8
|
357 |
+
MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
|
358 |
+
MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
|
359 |
+
MUSICGEN_FILES_URLS = [
|
360 |
+
(MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
|
361 |
+
(MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG),
|
362 |
+
]
|
363 |
+
|
364 |
+
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
|
365 |
+
CODEGEN_SPM = "spm.model"
|
366 |
+
|
367 |
+
TRANSLATION_SPM_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
368 |
+
TRANSLATION_SPM = "sentencepiece.bpe.model"
|
369 |
+
|
370 |
+
TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
|
371 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS = "pytorch_model.bin"
|
372 |
+
TEXT_TO_VIDEO_CONFIG = "config.json"
|
373 |
+
TEXT_TO_VIDEO_VOCAB = "vocab.json"
|
374 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/pytorch_model.bin"
|
375 |
+
TEXT_TO_VIDEO_CONFIG_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/config.json"
|
376 |
+
TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/vocab.json"
|
377 |
+
TEXT_TO_VIDEO_FILES_URLS = [
|
378 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
379 |
+
(TEXT_TO_VIDEO_CONFIG_URL, TEXT_TO_VIDEO_CONFIG),
|
380 |
+
(TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
|
381 |
+
]
|
382 |
+
|
383 |
+
SUMMARIZATION_FOLDER = "./SummarizationModel"
|
384 |
+
SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
|
385 |
+
SUMMARIZATION_CONFIG = "config.json"
|
386 |
+
SUMMARIZATION_VOCAB = "vocab.json"
|
387 |
+
SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
|
388 |
+
SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
|
389 |
+
SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
|
390 |
+
SUMMARIZATION_FILES_URLS = [
|
391 |
+
(SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
|
392 |
+
(SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
|
393 |
+
(SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB),
|
394 |
+
]
|
395 |
+
|
396 |
+
IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
|
397 |
+
IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
|
398 |
+
IMAGE_TO_3D_CONFIG = "config.json"
|
399 |
+
IMAGE_TO_3D_MODEL_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
|
400 |
+
IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
|
401 |
+
IMAGE_TO_3D_FILES_URLS = [
|
402 |
+
(IMAGE_TO_3D_MODEL_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
|
403 |
+
(IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG),
|
404 |
+
]
|
405 |
+
|
406 |
+
|
407 |
+
state_dict = None
|
408 |
+
enc = None
|
409 |
+
config = None
|
410 |
+
model = None
|
411 |
+
device = torch.device("cpu")
|
412 |
+
news_clf = None
|
413 |
+
tfidf_vectorizer = None
|
414 |
+
text_queue = queue.Queue()
|
415 |
+
categories = None
|
416 |
+
is_training = False
|
417 |
+
background_threads = []
|
418 |
+
feedback_queue = queue.Queue()
|
419 |
+
reasoning_queue = queue.Queue()
|
420 |
+
seen_responses = set()
|
421 |
+
tts_model = None
|
422 |
+
stt_model = None
|
423 |
+
sentiment_model = None
|
424 |
+
imagegen_model = None
|
425 |
+
lipsync_model = None
|
426 |
+
wav2lip_model = None
|
427 |
+
musicgen_model = None
|
428 |
+
translation_model = None
|
429 |
+
codegen_model = None
|
430 |
+
text_to_video_model = None
|
431 |
+
summarization_model = None
|
432 |
+
image_to_3d_model = None
|
433 |
+
tts_pipeline = False
|
434 |
+
stt_pipeline = False
|
435 |
+
sentiment_pipeline = False
|
436 |
+
imagegen_pipeline = False
|
437 |
+
translation_pipeline = False
|
438 |
+
codegen_pipeline = False
|
439 |
+
text_to_video_pipeline = False
|
440 |
+
summarization_pipeline = False
|
441 |
+
image_to_3d_pipeline = False
|
442 |
+
stt_tokenizer = None
|
443 |
+
stt_processor = None
|
444 |
+
sentiment_tokenizer = None
|
445 |
+
sentiment_model_instance = None
|
446 |
+
imagegen_vae = None
|
447 |
+
imagegen_unet = None
|
448 |
+
imagegen_scheduler = None
|
449 |
+
musicgen_model_instance = None
|
450 |
+
musicgen_tokenizer = None
|
451 |
+
musicgen_processor = None
|
452 |
+
translation_model_instance = None
|
453 |
+
translation_tokenizer = None
|
454 |
+
codegen_model_instance = None
|
455 |
+
codegen_tokenizer = None
|
456 |
+
codegen_sp = None
|
457 |
+
translation_sp = None
|
458 |
+
text_to_video_tokenizer = None
|
459 |
+
text_to_video_model_instance = None
|
460 |
+
summarization_tokenizer = None
|
461 |
+
summarization_model_instance = None
|
462 |
+
image_to_3d_config = None
|
463 |
+
image_to_3d_model_instance = None
|
464 |
+
app = Flask(__name__)
|
465 |
+
CORS(app)
|
466 |
+
|
467 |
+
from gpt2_pytorch import *
|
468 |
+
from tts_vits import *
|
469 |
+
from stt_wav2vec2 import *
|
470 |
+
from sentiment_roberta import *
|
471 |
+
from imagegen_vae_unet import *
|
472 |
+
from musicgen_torch import *
|
473 |
+
from translation_mbart import *
|
474 |
+
from codegen_torch import *
|
475 |
+
from text_to_video_clip4clip import *
|
476 |
+
from summarization_bart import *
|
477 |
+
from image_to_3d_openlrm import *
|
478 |
+
|
479 |
+
def download_file(url, filename):
|
480 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True) # Ensure directory exists
|
481 |
+
if not os.path.exists(filename):
|
482 |
+
print(f"Downloading {filename} from {url}...")
|
483 |
+
try:
|
484 |
+
wget.download(url, out=filename) # Specify output filename directly
|
485 |
+
print(f"Downloaded {filename} successfully.")
|
486 |
+
except Exception as e:
|
487 |
+
print(f"Error downloading {filename}: {e}")
|
488 |
+
|
489 |
+
def ensure_folder_and_files_exist(folder_path, files_urls):
|
490 |
+
if not os.path.exists(folder_path):
|
491 |
+
os.makedirs(folder_path)
|
492 |
+
print(f"Folder '{folder_path}' created.")
|
493 |
+
|
494 |
+
for url, filename in files_urls:
|
495 |
+
filepath = os.path.join(folder_path, filename)
|
496 |
+
download_file(url, filepath)
|
497 |
+
|
498 |
+
def ensure_single_file_exists(folder_path, file_url, filename):
|
499 |
+
if not os.path.exists(folder_path):
|
500 |
+
os.makedirs(folder_path)
|
501 |
+
print(f"Folder '{folder_path}' created.")
|
502 |
+
filepath = os.path.join(folder_path, filename)
|
503 |
+
download_file(file_url, filepath)
|
504 |
+
|
505 |
+
|
506 |
+
def ensure_gpt2_files_exist():
|
507 |
+
ensure_folder_and_files_exist(GPT2_FOLDER, GPT2_FILES_URLS)
|
508 |
+
|
509 |
+
def ensure_translation_files_exist():
|
510 |
+
ensure_folder_and_files_exist(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
511 |
+
ensure_single_file_exists(TRANSLATION_FOLDER, TRANSLATION_SPM_URL, TRANSLATION_SPM)
|
512 |
+
|
513 |
+
def ensure_codegen_files_exist():
|
514 |
+
ensure_folder_and_files_exist(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
515 |
+
ensure_single_file_exists(CODEGEN_FOLDER, CODEGEN_SPM_URL, CODEGEN_SPM)
|
516 |
+
|
517 |
+
def ensure_tts_files_exist():
|
518 |
+
ensure_folder_and_files_exist(TTS_FOLDER, TTS_FILES_URLS)
|
519 |
+
|
520 |
+
def ensure_stt_files_exist():
|
521 |
+
ensure_folder_and_files_exist(STT_FOLDER, STT_FILES_URLS)
|
522 |
+
|
523 |
+
def ensure_sentiment_files_exist():
|
524 |
+
ensure_folder_and_files_exist(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
|
525 |
+
|
526 |
+
def ensure_imagegen_files_exist():
|
527 |
+
ensure_folder_and_files_exist(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
|
528 |
+
|
529 |
+
def ensure_lipsync_files_exist():
|
530 |
+
ensure_folder_and_files_exist(LIPSYNC_FOLDER, LIPSYNC_FILES_URLS)
|
531 |
+
|
532 |
+
def ensure_wav2lip_files_exist():
|
533 |
+
ensure_folder_and_files_exist(WAV2LIP_FOLDER, WAV2LIP_FILES_URLS)
|
534 |
+
|
535 |
+
def ensure_musicgen_files_exist():
|
536 |
+
ensure_folder_and_files_exist(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
537 |
+
|
538 |
+
def ensure_text_to_video_files_exist():
|
539 |
+
ensure_folder_and_files_exist(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
|
540 |
+
|
541 |
+
def ensure_summarization_files_exist():
|
542 |
+
ensure_folder_and_files_exist(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
|
543 |
+
|
544 |
+
def ensure_image_to_3d_files_exist():
|
545 |
+
ensure_folder_and_files_exist(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
|
546 |
+
|
547 |
+
def ensure_all_model_files_exist(): # Define the function here, before it's called
|
548 |
+
ensure_gpt2_files_exist()
|
549 |
+
ensure_translation_files_exist()
|
550 |
+
ensure_codegen_files_exist()
|
551 |
+
ensure_tts_files_exist()
|
552 |
+
ensure_stt_files_exist()
|
553 |
+
ensure_sentiment_files_exist()
|
554 |
+
ensure_imagegen_files_exist()
|
555 |
+
ensure_lipsync_files_exist()
|
556 |
+
ensure_wav2lip_files_exist()
|
557 |
+
ensure_musicgen_files_exist()
|
558 |
+
ensure_text_to_video_files_exist()
|
559 |
+
ensure_summarization_files_exist()
|
560 |
+
ensure_image_to_3d_files_exist()
|
561 |
+
|
562 |
+
|
563 |
+
@app.route("/", methods=['GET'])
|
564 |
+
async def html_handler():
|
565 |
+
return html_code
|
566 |
+
|
567 |
+
@app.route("/generate_stream", methods=['GET'])
|
568 |
+
async def generate_stream_api():
|
569 |
+
text_input = request.args.get("text")
|
570 |
+
temperature = float(request.args.get("temp", 0.7))
|
571 |
+
top_k = int(request.args.get("top_k", 40))
|
572 |
+
top_p = float(request.args.get("top_p", 0.0))
|
573 |
+
reppenalty = float(request.args.get("reppenalty", 1.2))
|
574 |
+
return Response(generate_stream_generator(text_input, temperature, top_k, top_p, reppenalty), mimetype='text/event-stream')
|
575 |
+
|
576 |
+
@app.route("/tts", methods=['POST'])
|
577 |
+
def tts_api():
|
578 |
+
data = request.get_json()
|
579 |
+
text = data.get('text')
|
580 |
+
if not text:
|
581 |
+
return jsonify({"error": "Text is required"}), 400
|
582 |
+
output_file = text_to_speech(text)
|
583 |
+
if output_file == "Error generating speech.":
|
584 |
+
return jsonify({"error": "TTS generation failed"}), 500
|
585 |
+
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
586 |
+
|
587 |
+
@app.route("/stt", methods=['POST'])
|
588 |
+
def stt_api():
|
589 |
+
if 'audio' not in request.files:
|
590 |
+
return jsonify({"error": "Audio file is required"}), 400
|
591 |
+
audio_file = request.files['audio']
|
592 |
+
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
593 |
+
audio_file.save(temp_audio_path)
|
594 |
+
output_file = speech_to_text(temp_audio_path)
|
595 |
+
os.remove(temp_audio_path)
|
596 |
+
if output_file == "Error transcribing audio.":
|
597 |
+
return jsonify({"error": "STT failed"}), 500
|
598 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output.txt")
|
599 |
+
|
600 |
+
@app.route("/sentiment", methods=['POST'])
|
601 |
+
def sentiment_api():
|
602 |
+
data = request.get_json()
|
603 |
+
text = data.get('text')
|
604 |
+
if not text:
|
605 |
+
return jsonify({"error": "Text is required"}), 400
|
606 |
+
output_file = analyze_sentiment(text)
|
607 |
+
if output_file == "Sentiment model not initialized.":
|
608 |
+
return jsonify({"error": "Sentiment analysis failed"}), 500
|
609 |
+
return jsonify(output_file)
|
610 |
+
|
611 |
+
@app.route("/imagegen", methods=['POST'])
|
612 |
+
def imagegen_api():
|
613 |
+
data = request.get_json()
|
614 |
+
prompt = data.get('prompt')
|
615 |
+
if not prompt:
|
616 |
+
return jsonify({"error": "Prompt is required"}), 400
|
617 |
+
output_file = generate_image(prompt)
|
618 |
+
if output_file == "Error generating image.":
|
619 |
+
return jsonify({"error": "Image generation failed"}), 500
|
620 |
+
image_io = BytesIO()
|
621 |
+
output_file.save(image_io, 'PNG')
|
622 |
+
image_io.seek(0)
|
623 |
+
return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
|
624 |
+
|
625 |
+
@app.route("/musicgen", methods=['POST'])
|
626 |
+
def musicgen_api():
|
627 |
+
data = request.get_json()
|
628 |
+
prompt = data.get('prompt')
|
629 |
+
if not prompt:
|
630 |
+
return jsonify({"error": "Prompt is required"}), 400
|
631 |
+
output_file = generate_music(prompt)
|
632 |
+
if output_file == "Error generating music.":
|
633 |
+
return jsonify({"error": "Music generation failed"}), 500
|
634 |
+
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
635 |
+
|
636 |
+
@app.route("/translation", methods=['POST'])
|
637 |
+
def translation_api():
|
638 |
+
data = request.get_json()
|
639 |
+
text = data.get('text')
|
640 |
+
target_lang = data.get('target_lang', 'es')
|
641 |
+
source_lang = data.get('source_lang', 'en')
|
642 |
+
if not text:
|
643 |
+
return jsonify({"error": "Text is required"}), 400
|
644 |
+
output_file = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
645 |
+
if output_file == "Error during translation.":
|
646 |
+
return jsonify({"error": "Translation failed"}), 500
|
647 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_translation.txt")
|
648 |
+
|
649 |
+
@app.route("/codegen", methods=['POST'])
|
650 |
+
def codegen_api():
|
651 |
+
data = request.get_json()
|
652 |
+
prompt = data.get('prompt')
|
653 |
+
if not prompt:
|
654 |
+
return jsonify({"error": "Prompt is required"}), 400
|
655 |
+
output_file = generate_code(prompt)
|
656 |
+
if output_file == "Error generating code.":
|
657 |
+
return jsonify({"error": "Code generation failed"}), 500
|
658 |
+
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
659 |
+
|
660 |
+
@app.route("/text_to_video", methods=['POST'])
|
661 |
+
def text_to_video_api():
|
662 |
+
data = request.get_json()
|
663 |
+
prompt = data.get('prompt')
|
664 |
+
if not prompt:
|
665 |
+
return jsonify({"error": "Prompt is required"}), 400
|
666 |
+
output_file = text_to_video(prompt)
|
667 |
+
if output_file == "Error generating video representation.":
|
668 |
+
return jsonify({"error": "Text to video failed"}), 500
|
669 |
+
return send_file(output_file, mimetype="application/octet-stream", as_attachment=True, download_name="output_video_representation.pt")
|
670 |
+
|
671 |
+
@app.route("/summarization", methods=['POST'])
|
672 |
+
def summarization_api():
|
673 |
+
data = request.get_json()
|
674 |
+
text = data.get('text')
|
675 |
+
if not text:
|
676 |
+
return jsonify({"error": "Text is required"}), 400
|
677 |
+
output_file = summarize_text(text)
|
678 |
+
if output_file == "Error during summarization.":
|
679 |
+
return jsonify({"error": "Summarization failed"}), 500
|
680 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
681 |
+
|
682 |
+
@app.route("/image_to_3d", methods=['POST'])
|
683 |
+
def image_to_3d_api():
|
684 |
+
if 'image' not in request.files:
|
685 |
+
return jsonify({"error": "Image file is required"}), 400
|
686 |
+
image_file = request.files['image']
|
687 |
+
temp_image_path = f"temp_image_{uuid.uuid4()}.png"
|
688 |
+
image_file.save(temp_image_path)
|
689 |
+
output_file = image_to_3d(temp_image_path)
|
690 |
+
os.remove(temp_image_path)
|
691 |
+
if output_file == "Error converting image to 3D.":
|
692 |
+
return jsonify({"error": "Image to 3D failed"}), 500
|
693 |
+
return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
|
694 |
+
|
695 |
+
|
696 |
+
async def main():
|
697 |
+
global background_threads, response_queue
|
698 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
699 |
+
response_queue = queue.Queue()
|
700 |
+
|
701 |
+
ensure_all_model_files_exist()
|
702 |
+
initialize_model()
|
703 |
+
await initialize_sklearn()
|
704 |
+
initialize_tts_model()
|
705 |
+
initialize_stt_model()
|
706 |
+
initialize_sentiment_model()
|
707 |
+
initialize_imagegen_model()
|
708 |
+
ensure_lipsync_files_exist()
|
709 |
+
ensure_wav2lip_files_exist()
|
710 |
+
initialize_musicgen_model()
|
711 |
+
initialize_translation_model()
|
712 |
+
initialize_codegen_model()
|
713 |
+
initialize_text_to_video_model()
|
714 |
+
initialize_summarization_model()
|
715 |
+
initialize_image_to_3d_model()
|
716 |
+
|
717 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
|
718 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
|
719 |
+
background_threads.append(threading.Thread(target=background_training, daemon=True))
|
720 |
+
for thread in background_threads:
|
721 |
+
thread.start()
|
722 |
+
|
723 |
+
asyncio.create_task(background_reasoning_queue())
|
724 |
+
|
725 |
+
app.run(host="127.0.0.1", port=7860, debug=False)
|
726 |
+
|
727 |
+
if __name__ == '__main__':
|
728 |
+
asyncio.run(main())
|
codegen_torch.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import wget
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import sentencepiece as spm
|
8 |
+
import re
|
9 |
+
|
10 |
+
CODEGEN_FOLDER = "./CodeGenModel"
|
11 |
+
CODEGEN_MODEL_NAME = "codegen-350M-multi"
|
12 |
+
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
13 |
+
CODEGEN_CONFIG = "config.json"
|
14 |
+
CODEGEN_VOCAB = "vocab.json"
|
15 |
+
CODEGEN_MERGES = "merges.txt"
|
16 |
+
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
|
17 |
+
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
|
18 |
+
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
|
19 |
+
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
|
20 |
+
CODEGEN_FILES_URLS = [
|
21 |
+
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
|
22 |
+
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
|
23 |
+
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
|
24 |
+
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
|
25 |
+
]
|
26 |
+
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
|
27 |
+
CODEGEN_SPM = "spm.model"
|
28 |
+
|
29 |
+
def ensure_codegen_files_exist():
|
30 |
+
os.makedirs(CODEGEN_FOLDER, exist_ok=True)
|
31 |
+
for url, filename in CODEGEN_FILES_URLS:
|
32 |
+
filepath = os.path.join(CODEGEN_FOLDER, filename)
|
33 |
+
if not os.path.exists(filepath):
|
34 |
+
wget.download(url, out=filepath)
|
35 |
+
filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM)
|
36 |
+
if not os.path.exists(filepath_spm):
|
37 |
+
wget.download(CODEGEN_SPM_URL, out=filepath_spm)
|
38 |
+
|
39 |
+
class CodeGenConfig:
|
40 |
+
def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs):
|
41 |
+
self.vocab_size = vocab_size
|
42 |
+
self.n_positions = n_positions
|
43 |
+
self.n_ctx = n_ctx
|
44 |
+
self.n_embd = n_embd
|
45 |
+
self.n_layer = n_layer
|
46 |
+
self.n_head = n_head
|
47 |
+
self.n_inner = n_inner
|
48 |
+
self.activation_function = activation_function
|
49 |
+
self.resid_pdrop = resid_pdrop
|
50 |
+
self.embd_pdrop = embd_pdrop
|
51 |
+
self.attn_pdrop = attn_pdrop
|
52 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
53 |
+
self.initializer_range = initializer_range
|
54 |
+
self.scale_attn_weights = scale_attn_weights
|
55 |
+
self.use_cache = use_cache
|
56 |
+
self.bos_token_id = bos_token_id
|
57 |
+
self.eos_token_id = eos_token_id
|
58 |
+
for key, value in kwargs.items():
|
59 |
+
setattr(self, key, value)
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_dict(cls, config_dict):
|
63 |
+
return cls(**config_dict)
|
64 |
+
|
65 |
+
class CodeGenForCausalLM(nn.Module):
|
66 |
+
def __init__(self, config):
|
67 |
+
super().__init__()
|
68 |
+
self.transformer = CodeGenModel(config)
|
69 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
70 |
+
|
71 |
+
def forward(self, input_ids, attention_mask=None):
|
72 |
+
transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask)
|
73 |
+
logits = self.lm_head(transformer_outputs)
|
74 |
+
return logits
|
75 |
+
|
76 |
+
class CodeGenModel(nn.Module):
|
77 |
+
def __init__(self, config):
|
78 |
+
super().__init__()
|
79 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
80 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
81 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
82 |
+
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
|
83 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
84 |
+
|
85 |
+
def forward(self, input_ids, attention_mask=None):
|
86 |
+
input_shape = input_ids.size()
|
87 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
88 |
+
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device)
|
89 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
90 |
+
inputs_embeds = self.wte(input_ids)
|
91 |
+
position_embeds = self.wpe(position_ids)
|
92 |
+
hidden_states = inputs_embeds + position_embeds
|
93 |
+
hidden_states = self.drop(hidden_states)
|
94 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
95 |
+
for block in self.h:
|
96 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
97 |
+
hidden_states = self.ln_f(hidden_states)
|
98 |
+
return hidden_states.view(*output_shape)
|
99 |
+
|
100 |
+
class CodeGenBlock(nn.Module):
|
101 |
+
def __init__(self, config):
|
102 |
+
super().__init__()
|
103 |
+
self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
104 |
+
self.attn = CodeGenAttention(config)
|
105 |
+
self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
106 |
+
self.mlp = CodeGenMLP(config)
|
107 |
+
|
108 |
+
def forward(self, hidden_states, attention_mask=None):
|
109 |
+
residual = hidden_states
|
110 |
+
hidden_states = self.ln_1(hidden_states)
|
111 |
+
attn_outputs = self.attn(hidden_states, attention_mask=attention_mask)
|
112 |
+
hidden_states = residual + attn_outputs
|
113 |
+
residual = hidden_states
|
114 |
+
hidden_states = self.ln_2(hidden_states)
|
115 |
+
feedforward_hidden_states = self.mlp(hidden_states)
|
116 |
+
hidden_states = residual + feedforward_hidden_states
|
117 |
+
return hidden_states
|
118 |
+
|
119 |
+
class CodeGenMLP(nn.Module):
|
120 |
+
def __init__(self, config):
|
121 |
+
super().__init__()
|
122 |
+
self.c_fc = nn.Linear(config.n_embd, config.n_inner)
|
123 |
+
self.c_proj = nn.Linear(config.n_inner, config.n_embd)
|
124 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
125 |
+
|
126 |
+
def forward(self, hidden_states):
|
127 |
+
hidden_states = self.c_fc(hidden_states)
|
128 |
+
hidden_states = F.gelu(hidden_states)
|
129 |
+
hidden_states = self.c_proj(hidden_states)
|
130 |
+
hidden_states = self.dropout(hidden_states)
|
131 |
+
return hidden_states
|
132 |
+
|
133 |
+
class CodeGenAttention(nn.Module):
|
134 |
+
def __init__(self, config):
|
135 |
+
super().__init__()
|
136 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
137 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
138 |
+
self.n_head = config.n_head
|
139 |
+
self.embed_dim = config.n_embd
|
140 |
+
self.split_size = self.embed_dim
|
141 |
+
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
|
142 |
+
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
143 |
+
self.scale_attn_weights = config.scale_attn_weights
|
144 |
+
self.use_cache = config.use_cache
|
145 |
+
self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx)))
|
146 |
+
|
147 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
148 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
149 |
+
if self.scale_attn_weights:
|
150 |
+
attn_weights = attn_weights / math.sqrt(value.size(-1))
|
151 |
+
|
152 |
+
mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)]
|
153 |
+
attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device))
|
154 |
+
|
155 |
+
if attention_mask is not None:
|
156 |
+
attn_weights = attn_weights + attention_mask
|
157 |
+
|
158 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
159 |
+
attn_weights = self.attn_dropout(attn_weights)
|
160 |
+
attn_output = torch.matmul(attn_weights, value)
|
161 |
+
return attn_output
|
162 |
+
|
163 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
164 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
165 |
+
tensor = tensor.view(*new_shape)
|
166 |
+
return tensor.permute(0, 2, 1, 3)
|
167 |
+
|
168 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
169 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
170 |
+
return tensor.view(*new_shape)
|
171 |
+
|
172 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False):
|
173 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
174 |
+
query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head)
|
175 |
+
key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head)
|
176 |
+
value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head)
|
177 |
+
if past_key_value is not None:
|
178 |
+
past_key, past_value = past_key_value
|
179 |
+
key = torch.cat((past_key, key), dim=-2)
|
180 |
+
value = torch.cat((past_value, value), dim=-2)
|
181 |
+
present_key_value = (key, value) if use_cache else None
|
182 |
+
attn_output = self._attn(query, key, value, attention_mask, head_mask)
|
183 |
+
attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head)
|
184 |
+
attn_output = self.c_proj(attn_output)
|
185 |
+
attn_output = self.resid_dropout(attn_output)
|
186 |
+
outputs = (attn_output, present_key_value)
|
187 |
+
return outputs[0]
|
gpt2_pytorch.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import wget
|
6 |
+
import json
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
GPT2_FOLDER = "./GPT2"
|
10 |
+
MODEL_FILE = "gpt2-pytorch_model.bin"
|
11 |
+
ENCODER_FILE = "encoder.json"
|
12 |
+
VOCAB_FILE = "vocab.bpe"
|
13 |
+
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
|
14 |
+
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/encoder.json"
|
15 |
+
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/vocab.bpe"
|
16 |
+
MAX_LENGTH = 1024
|
17 |
+
END_OF_TEXT_TOKEN = "<|endoftext|>"
|
18 |
+
|
19 |
+
def ensure_gpt2_files_exist():
|
20 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, MODEL_FILE)):
|
21 |
+
wget.download(MODEL_URL, out=os.path.join(GPT2_FOLDER, MODEL_FILE))
|
22 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, ENCODER_FILE)):
|
23 |
+
wget.download(ENCODER_URL, out=os.path.join(GPT2_FOLDER, ENCODER_FILE))
|
24 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, VOCAB_FILE)):
|
25 |
+
wget.download(VOCAB_URL, out=os.path.join(GPT2_FOLDER, VOCAB_FILE))
|
26 |
+
|
27 |
+
class GPT2Config:
|
28 |
+
def __init__(self, vocab_size_or_config_json_file=50257, n_positions=MAX_LENGTH, n_ctx=MAX_LENGTH, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-5, initializer_range=0.02):
|
29 |
+
self.vocab_size = vocab_size_or_config_json_file
|
30 |
+
self.n_ctx = n_ctx
|
31 |
+
self.n_positions = n_positions
|
32 |
+
self.n_embd = n_embd
|
33 |
+
self.n_layer = n_layer
|
34 |
+
self.n_head = n_head
|
35 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
36 |
+
self.initializer_range = initializer_range
|
37 |
+
|
38 |
+
class GPT2LMHeadModel(nn.Module):
|
39 |
+
def __init__(self, config):
|
40 |
+
super().__init__()
|
41 |
+
self.transformer = GPT2Model(config)
|
42 |
+
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
43 |
+
|
44 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
45 |
+
lm_logits, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
46 |
+
return lm_logits, presents
|
47 |
+
|
48 |
+
class GPT2Model(nn.Module):
|
49 |
+
def __init__(self, config):
|
50 |
+
super().__init__()
|
51 |
+
self.n_layer = config.n_layer
|
52 |
+
self.n_embd = config.n_embd
|
53 |
+
self.n_vocab = config.vocab_size
|
54 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
55 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
56 |
+
block = Block(config.n_ctx, config, scale=True)
|
57 |
+
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
58 |
+
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
59 |
+
|
60 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
61 |
+
if past is None:
|
62 |
+
past_length = 0
|
63 |
+
past = [None] * len(self.h)
|
64 |
+
else:
|
65 |
+
past_length = past[0][0].size(-2)
|
66 |
+
if position_ids is None:
|
67 |
+
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
68 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
69 |
+
|
70 |
+
input_shape = input_ids.size()
|
71 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
72 |
+
position_ids = position_ids.view(-1, position_ids.size(-1))
|
73 |
+
|
74 |
+
inputs_embeds = self.wte(input_ids)
|
75 |
+
position_embeds = self.wpe(position_ids)
|
76 |
+
if token_type_ids is not None:
|
77 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
78 |
+
token_type_embeds = self.wte(token_type_ids)
|
79 |
+
else:
|
80 |
+
token_type_embeds = 0
|
81 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
82 |
+
presents = []
|
83 |
+
for block, layer_past in zip(self.h, past):
|
84 |
+
hidden_states, present = block(hidden_states, layer_past)
|
85 |
+
presents.append(present)
|
86 |
+
hidden_states = self.ln_f(hidden_states)
|
87 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
88 |
+
return hidden_states.view(*output_shape), presents
|
89 |
+
|
90 |
+
class GPT2LMHead(nn.Module):
|
91 |
+
def __init__(self, model_embeddings_weights, config):
|
92 |
+
super().__init__()
|
93 |
+
self.n_embd = config.n_embd
|
94 |
+
self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
95 |
+
self.decoder.weight = model_embeddings_weights
|
96 |
+
|
97 |
+
def forward(self, hidden_state):
|
98 |
+
lm_logits = self.decoder(hidden_state)
|
99 |
+
return lm_logits
|
100 |
+
|
101 |
+
class Block(nn.Module):
|
102 |
+
def __init__(self, n_ctx, config, scale=False):
|
103 |
+
super().__init__()
|
104 |
+
nx = config.n_embd
|
105 |
+
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
106 |
+
self.attn = Attention(nx, n_ctx, config, scale)
|
107 |
+
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
108 |
+
self.mlp = MLP(4 * nx, config)
|
109 |
+
|
110 |
+
def forward(self, x, layer_past=None):
|
111 |
+
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
|
112 |
+
x = x + a
|
113 |
+
m = self.mlp(self.ln_2(x))
|
114 |
+
x = x + m
|
115 |
+
return x, present
|
116 |
+
|
117 |
+
class Attention(nn.Module):
|
118 |
+
def __init__(self, nx, n_ctx, config, scale=False):
|
119 |
+
super().__init__()
|
120 |
+
n_state = nx
|
121 |
+
assert n_state % config.n_head == 0
|
122 |
+
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
123 |
+
self.n_head = config.n_head
|
124 |
+
self.split_size = n_state
|
125 |
+
self.scale = scale
|
126 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
127 |
+
self.c_proj = Conv1D(n_state, nx)
|
128 |
+
|
129 |
+
def _attn(self, q, k, v):
|
130 |
+
w = torch.matmul(q, k)
|
131 |
+
if self.scale:
|
132 |
+
w = w / math.sqrt(v.size(-1))
|
133 |
+
nd, ns = w.size(-2), w.size(-1)
|
134 |
+
b = self.bias[:, :, ns - nd:ns, :ns]
|
135 |
+
w = w * b - 1e-10 * (1 - b)
|
136 |
+
w = nn.Softmax(dim=-1)(w)
|
137 |
+
return torch.matmul(w, v)
|
138 |
+
|
139 |
+
def merge_heads(self, x):
|
140 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
141 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
142 |
+
return x.view(*new_x_shape)
|
143 |
+
|
144 |
+
def split_heads(self, x, k=False):
|
145 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
146 |
+
x = x.view(*new_x_shape)
|
147 |
+
if k:
|
148 |
+
return x.permute(0, 2, 3, 1)
|
149 |
+
else:
|
150 |
+
return x.permute(0, 2, 1, 3)
|
151 |
+
|
152 |
+
def forward(self, x, layer_past=None):
|
153 |
+
x = self.c_attn(x)
|
154 |
+
query, key, value = x.split(self.split_size, dim=2)
|
155 |
+
query = self.split_heads(query)
|
156 |
+
key = self.split_heads(key, k=True)
|
157 |
+
value = self.split_heads(value)
|
158 |
+
if layer_past is not None:
|
159 |
+
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]
|
160 |
+
key = torch.cat((past_key, key), dim=-1)
|
161 |
+
value = torch.cat((past_value, value), dim=-2)
|
162 |
+
present = torch.stack((key.transpose(-2, -1), value))
|
163 |
+
a = self._attn(query, key, value)
|
164 |
+
a = self.merge_heads(a)
|
165 |
+
a = self.c_proj(a)
|
166 |
+
return a, present
|
167 |
+
|
168 |
+
class MLP(nn.Module):
|
169 |
+
def __init__(self, n_state, config):
|
170 |
+
super().__init__()
|
171 |
+
nx = config.n_embd
|
172 |
+
self.c_fc = Conv1D(n_state, nx)
|
173 |
+
self.c_proj = Conv1D(nx, n_state)
|
174 |
+
self.act = gelu
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
h = self.act(self.c_fc(x))
|
178 |
+
h2 = self.c_proj(h)
|
179 |
+
return h2
|
180 |
+
|
181 |
+
class Conv1D(nn.Module):
|
182 |
+
def __init__(self, nf, nx):
|
183 |
+
super().__init__()
|
184 |
+
self.nf = nf
|
185 |
+
w = torch.empty(nx, nf)
|
186 |
+
nn.init.normal_(w, std=0.02)
|
187 |
+
self.weight = Parameter(w)
|
188 |
+
self.bias = Parameter(torch.zeros(nf))
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
size_out = x.size()[:-1] + (self.nf,)
|
192 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
193 |
+
x = x.view(*size_out)
|
194 |
+
return x
|
195 |
+
|
196 |
+
class LayerNorm(nn.Module):
|
197 |
+
def __init__(self, hidden_size, eps=1e-12):
|
198 |
+
super().__init__()
|
199 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
200 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
201 |
+
self.variance_epsilon = eps
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
u = x.mean(-1, keepdim=True)
|
205 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
206 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
207 |
+
return self.weight * x + self.bias
|
208 |
+
|
209 |
+
def gelu(x):
|
210 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
image_to_3d_openlrm.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
|
8 |
+
IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
|
9 |
+
IMAGE_TO_3D_CONFIG = "config.json"
|
10 |
+
IMAGE_TO_3D_MODEL_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
|
11 |
+
IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
|
12 |
+
IMAGE_TO_3D_FILES_URLS = [
|
13 |
+
(IMAGE_TO_3D_MODEL_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
|
14 |
+
(IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG),
|
15 |
+
]
|
16 |
+
|
17 |
+
def ensure_image_to_3d_files_exist():
|
18 |
+
os.makedirs(IMAGE_TO_3D_FOLDER, exist_ok=True)
|
19 |
+
for url, filename in IMAGE_TO_3D_FILES_URLS:
|
20 |
+
filepath = os.path.join(IMAGE_TO_3D_FOLDER, filename)
|
21 |
+
if not os.path.exists(filepath):
|
22 |
+
wget.download(url, out=filepath)
|
23 |
+
|
24 |
+
class OpenLRM(nn.Module):
|
25 |
+
def __init__(self, num_classes):
|
26 |
+
super().__init__()
|
27 |
+
self.fc = nn.Linear(100, num_classes)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
logits = self.fc(x)
|
31 |
+
return logits
|
imagegen_vae_unet.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import wget
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
|
8 |
+
IMAGEGEN_FOLDER = "./ImageGenModel"
|
9 |
+
IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
|
10 |
+
IMAGEGEN_CONFIG = "config.json"
|
11 |
+
IMAGEGEN_MODEL_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
|
12 |
+
IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
|
13 |
+
IMAGEGEN_FILES_URLS = [
|
14 |
+
(IMAGEGEN_MODEL_URL, IMAGEGEN_MODEL_WEIGHTS),
|
15 |
+
(IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG),
|
16 |
+
]
|
17 |
+
|
18 |
+
def ensure_imagegen_files_exist():
|
19 |
+
os.makedirs(IMAGEGEN_FOLDER, exist_ok=True)
|
20 |
+
for url, filename in IMAGEGEN_FILES_URLS:
|
21 |
+
filepath = os.path.join(IMAGEGEN_FOLDER, filename)
|
22 |
+
if not os.path.exists(filepath):
|
23 |
+
wget.download(url, out=filepath)
|
24 |
+
|
25 |
+
class UNet2DConditionModelConfig:
|
26 |
+
def __init__(self, **kwargs):
|
27 |
+
self.sample_size = 64
|
28 |
+
self.layers_per_block = 2
|
29 |
+
self.block_out_channels = [320, 640, 1280, 1280]
|
30 |
+
self.downsample = [2, 2, 2, 2]
|
31 |
+
self.upsample = [2, 2, 2, 2]
|
32 |
+
self.cross_attention_dim = 768
|
33 |
+
self.act_fn = "silu"
|
34 |
+
self.norm_num_groups = 32
|
35 |
+
self.num_attention_heads = 8
|
36 |
+
for key, value in kwargs.items():
|
37 |
+
setattr(self, key, value)
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def from_dict(cls, config_dict):
|
41 |
+
return cls(**config_dict)
|
42 |
+
|
43 |
+
class UNet2DConditionModel(nn.Module):
|
44 |
+
def __init__(self, config: UNet2DConditionModelConfig):
|
45 |
+
super().__init__()
|
46 |
+
self.conv_in = nn.Conv2d(4, config.block_out_channels[0], kernel_size=3, padding=1)
|
47 |
+
self.down_blocks = nn.ModuleList([])
|
48 |
+
for i in range(len(config.block_out_channels)):
|
49 |
+
is_final_block = i == len(config.block_out_channels) - 1
|
50 |
+
downsample_factor = 1 if is_final_block else config.downsample[i]
|
51 |
+
out_channels = config.block_out_channels[i]
|
52 |
+
layers_per_block = config.layers_per_block
|
53 |
+
self.down_blocks.append(DownBlock(out_channels, layers_per_block, downsample_factor))
|
54 |
+
self.mid_block = MidBlock(config.block_out_channels[-1])
|
55 |
+
self.up_blocks = nn.ModuleList([])
|
56 |
+
reversed_block_out_channels = list(reversed(config.block_out_channels))
|
57 |
+
reversed_upsample_factors = list(reversed(config.upsample))
|
58 |
+
for i in range(len(config.block_out_channels)):
|
59 |
+
is_final_block = i == len(config.block_out_channels) - 1
|
60 |
+
upsample_factor = 1 if is_final_block else reversed_upsample_factors[i]
|
61 |
+
out_channels = reversed_block_out_channels[i]
|
62 |
+
layers_per_block = config.layers_per_block
|
63 |
+
self.up_blocks.append(UpBlock(out_channels, layers_per_block, upsample_factor))
|
64 |
+
self.norm_out = nn.GroupNorm(num_groups=config.norm_num_groups, num_channels=config.block_out_channels[0])
|
65 |
+
self.conv_norm_out = nn.Conv2d(config.block_out_channels[0], config.block_out_channels[0], kernel_size=3, padding=1)
|
66 |
+
self.conv_out = nn.Conv2d(config.block_out_channels[0], 4, kernel_size=3, padding=1)
|
67 |
+
|
68 |
+
def forward(self, sample: torch.FloatTensor, timestep: torch.IntTensor, encoder_hidden_states: torch.FloatTensor):
|
69 |
+
sample = self.conv_in(sample)
|
70 |
+
for down_block in self.down_blocks:
|
71 |
+
sample = down_block(sample)
|
72 |
+
sample = self.mid_block(sample)
|
73 |
+
for up_block in self.up_blocks:
|
74 |
+
sample = up_block(sample)
|
75 |
+
sample = self.norm_out(sample)
|
76 |
+
sample = F.silu(sample)
|
77 |
+
sample = self.conv_norm_out(sample)
|
78 |
+
sample = F.silu(sample)
|
79 |
+
sample = self.conv_out(sample)
|
80 |
+
return {"sample": sample}
|
81 |
+
|
82 |
+
class DownBlock(nn.Module):
|
83 |
+
def __init__(self, out_channels, layers_per_block, downsample_factor):
|
84 |
+
super().__init__()
|
85 |
+
self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)])
|
86 |
+
if downsample_factor > 1:
|
87 |
+
self.downsample = Downsample2D(out_channels, downsample_factor)
|
88 |
+
else:
|
89 |
+
self.downsample = nn.Identity()
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
for layer in self.layers:
|
93 |
+
x = layer(x)
|
94 |
+
x = self.downsample(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
class UpBlock(nn.Module):
|
98 |
+
def __init__(self, out_channels, layers_per_block, upsample_factor):
|
99 |
+
super().__init__()
|
100 |
+
self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)])
|
101 |
+
if upsample_factor > 1:
|
102 |
+
self.upsample = Upsample2D(out_channels, upsample_factor)
|
103 |
+
else:
|
104 |
+
self.upsample = nn.Identity()
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
for layer in self.layers:
|
108 |
+
x = layer(x)
|
109 |
+
x = self.upsample(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
class ResnetBlock(nn.Module):
|
113 |
+
def __init__(self, channels):
|
114 |
+
super().__init__()
|
115 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
116 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
117 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
118 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
119 |
+
self.residual_conv = nn.Conv2d(channels, channels, kernel_size=1)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
residual = x
|
123 |
+
x = self.norm1(x)
|
124 |
+
x = F.silu(x)
|
125 |
+
x = self.conv1(x)
|
126 |
+
x = self.norm2(x)
|
127 |
+
x = F.silu(x)
|
128 |
+
x = self.conv2(x)
|
129 |
+
return x + self.residual_conv(residual)
|
130 |
+
|
131 |
+
class MidBlock(nn.Module):
|
132 |
+
def __init__(self, channels):
|
133 |
+
super().__init__()
|
134 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
135 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
136 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
137 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
x = self.norm1(x)
|
141 |
+
x = F.silu(x)
|
142 |
+
x = self.conv1(x)
|
143 |
+
x = self.norm2(x)
|
144 |
+
x = F.silu(x)
|
145 |
+
x = self.conv2(x)
|
146 |
+
return x
|
147 |
+
|
148 |
+
class Downsample2D(nn.Module):
|
149 |
+
def __init__(self, channels, factor):
|
150 |
+
super().__init__()
|
151 |
+
self.factor = factor
|
152 |
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=factor, padding=1)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
return self.conv(x)
|
156 |
+
|
157 |
+
class Upsample2D(nn.Module):
|
158 |
+
def __init__(self, channels, factor):
|
159 |
+
super().__init__()
|
160 |
+
self.factor = factor
|
161 |
+
self.conv = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
return self.conv(x)
|
lipsync_wav2lip.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import os
|
5 |
+
|
6 |
+
LIPSYNC_FOLDER = "./LipSyncModel"
|
7 |
+
LIPSYNC_MODEL_WEIGHTS = "lipsync_expert.pth"
|
8 |
+
LIPSYNC_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Flipsync%5Fexpert%2Epth"
|
9 |
+
LIPSYNC_FILES_URLS = [
|
10 |
+
(LIPSYNC_MODEL_WEIGHTS_URL, LIPSYNC_MODEL_WEIGHTS),
|
11 |
+
]
|
12 |
+
|
13 |
+
WAV2LIP_FOLDER = "./Wav2LipModel"
|
14 |
+
WAV2LIP_MODEL_WEIGHTS = "wav2lip_gan.pth"
|
15 |
+
WAV2LIP_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Fwav2lip%5Fgan%2Epth"
|
16 |
+
WAV2LIP_FILES_URLS = [
|
17 |
+
(WAV2LIP_MODEL_WEIGHTS_URL, WAV2LIP_MODEL_WEIGHTS),
|
18 |
+
]
|
19 |
+
|
20 |
+
def ensure_lipsync_files_exist():
|
21 |
+
os.makedirs(LIPSYNC_FOLDER, exist_ok=True)
|
22 |
+
for url, filename in LIPSYNC_FILES_URLS:
|
23 |
+
filepath = os.path.join(LIPSYNC_FOLDER, filename)
|
24 |
+
if not os.path.exists(filepath):
|
25 |
+
try:
|
26 |
+
wget.download(url, out=filepath)
|
27 |
+
except Exception as e:
|
28 |
+
print(f"Warning: Download for {filename} failed, likely due to link restrictions. You may need to download it manually.")
|
29 |
+
|
30 |
+
def ensure_wav2lip_files_exist():
|
31 |
+
os.makedirs(WAV2LIP_FOLDER, exist_ok=True)
|
32 |
+
for url, filename in WAV2LIP_FILES_URLS:
|
33 |
+
filepath = os.path.join(WAV2LIP_FOLDER, filename)
|
34 |
+
if not os.path.exists(filepath):
|
35 |
+
try:
|
36 |
+
wget.download(url, out=filepath)
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Warning: Download for {filename} failed, likely due to link restrictions. You may need to download it manually.")
|
39 |
+
|
40 |
+
|
41 |
+
class LipSyncModel(nn.Module):
|
42 |
+
def __init__(self, num_classes):
|
43 |
+
super().__init__()
|
44 |
+
self.fc = nn.Linear(100, num_classes)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
logits = self.fc(x)
|
48 |
+
return logits
|
49 |
+
|
50 |
+
class Wav2LipModel(nn.Module):
|
51 |
+
def __init__(self, num_classes):
|
52 |
+
super().__init__()
|
53 |
+
self.fc = nn.Linear(100, num_classes)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
logits = self.fc(x)
|
57 |
+
return logits
|
musicgen_torch.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
import wget
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
MUSICGEN_FOLDER = "./MusicGenModel"
|
10 |
+
MUSICGEN_MODEL_NAME = "melody"
|
11 |
+
MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
12 |
+
MUSICGEN_CONFIG = "config.json"
|
13 |
+
MUSICGEN_SAMPLE_RATE = 32000
|
14 |
+
MUSICGEN_DURATION = 8
|
15 |
+
MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
|
16 |
+
MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
|
17 |
+
MUSICGEN_FILES_URLS = [
|
18 |
+
(MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
|
19 |
+
(MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG),
|
20 |
+
]
|
21 |
+
|
22 |
+
def ensure_musicgen_files_exist():
|
23 |
+
os.makedirs(MUSICGEN_FOLDER, exist_ok=True)
|
24 |
+
for url, filename in MUSICGEN_FILES_URLS:
|
25 |
+
filepath = os.path.join(MUSICGEN_FOLDER, filename)
|
26 |
+
if not os.path.exists(filepath):
|
27 |
+
wget.download(url, out=filepath)
|
28 |
+
|
29 |
+
class MusicGenModel(nn.Module):
|
30 |
+
def __init__(self, num_classes):
|
31 |
+
super().__init__()
|
32 |
+
self.fc = nn.Linear(100, num_classes)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
logits = self.fc(x)
|
36 |
+
return logits
|
sentiment_roberta.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
SENTIMENT_FOLDER = "./SentimentModel"
|
8 |
+
SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
|
9 |
+
SENTIMENT_VOCAB = "sentiment_vocab.json"
|
10 |
+
SENTIMENT_CONFIG = "config.json"
|
11 |
+
SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/pytorch_model.bin"
|
12 |
+
SENTIMENT_VOCAB_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/vocab.json"
|
13 |
+
SENTIMENT_CONFIG_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/config.json"
|
14 |
+
SENTIMENT_FILES_URLS = [
|
15 |
+
(SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
|
16 |
+
(SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
|
17 |
+
(SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG),
|
18 |
+
]
|
19 |
+
|
20 |
+
def ensure_sentiment_files_exist():
|
21 |
+
os.makedirs(SENTIMENT_FOLDER, exist_ok=True)
|
22 |
+
for url, filename in SENTIMENT_FILES_URLS:
|
23 |
+
filepath = os.path.join(SENTIMENT_FOLDER, filename)
|
24 |
+
if not os.path.exists(filepath):
|
25 |
+
wget.download(url, out=filepath)
|
26 |
+
|
27 |
+
class RobertaForSequenceClassification(nn.Module):
|
28 |
+
def __init__(self, num_labels):
|
29 |
+
super().__init__()
|
30 |
+
self.dense = nn.Linear(768, 768)
|
31 |
+
self.dropout = nn.Dropout(0.1)
|
32 |
+
self.out_proj = nn.Linear(768, num_labels)
|
33 |
+
|
34 |
+
def forward(self, sequence_output):
|
35 |
+
x = sequence_output[:, 0, :]
|
36 |
+
x = self.dropout(x)
|
37 |
+
x = self.dense(x)
|
38 |
+
x = torch.tanh(x)
|
39 |
+
x = self.dropout(x)
|
40 |
+
x = self.out_proj(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
class RobertaModel(nn.Module):
|
44 |
+
def __init__(self, config):
|
45 |
+
super().__init__()
|
46 |
+
self.embeddings = RobertaEmbeddings(config)
|
47 |
+
self.encoder = RobertaEncoder(config)
|
48 |
+
|
49 |
+
def forward(self, input_ids, attention_mask=None):
|
50 |
+
embedding_output = self.embeddings(input_ids)
|
51 |
+
encoder_outputs = self.encoder(embedding_output, attention_mask=attention_mask)
|
52 |
+
return (encoder_outputs[0], )
|
53 |
+
|
54 |
+
class RobertaEmbeddings(nn.Module):
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
59 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
60 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
61 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
62 |
+
self.position_ids = torch.arange(config.max_position_embeddings).expand((1, -1))
|
63 |
+
|
64 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
65 |
+
input_shape = input_ids.size()
|
66 |
+
seq_length = input_shape[1]
|
67 |
+
if position_ids is None:
|
68 |
+
position_ids = self.position_ids[:, :seq_length]
|
69 |
+
if token_type_ids is None:
|
70 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
71 |
+
|
72 |
+
input_embeddings = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) + self.token_type_embeddings(token_type_ids)
|
73 |
+
embeddings = self.LayerNorm(embeddings)
|
74 |
+
embeddings = self.dropout(embeddings)
|
75 |
+
return embeddings
|
76 |
+
|
77 |
+
class RobertaEncoder(nn.Module):
|
78 |
+
def __init__(self, config):
|
79 |
+
super().__init__()
|
80 |
+
self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
81 |
+
|
82 |
+
def forward(self, hidden_states, attention_mask=None):
|
83 |
+
all_encoder_layers = []
|
84 |
+
for layer_module in self.layer:
|
85 |
+
hidden_states = layer_module(hidden_states, attention_mask=attention_mask)
|
86 |
+
all_encoder_layers.append(hidden_states)
|
87 |
+
return (hidden_states, all_encoder_layers)
|
88 |
+
|
89 |
+
class RobertaLayer(nn.Module):
|
90 |
+
def __init__(self, config):
|
91 |
+
super().__init__()
|
92 |
+
self.attention = RobertaAttention(config)
|
93 |
+
self.intermediate = RobertaIntermediate(config)
|
94 |
+
self.output = RobertaOutput(config)
|
95 |
+
|
96 |
+
def forward(self, hidden_states, attention_mask=None):
|
97 |
+
attention_output = self.attention(hidden_states, attention_mask=attention_mask)
|
98 |
+
intermediate_output = self.intermediate(attention_output)
|
99 |
+
layer_output = self.output(intermediate_output, attention_output)
|
100 |
+
return layer_output
|
101 |
+
|
102 |
+
class RobertaAttention(nn.Module):
|
103 |
+
def __init__(self, config):
|
104 |
+
super().__init__()
|
105 |
+
self.self_attn = RobertaSelfAttention(config)
|
106 |
+
self.output = RobertaSelfOutput(config)
|
107 |
+
|
108 |
+
def forward(self, hidden_states, attention_mask=None):
|
109 |
+
self_output = self.self_attn(hidden_states, attention_mask=attention_mask)
|
110 |
+
attention_output = self.output(self_output, hidden_states)
|
111 |
+
return attention_output
|
112 |
+
|
113 |
+
class RobertaSelfAttention(nn.Module):
|
114 |
+
def __init__(self, config):
|
115 |
+
super().__init__()
|
116 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
117 |
+
raise ValueError(
|
118 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
119 |
+
f"heads ({config.num_attention_heads})"
|
120 |
+
)
|
121 |
+
|
122 |
+
self.num_attention_heads = config.num_attention_heads
|
123 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
124 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
125 |
+
|
126 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
127 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
128 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
129 |
+
|
130 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
131 |
+
|
132 |
+
def transpose_for_scores(self, x):
|
133 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
134 |
+
x = x.view(*new_x_shape)
|
135 |
+
return x.permute(0, 2, 1, 3)
|
136 |
+
|
137 |
+
def forward(self, hidden_states, attention_mask=None):
|
138 |
+
mixed_query_layer = self.query(hidden_states)
|
139 |
+
mixed_key_layer = self.key(hidden_states)
|
140 |
+
mixed_value_layer = self.value(hidden_states)
|
141 |
+
|
142 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
143 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
144 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
145 |
+
|
146 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
147 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
148 |
+
if attention_mask is not None:
|
149 |
+
attention_scores = attention_scores + attention_mask
|
150 |
+
|
151 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
152 |
+
attention_probs = self.dropout(attention_probs)
|
153 |
+
|
154 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
155 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
156 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
157 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
158 |
+
return context_layer
|
159 |
+
|
160 |
+
class RobertaSelfOutput(nn.Module):
|
161 |
+
def __init__(self, config):
|
162 |
+
super().__init__()
|
163 |
+
self.dense = nn.Linear(config.all_head_size, config.hidden_size)
|
164 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
165 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
166 |
+
|
167 |
+
def forward(self, hidden_states, input_tensor):
|
168 |
+
hidden_states = self.dense(hidden_states)
|
169 |
+
hidden_states = self.dropout(hidden_states)
|
170 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
171 |
+
return hidden_states
|
172 |
+
|
173 |
+
class RobertaIntermediate(nn.Module):
|
174 |
+
def __init__(self, config):
|
175 |
+
super().__init__()
|
176 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
177 |
+
self.intermediate_act_fn = gelu
|
178 |
+
|
179 |
+
def forward(self, hidden_states):
|
180 |
+
hidden_states = self.dense(hidden_states)
|
181 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
182 |
+
return hidden_states
|
183 |
+
|
184 |
+
class RobertaOutput(nn.Module):
|
185 |
+
def __init__(self, config):
|
186 |
+
super().__init__()
|
187 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
188 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
189 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
190 |
+
|
191 |
+
def forward(self, hidden_states, input_tensor):
|
192 |
+
hidden_states = self.dense(hidden_states)
|
193 |
+
hidden_states = self.dropout(hidden_states)
|
194 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
195 |
+
return hidden_states
|
stt_wav2vec2.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
import wget
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
STT_FOLDER = "./STTModel"
|
10 |
+
STT_MODEL_NAME = "wav2vec2"
|
11 |
+
STT_MODEL_WEIGHTS = "pytorch_model.bin"
|
12 |
+
STT_CONFIG = "config.json"
|
13 |
+
STT_VOCAB = "vocab.json"
|
14 |
+
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
|
15 |
+
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
|
16 |
+
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
17 |
+
STT_FILES_URLS = [
|
18 |
+
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
|
19 |
+
(STT_CONFIG_URL, STT_CONFIG),
|
20 |
+
(STT_VOCAB_URL, STT_VOCAB),
|
21 |
+
]
|
22 |
+
|
23 |
+
def ensure_stt_files_exist():
|
24 |
+
os.makedirs(STT_FOLDER, exist_ok=True)
|
25 |
+
for url, filename in STT_FILES_URLS:
|
26 |
+
filepath = os.path.join(STT_FOLDER, filename)
|
27 |
+
if not os.path.exists(filepath):
|
28 |
+
wget.download(url, out=filepath)
|
29 |
+
|
30 |
+
class Wav2Vec2ForCTC(nn.Module):
|
31 |
+
def __init__(self, num_classes):
|
32 |
+
super().__init__()
|
33 |
+
self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2, padding=2)
|
34 |
+
self.relu1 = nn.ReLU()
|
35 |
+
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
|
36 |
+
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
|
37 |
+
self.relu2 = nn.ReLU()
|
38 |
+
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
39 |
+
self.fc = nn.Linear(32 * 39 * 40, num_classes) # Adjusted input size
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.pool1(self.relu1(self.conv1(x)))
|
43 |
+
x = self.pool2(self.relu2(self.conv2(x)))
|
44 |
+
x = x.view(x.size(0), -1)
|
45 |
+
logits = self.fc(x)
|
46 |
+
return logits
|
summarization_bart.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
SUMMARIZATION_FOLDER = "./SummarizationModel"
|
8 |
+
SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
|
9 |
+
SUMMARIZATION_CONFIG = "config.json"
|
10 |
+
SUMMARIZATION_VOCAB = "vocab.json"
|
11 |
+
SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
|
12 |
+
SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
|
13 |
+
SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
|
14 |
+
SUMMARIZATION_FILES_URLS = [
|
15 |
+
(SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
|
16 |
+
(SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
|
17 |
+
(SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB),
|
18 |
+
]
|
19 |
+
|
20 |
+
def ensure_summarization_files_exist():
|
21 |
+
os.makedirs(SUMMARIZATION_FOLDER, exist_ok=True)
|
22 |
+
for url, filename in SUMMARIZATION_FILES_URLS:
|
23 |
+
filepath = os.path.join(SUMMARIZATION_FOLDER, filename)
|
24 |
+
if not os.path.exists(filepath):
|
25 |
+
wget.download(url, out=filepath)
|
26 |
+
|
27 |
+
class BartForConditionalGeneration(nn.Module):
|
28 |
+
def __init__(self, num_classes):
|
29 |
+
super().__init__()
|
30 |
+
self.fc = nn.Linear(100, num_classes)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
logits = self.fc(x)
|
34 |
+
return logits
|
text_to_video_clip4clip.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
|
7 |
+
TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
|
8 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS = "pytorch_model.bin"
|
9 |
+
TEXT_TO_VIDEO_CONFIG = "config.json"
|
10 |
+
TEXT_TO_VIDEO_VOCAB = "vocab.json"
|
11 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/pytorch_model.bin"
|
12 |
+
TEXT_TO_VIDEO_CONFIG_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/config.json"
|
13 |
+
TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/vocab.json"
|
14 |
+
TEXT_TO_VIDEO_FILES_URLS = [
|
15 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
16 |
+
(TEXT_TO_VIDEO_CONFIG_URL, TEXT_TO_VIDEO_CONFIG),
|
17 |
+
(TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
|
18 |
+
]
|
19 |
+
|
20 |
+
def ensure_text_to_video_files_exist():
|
21 |
+
os.makedirs(TEXT_TO_VIDEO_FOLDER, exist_ok=True)
|
22 |
+
for url, filename in TEXT_TO_VIDEO_FILES_URLS:
|
23 |
+
filepath = os.path.join(TEXT_TO_VIDEO_FOLDER, filename)
|
24 |
+
if not os.path.exists(filepath):
|
25 |
+
wget.download(url, out=filepath)
|
26 |
+
|
27 |
+
class Clip4ClipModel(nn.Module):
|
28 |
+
def __init__(self, num_classes):
|
29 |
+
super().__init__()
|
30 |
+
self.fc = nn.Linear(100, num_classes)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
logits = self.fc(x)
|
34 |
+
return logits
|
translation_mbart.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import wget
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sentencepiece as spm
|
7 |
+
import re
|
8 |
+
|
9 |
+
TRANSLATION_FOLDER = "./TranslationModel"
|
10 |
+
TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
|
11 |
+
TRANSLATION_MODEL_CONFIG_FILE = "config.json"
|
12 |
+
TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
|
13 |
+
TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
|
14 |
+
TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
|
15 |
+
TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
16 |
+
TRANSLATION_MODEL_FILES_URLS = [
|
17 |
+
(TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
|
18 |
+
(TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
|
19 |
+
(TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
|
20 |
+
]
|
21 |
+
TRANSLATION_SPM_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
22 |
+
TRANSLATION_SPM = "sentencepiece.bpe.model"
|
23 |
+
|
24 |
+
def ensure_translation_files_exist():
|
25 |
+
os.makedirs(TRANSLATION_FOLDER, exist_ok=True)
|
26 |
+
for url, filename in TRANSLATION_MODEL_FILES_URLS:
|
27 |
+
filepath = os.path.join(TRANSLATION_FOLDER, filename)
|
28 |
+
if not os.path.exists(filepath):
|
29 |
+
wget.download(url, out=filepath)
|
30 |
+
filepath_spm = os.path.join(TRANSLATION_FOLDER, TRANSLATION_SPM)
|
31 |
+
if not os.path.exists(filepath_spm):
|
32 |
+
wget.download(TRANSLATION_SPM_URL, out=filepath_spm)
|
33 |
+
|
34 |
+
class MBartConfig:
|
35 |
+
def __init__(self, vocab_size, hidden_size=1024, num_hidden_layers=12, num_attention_heads=16, intermediate_size=4096, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, layer_norm_eps=1e-05, initializer_range=0.02, pad_token_id=1, bos_token_id=0, eos_token_id=2, n_positions=1024, n_ctx=1024, decoder_layers=12, decoder_attention_heads=16, decoder_ffn_dim=4096, encoder_layers=12, encoder_attention_heads=16, encoder_ffn_dim=4096, **kwargs):
|
36 |
+
self.vocab_size = vocab_size
|
37 |
+
self.hidden_size = hidden_size
|
38 |
+
self.num_hidden_layers = num_hidden_layers
|
39 |
+
self.num_attention_heads = num_attention_heads
|
40 |
+
self.intermediate_size = intermediate_size
|
41 |
+
self.hidden_act = hidden_act
|
42 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
43 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
44 |
+
self.layer_norm_eps = layer_norm_eps
|
45 |
+
self.initializer_range = initializer_range
|
46 |
+
self.pad_token_id = pad_token_id
|
47 |
+
self.bos_token_id = bos_token_id
|
48 |
+
self.eos_token_id = eos_token_id
|
49 |
+
self.n_positions = n_positions
|
50 |
+
self.n_ctx = n_ctx
|
51 |
+
self.decoder_layers = decoder_layers
|
52 |
+
self.decoder_attention_heads = decoder_attention_heads
|
53 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
54 |
+
self.encoder_layers = encoder_layers
|
55 |
+
self.encoder_attention_heads = encoder_attention_heads
|
56 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
57 |
+
for key, value in kwargs.items():
|
58 |
+
setattr(self, key, value)
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_dict(cls, config_dict):
|
62 |
+
return cls(**config_dict)
|
63 |
+
|
64 |
+
class MBartForConditionalGeneration(nn.Module):
|
65 |
+
def __init__(self, config):
|
66 |
+
super().__init__()
|
67 |
+
self.model = MBartModel(config)
|
68 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
69 |
+
self.final_logits_bias = nn.Parameter(torch.zeros((1, config.vocab_size)))
|
70 |
+
|
71 |
+
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None):
|
72 |
+
outputs = self.model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)
|
73 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
74 |
+
return lm_logits
|
75 |
+
|
76 |
+
class MBartModel(nn.Module):
|
77 |
+
def __init__(self, config):
|
78 |
+
super().__init__()
|
79 |
+
self.encoder = MBartEncoder(config)
|
80 |
+
self.decoder = MBartDecoder(config)
|
81 |
+
|
82 |
+
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None):
|
83 |
+
encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
|
84 |
+
decoder_outputs = self.decoder(decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
85 |
+
return decoder_outputs
|
86 |
+
|
87 |
+
class MBartEncoder(nn.Module):
|
88 |
+
def __init__(self, config):
|
89 |
+
super().__init__()
|
90 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
91 |
+
self.embed_positions = MBartSinusoidalPositionalEmbedding(config.hidden_size, config.pad_token_id)
|
92 |
+
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
93 |
+
self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
|
94 |
+
|
95 |
+
def forward(self, input_ids, attention_mask=None):
|
96 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
97 |
+
position_embeddings = self.embed_positions(input_ids)
|
98 |
+
embeddings = inputs_embeds + position_embeddings
|
99 |
+
embeddings = self.layernorm_embedding(embeddings)
|
100 |
+
encoder_states = embeddings
|
101 |
+
all_encoder_layers = []
|
102 |
+
for layer_module in self.layers:
|
103 |
+
encoder_states = layer_module(encoder_states, encoder_padding_mask=attention_mask)
|
104 |
+
all_encoder_layers.append(encoder_states)
|
105 |
+
return (encoder_states, all_encoder_layers)
|
106 |
+
|
107 |
+
class MBartDecoder(nn.Module):
|
108 |
+
def __init__(self, config):
|
109 |
+
super().__init__()
|
110 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
111 |
+
self.embed_positions = MBartSinusoidalPositionalEmbedding(config.hidden_size, config.pad_token_id)
|
112 |
+
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
113 |
+
self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
|
114 |
+
|
115 |
+
def forward(self, decoder_input_ids, encoder_outputs, decoder_attention_mask=None):
|
116 |
+
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
117 |
+
position_embeddings = self.embed_positions(decoder_input_ids)
|
118 |
+
embeddings = inputs_embeds + position_embeddings
|
119 |
+
embeddings = self.layernorm_embedding(embeddings)
|
120 |
+
decoder_states = embeddings
|
121 |
+
all_decoder_layers = []
|
122 |
+
all_cross_attention_layers = []
|
123 |
+
for layer_module in self.layers:
|
124 |
+
decoder_states, cross_attn_weights = layer_module(decoder_states, encoder_outputs[0], decoder_padding_mask=decoder_attention_mask, encoder_padding_mask=encoder_outputs[0])
|
125 |
+
all_decoder_layers.append(decoder_states)
|
126 |
+
all_cross_attention_layers.append(cross_attn_weights)
|
127 |
+
return (decoder_states, all_decoder_layers, all_cross_attention_layers)
|
128 |
+
|
129 |
+
class MBartSinusoidalPositionalEmbedding(nn.Module):
|
130 |
+
def __init__(self, embedding_dim, padding_idx):
|
131 |
+
super().__init__()
|
132 |
+
self.embedding_dim = embedding_dim
|
133 |
+
self.padding_idx = padding_idx
|
134 |
+
|
135 |
+
def forward(self, input_ids):
|
136 |
+
seq_len = input_ids.size(1)
|
137 |
+
positions = torch.arange(self.padding_idx + 1, seq_len + self.padding_idx + 1, dtype=torch.long, device=input_ids.device)
|
138 |
+
return self.get_embedding(positions)
|
139 |
+
|
140 |
+
def get_embedding(self, positions):
|
141 |
+
half_dim = self.embedding_dim // 2
|
142 |
+
emb = math.log(10000) / (half_dim - 1)
|
143 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=positions.device) * -emb)
|
144 |
+
emb = torch.outer(positions.float(), emb)
|
145 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
146 |
+
if self.embedding_dim % 2 == 1:
|
147 |
+
emb = F.pad(emb, (0, 1, 0, 0))
|
148 |
+
return emb
|
149 |
+
|
150 |
+
class MBartEncoderLayer(nn.Module):
|
151 |
+
def __init__(self, config):
|
152 |
+
super().__init__()
|
153 |
+
self.self_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.encoder_attention_heads)
|
154 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
155 |
+
self.fc1 = nn.Linear(config.hidden_size, config.encoder_ffn_dim)
|
156 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.hidden_size)
|
157 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
158 |
+
|
159 |
+
def forward(self, hidden_states, encoder_padding_mask=None):
|
160 |
+
residual = hidden_states
|
161 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
162 |
+
hidden_states, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attention_mask=encoder_padding_mask)
|
163 |
+
hidden_states = residual + hidden_states
|
164 |
+
residual = hidden_states
|
165 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
166 |
+
hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
|
167 |
+
hidden_states = residual + hidden_states
|
168 |
+
return hidden_states
|
169 |
+
|
170 |
+
class MBartDecoderLayer(nn.Module):
|
171 |
+
def __init__(self, config):
|
172 |
+
super().__init__()
|
173 |
+
self.self_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads)
|
174 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
175 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
176 |
+
self.encoder_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads)
|
177 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
178 |
+
self.fc1 = nn.Linear(config.hidden_size, config.decoder_ffn_dim)
|
179 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, config.hidden_size)
|
180 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
181 |
+
|
182 |
+
def forward(self, hidden_states, encoder_hidden_states, decoder_padding_mask=None, encoder_padding_mask=None):
|
183 |
+
residual = hidden_states
|
184 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
185 |
+
hidden_states, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attention_mask=decoder_padding_mask)
|
186 |
+
hidden_states = residual + hidden_states
|
187 |
+
residual = hidden_states
|
188 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
189 |
+
hidden_states, cross_attn_weights = self.encoder_attn(hidden_states, encoder_hidden_states, encoder_hidden_states, attention_mask=encoder_padding_mask)
|
190 |
+
hidden_states = residual + hidden_states
|
191 |
+
residual = hidden_states
|
192 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
193 |
+
hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
|
194 |
+
hidden_states = residual + hidden_states
|
195 |
+
return hidden_states, cross_attn_weights
|
196 |
+
|
197 |
+
class MBartAttention(nn.Module):
|
198 |
+
def __init__(self, config, embed_dim, num_heads):
|
199 |
+
super().__init__()
|
200 |
+
self.embed_dim = embed_dim
|
201 |
+
self.num_heads = num_heads
|
202 |
+
self.head_dim = embed_dim // num_heads
|
203 |
+
self.scaling = self.head_dim ** -0.5
|
204 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
205 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
206 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
207 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
208 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
209 |
+
|
210 |
+
def _shape(self, tensor, seq_len, bsz):
|
211 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
212 |
+
|
213 |
+
def forward(self, query, key, value, attention_mask=None):
|
214 |
+
bsz, tgt_len, _ = query.size()
|
215 |
+
bsz, src_len, _ = key.size()
|
216 |
+
query = self.q_proj(query)
|
217 |
+
key = self.k_proj(key)
|
218 |
+
value = self.v_proj(value)
|
219 |
+
query = self._shape(query, tgt_len, bsz)
|
220 |
+
key = self._shape(key, src_len, bsz)
|
221 |
+
value = self._shape(value, src_len, bsz)
|
222 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * self.scaling
|
223 |
+
|
224 |
+
if attention_mask is not None:
|
225 |
+
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
|
226 |
+
attn_weights = attn_weights + attention_mask
|
227 |
+
|
228 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
229 |
+
attn_weights = self.dropout(attn_weights)
|
230 |
+
attn_output = torch.matmul(attn_weights, value)
|
231 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, self.embed_dim)
|
232 |
+
attn_output = self.out_proj(attn_output)
|
233 |
+
return attn_output, attn_weights
|
234 |
+
|
235 |
+
class MBartTokenizer:
|
236 |
+
def __init__(self, sentencepiece_processor):
|
237 |
+
self.sp = sentencepiece_processor
|
238 |
+
self.pad_token = "<pad>"
|
239 |
+
self.bos_token = "<s>"
|
240 |
+
self.eos_token = "</s>"
|
241 |
+
self.pad_token_id = 1
|
242 |
+
self.bos_token_id = 0
|
243 |
+
self.eos_token_id = 2
|
244 |
+
self.model_max_length = 1024
|
245 |
+
|
246 |
+
def __call__(self, text, return_tensors="pt", padding=True, truncation=True, max_length=None, src_lang="en_XX", tgt_lang="es_XX", **kwargs):
|
247 |
+
max_length = max_length if max_length is not None else self.model_max_length
|
248 |
+
self.sp.SetEncodeExtraOptions("bos:<s>,eos:</s>")
|
249 |
+
input_ids = self.sp.EncodeAsIds(f"{src_lang} {text}")
|
250 |
+
if truncation and len(input_ids) > max_length:
|
251 |
+
input_ids = input_ids[:max_length]
|
252 |
+
if padding:
|
253 |
+
input_ids += [self.pad_token_id] * (max_length - len(input_ids))
|
254 |
+
if return_tensors == "pt":
|
255 |
+
return {"input_ids": torch.tensor([input_ids]), "attention_mask": torch.ones(len(input_ids)).unsqueeze(0)}
|
256 |
+
return input_ids
|
257 |
+
|
258 |
+
def batch_decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, target_lang="es_XX"):
|
259 |
+
decoded_texts = []
|
260 |
+
for ids in token_ids:
|
261 |
+
text = self.sp.DecodeIds(list(ids))
|
262 |
+
if skip_special_tokens:
|
263 |
+
text = re.sub(r'(<s>|</s>|<pad>)', '', text).strip()
|
264 |
+
if clean_up_tokenization_spaces:
|
265 |
+
text = text.replace(' ', ' ').strip()
|
266 |
+
decoded_texts.append(text.replace(f"{target_lang} ", ""))
|
267 |
+
return decoded_texts
|
tts_vits.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
import wget
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
TTS_FOLDER = "./TTSModel"
|
10 |
+
TTS_MODEL_NAME = "vits"
|
11 |
+
TTS_MODEL_CONFIG = "config.json"
|
12 |
+
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
|
13 |
+
TTS_VOCAB = "vocab.json"
|
14 |
+
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
|
15 |
+
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
|
16 |
+
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
|
17 |
+
TTS_FILES_URLS = [
|
18 |
+
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
|
19 |
+
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
|
20 |
+
(TTS_VOCAB_URL, TTS_VOCAB),
|
21 |
+
]
|
22 |
+
|
23 |
+
def ensure_tts_files_exist():
|
24 |
+
os.makedirs(TTS_FOLDER, exist_ok=True)
|
25 |
+
for url, filename in TTS_FILES_URLS:
|
26 |
+
filepath = os.path.join(TTS_FOLDER, filename)
|
27 |
+
if not os.path.exists(filepath):
|
28 |
+
wget.download(url, out=filepath)
|
29 |
+
|
30 |
+
class VITS(nn.Module):
|
31 |
+
def __init__(self, spec_channels, segment_size, num_speakers, num_languages, num_symbols):
|
32 |
+
super().__init__()
|
33 |
+
self.spec_channels = spec_channels
|
34 |
+
self.segment_size = segment_size
|
35 |
+
self.num_speakers = num_speakers
|
36 |
+
self.num_languages = num_languages
|
37 |
+
self.num_symbols = num_symbols
|
38 |
+
self.embedding = nn.Embedding(num_symbols, 192)
|
39 |
+
self.decoder = Generator(spec_channels)
|
40 |
+
|
41 |
+
def forward(self, text):
|
42 |
+
x = self.embedding(text)
|
43 |
+
audio = self.decoder(x)
|
44 |
+
return audio
|
45 |
+
|
46 |
+
class Generator(nn.Module):
|
47 |
+
def __init__(self, spec_channels):
|
48 |
+
super().__init__()
|
49 |
+
self.spec_channels = spec_channels
|
50 |
+
self.initial_conv = nn.ConvTranspose2d(192, spec_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
|
51 |
+
self.final_conv = nn.Conv2d(spec_channels, 1, kernel_size=(7, 7), padding=(3, 3))
|
52 |
+
|
53 |
+
def forward(self, encoder_outputs):
|
54 |
+
x = encoder_outputs.unsqueeze(2)
|
55 |
+
x = self.initial_conv(x)
|
56 |
+
x = self.final_conv(x)
|
57 |
+
return x.squeeze(1)
|