Upload 4 files
Browse files- llm_chat.js +636 -0
- sentencepiece.js +0 -0
- tvmjs.bundle.js +16 -11
llm_chat.js
ADDED
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Helper to keep track of history conversations.
|
3 |
+
*/
|
4 |
+
class Conversation {
|
5 |
+
constructor(config) {
|
6 |
+
this.system = config.system;
|
7 |
+
this.roles = config.roles;
|
8 |
+
this.offset = config.offset;
|
9 |
+
this.seps = config.seps;
|
10 |
+
this.convId = null;
|
11 |
+
this.contextWindowStart = 0;
|
12 |
+
}
|
13 |
+
|
14 |
+
/**
|
15 |
+
* Get prompt arrays with the first one as system.
|
16 |
+
*
|
17 |
+
* @returns The prompt array.
|
18 |
+
*/
|
19 |
+
getPromptArray() {
|
20 |
+
if (this.seps.length == 0) {
|
21 |
+
throw Error("Need seps to work")
|
22 |
+
}
|
23 |
+
let ret = [this.system + this.seps[0]];
|
24 |
+
|
25 |
+
for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
|
26 |
+
const item = tvmjsGlobalEnv.workerHistoryMsg[i];
|
27 |
+
const role = item[0];
|
28 |
+
const message = item[1];
|
29 |
+
if (message !== undefined && message != "") {
|
30 |
+
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
31 |
+
} else {
|
32 |
+
ret.push(role + ":");
|
33 |
+
}
|
34 |
+
}
|
35 |
+
return ret;
|
36 |
+
}
|
37 |
+
|
38 |
+
/**
|
39 |
+
* Get prompt arrays that has not been fed as input
|
40 |
+
*
|
41 |
+
* @returns The prompt array.
|
42 |
+
*/
|
43 |
+
getPromptArrayUnproccessed() {
|
44 |
+
if (this.seps.length == 0) {
|
45 |
+
throw Error("Need seps to work")
|
46 |
+
}
|
47 |
+
if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) {
|
48 |
+
throw Error("needs to call getLastPromptArray for the first message");
|
49 |
+
}
|
50 |
+
let ret = [this.seps[this.seps.length - 1]];
|
51 |
+
for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
|
52 |
+
const item = tvmjsGlobalEnv.workerHistoryMsg[i];
|
53 |
+
const role = item[0];
|
54 |
+
const message = item[1];
|
55 |
+
if (message !== undefined && message != "") {
|
56 |
+
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
57 |
+
} else {
|
58 |
+
ret.push(role + ":");
|
59 |
+
}
|
60 |
+
}
|
61 |
+
return ret;
|
62 |
+
|
63 |
+
}
|
64 |
+
|
65 |
+
/**
|
66 |
+
* Get last prompt array with prefix as system.
|
67 |
+
*
|
68 |
+
* @returns The prompt array.
|
69 |
+
*/
|
70 |
+
getLastPromptArray() {
|
71 |
+
if (this.seps.length == 0) {
|
72 |
+
throw Error("Need seps to work")
|
73 |
+
}
|
74 |
+
let ret = [this.system + this.seps[0]];
|
75 |
+
|
76 |
+
for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
|
77 |
+
const item = tvmjsGlobalEnv.workerHistoryMsg[i];
|
78 |
+
const role = item[0];
|
79 |
+
const message = item[1];
|
80 |
+
if (message !== undefined && message != "") {
|
81 |
+
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
|
82 |
+
} else {
|
83 |
+
ret.push(role + ":");
|
84 |
+
}
|
85 |
+
}
|
86 |
+
return ret;
|
87 |
+
}
|
88 |
+
|
89 |
+
reset() {
|
90 |
+
tvmjsGlobalEnv.workerHistoryMsg = [];
|
91 |
+
this.covId = null
|
92 |
+
}
|
93 |
+
|
94 |
+
getStopStr() {
|
95 |
+
return this.seps[this.seps.length - 1];
|
96 |
+
}
|
97 |
+
|
98 |
+
appendMessage(role, message) {
|
99 |
+
tvmjsGlobalEnv.workerHistoryMsg.push([role, message]);
|
100 |
+
}
|
101 |
+
|
102 |
+
switchConversation(message) {
|
103 |
+
tvmjsGlobalEnv.workerHistoryMsg = message
|
104 |
+
this.covId = tvmjsGlobalEnv.covId
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
function defaultConversation(maxWindowLength = 2048) {
|
109 |
+
return new Conversation({
|
110 |
+
system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.",
|
111 |
+
roles: ["user", "assistant"],
|
112 |
+
maxWindowLength: maxWindowLength,
|
113 |
+
offset: 0,
|
114 |
+
seps: [" ", "</s>"],
|
115 |
+
});
|
116 |
+
};
|
117 |
+
|
118 |
+
class LLMChatPipeline {
|
119 |
+
constructor(tvm, tokenizer, cacheMetadata, config) {
|
120 |
+
if (cacheMetadata == undefined) {
|
121 |
+
throw Error("Expect cacheMetadata");
|
122 |
+
}
|
123 |
+
this.tvm = tvm;
|
124 |
+
this.logger = console.log;
|
125 |
+
this.tokenizer = tokenizer;
|
126 |
+
this.bosTokenId = 1;
|
127 |
+
this.eosTokenId = 2;
|
128 |
+
|
129 |
+
this.maxWindowLength = config.maxWindowLength;
|
130 |
+
this.maxGenLength = config.maxGenLength;
|
131 |
+
this.meanGenLength = config.meanGenLength;
|
132 |
+
this.streamInterval = 1;
|
133 |
+
|
134 |
+
this.decodingTotalTime = 0;
|
135 |
+
this.decodingTotalTokens = 0;
|
136 |
+
this.encodingTotalTime = 0;
|
137 |
+
this.encodingTotalTokens = 0;
|
138 |
+
|
139 |
+
this.conversation = defaultConversation(this.maxWindowLength);
|
140 |
+
|
141 |
+
this.device = this.tvm.webgpu();
|
142 |
+
this.vm = this.tvm.detachFromCurrentScope(
|
143 |
+
this.tvm.createVirtualMachine(this.device)
|
144 |
+
);
|
145 |
+
this.encoding = this.tvm.detachFromCurrentScope(
|
146 |
+
this.vm.getFunction("encoding")
|
147 |
+
);
|
148 |
+
this.decoding = this.tvm.detachFromCurrentScope(
|
149 |
+
this.vm.getFunction("decoding")
|
150 |
+
);
|
151 |
+
this.params = this.tvm.detachFromCurrentScope(
|
152 |
+
this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize)
|
153 |
+
);
|
154 |
+
const fcreateCache = this.vm.getFunction("create_kv_cache");
|
155 |
+
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
|
156 |
+
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
|
157 |
+
);
|
158 |
+
|
159 |
+
// use extern config for now
|
160 |
+
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
|
161 |
+
// fill with pad token
|
162 |
+
this.logitsOnCPU = undefined;
|
163 |
+
|
164 |
+
this.kvCacheLength = 0;
|
165 |
+
this.clearCache = true
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
dispose() {
|
170 |
+
// note: tvm instance is not owned by this class
|
171 |
+
this.params.dispose();
|
172 |
+
this.decoding.dispose();
|
173 |
+
this.encoding.dispose();
|
174 |
+
this.vm.dispose();
|
175 |
+
this.kvCache.dispose();
|
176 |
+
this.fclearKVCaches.dispose();
|
177 |
+
if (this.logitsOnCPU != undefined) {
|
178 |
+
this.logitsOnCPU.dispose();
|
179 |
+
}
|
180 |
+
}
|
181 |
+
|
182 |
+
#clearKVCache() {
|
183 |
+
this.fclearKVCaches(this.kvCache);
|
184 |
+
this.kvCacheLength = 0;
|
185 |
+
}
|
186 |
+
|
187 |
+
#forward(inputs, curPos) {
|
188 |
+
this.tvm.beginScope();
|
189 |
+
var retValue;
|
190 |
+
const seqLenShape = this.tvm.makeShapeTuple([curPos]);
|
191 |
+
if (inputs.shape[1] > 1) {
|
192 |
+
retValue = this.encoding(
|
193 |
+
inputs, seqLenShape, this.kvCache, this.params
|
194 |
+
);
|
195 |
+
} else {
|
196 |
+
retValue = this.decoding(
|
197 |
+
inputs, seqLenShape, this.kvCache, this.params
|
198 |
+
);
|
199 |
+
}
|
200 |
+
const logits = this.tvm.detachFromCurrentScope(retValue.get(0));
|
201 |
+
this.tvm.endScope();
|
202 |
+
this.tvm.attachToCurrentScope(logits);
|
203 |
+
return logits;
|
204 |
+
}
|
205 |
+
|
206 |
+
// NOTE: caller must call device.sync()
|
207 |
+
#updateLogitsOnCPU(logits) {
|
208 |
+
if (this.logitsOnCPU == undefined) {
|
209 |
+
this.logitsOnCPU = this.tvm.detachFromCurrentScope(
|
210 |
+
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())
|
211 |
+
);
|
212 |
+
} else {
|
213 |
+
if (logits.shape[0] != this.logitsOnCPU.shape[0]) {
|
214 |
+
throw Error("We expect the size of logits to remain unchanged");
|
215 |
+
}
|
216 |
+
}
|
217 |
+
this.logitsOnCPU.copyFrom(logits);
|
218 |
+
}
|
219 |
+
|
220 |
+
async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) {
|
221 |
+
this.tvm.beginScope();
|
222 |
+
this.#updateLogitsOnCPU(logits);
|
223 |
+
this.tvm.endScope();
|
224 |
+
await this.device.sync();
|
225 |
+
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
|
226 |
+
}
|
227 |
+
|
228 |
+
async getInputTokens() {
|
229 |
+
let tokens = [this.bosTokenId];
|
230 |
+
let prompts = ""
|
231 |
+
if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) {
|
232 |
+
prompts = this.conversation.getPromptArray();
|
233 |
+
} else {
|
234 |
+
tokens.pop();
|
235 |
+
prompts = this.conversation.getPromptArrayUnproccessed();
|
236 |
+
}
|
237 |
+
tokens.push(...await this.tokenizer.encodeIds(prompts[0]));
|
238 |
+
let ctxLength = tokens.length;
|
239 |
+
let context = [];
|
240 |
+
let need_shift_window = false;
|
241 |
+
for (let i = prompts.length - 1; i > 0; --i) {
|
242 |
+
const encoded = this.tokenizer.encodeIds(prompts[i]);
|
243 |
+
ctxLength += encoded.length;
|
244 |
+
if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) {
|
245 |
+
need_shift_window = true;
|
246 |
+
break;
|
247 |
+
}
|
248 |
+
context.unshift(encoded);
|
249 |
+
}
|
250 |
+
if (!need_shift_window) {
|
251 |
+
for (const ctx of context) {
|
252 |
+
tokens.push(...ctx);
|
253 |
+
}
|
254 |
+
return tokens;
|
255 |
+
}
|
256 |
+
// need shift window and re-encode
|
257 |
+
this.logger("need shift window")
|
258 |
+
this.kvCacheLength = 0;
|
259 |
+
this.clearCache = true;
|
260 |
+
// abandon all tokens we collected
|
261 |
+
tokens = [this.bosTokenId]
|
262 |
+
let all_prompts = this.conversation.getPromptArray();
|
263 |
+
tokens.push(...await this.tokenizer.encodeIds(all_prompts[0]));
|
264 |
+
context = [];
|
265 |
+
ctxLength = tokens.length;
|
266 |
+
//only keep 10% of the window context
|
267 |
+
const fill_factor = 0.1
|
268 |
+
for (let i = all_prompts.length - 1; i > 0; --i) {
|
269 |
+
const encoded = this.tokenizer.encodeIds(all_prompts[i]);
|
270 |
+
ctxLength += encoded.length;
|
271 |
+
if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) {
|
272 |
+
break;
|
273 |
+
}
|
274 |
+
context.unshift(encoded);
|
275 |
+
}
|
276 |
+
for (const ctx of context) {
|
277 |
+
tokens.push(...ctx);
|
278 |
+
}
|
279 |
+
if (tokens.length + this.meanGenLength >= this.maxWindowLength) {
|
280 |
+
throw Error("Exceed max window length curr=" + tokens.length);
|
281 |
+
}
|
282 |
+
return tokens;
|
283 |
+
}
|
284 |
+
|
285 |
+
resetChat() {
|
286 |
+
if (this.conversation) {
|
287 |
+
this.conversation.reset();
|
288 |
+
}
|
289 |
+
this.#clearKVCache();
|
290 |
+
this.decodingTotalTime = 0;
|
291 |
+
this.encodingTotalTime = 0;
|
292 |
+
this.decodingTotalTokens = 0;
|
293 |
+
this.encodingTotalTokens = 0;
|
294 |
+
}
|
295 |
+
|
296 |
+
async generate(inputPrompt, callbackUpdateResponse) {
|
297 |
+
// switch to new Conversation
|
298 |
+
if (this.conversation.convId !== tvmjsGlobalEnv.covId) {}
|
299 |
+
this.conversation.appendMessage(this.conversation.roles[0], inputPrompt);
|
300 |
+
this.conversation.appendMessage(this.conversation.roles[1], "");
|
301 |
+
const stopStr = this.conversation.getStopStr();
|
302 |
+
const tokens = await this.getInputTokens();
|
303 |
+
const inputTokenLength = tokens.length;
|
304 |
+
|
305 |
+
var outputPrompt = "";
|
306 |
+
if (this.clearCache) {
|
307 |
+
this.#clearKVCache();
|
308 |
+
this.clearCache = false;
|
309 |
+
}
|
310 |
+
const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length);
|
311 |
+
if (maxGenLen < this.meanGenLength) {
|
312 |
+
throw Error("Too small window size config");
|
313 |
+
}
|
314 |
+
let step = 0;
|
315 |
+
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
|
316 |
+
this.tvm.beginScope();
|
317 |
+
var inputData;
|
318 |
+
|
319 |
+
let tstart = performance.now();
|
320 |
+
if (step == 0) {
|
321 |
+
inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
|
322 |
+
inputData.copyFrom(tokens);
|
323 |
+
} else {
|
324 |
+
inputData = this.tvm.empty([1, 1], "int32", this.device);
|
325 |
+
inputData.copyFrom(tokens.slice(tokens.length - 1));
|
326 |
+
}
|
327 |
+
const logits = this.tvm.detachFromCurrentScope(
|
328 |
+
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
|
329 |
+
);
|
330 |
+
this.tvm.endScope();
|
331 |
+
|
332 |
+
const nextToken = await this.sampleTokenFromLogits(logits);
|
333 |
+
logits.dispose();
|
334 |
+
|
335 |
+
tokens.push(nextToken);
|
336 |
+
const outputTokens = tokens.slice(inputTokenLength);
|
337 |
+
outputPrompt = this.tokenizer.decodeIds(outputTokens);
|
338 |
+
|
339 |
+
if (nextToken == this.eosTokenId) break;
|
340 |
+
|
341 |
+
const stopPos = outputPrompt.lastIndexOf(stopStr);
|
342 |
+
if (stopPos != -1) {
|
343 |
+
outputPrompt = outputPrompt.substring(0, stopPos);
|
344 |
+
break;
|
345 |
+
}
|
346 |
+
let tend = performance.now();
|
347 |
+
if (step != 0) {
|
348 |
+
this.decodingTotalTokens += 1;
|
349 |
+
this.decodingTotalTime += (tend - tstart) / 1000;
|
350 |
+
} else {
|
351 |
+
this.encodingTotalTime += (tend - tstart) / 1000;
|
352 |
+
this.encodingTotalTokens += inputTokenLength;
|
353 |
+
}
|
354 |
+
|
355 |
+
if (step % this.streamInterval == 0) {
|
356 |
+
callbackUpdateResponse(step, outputPrompt);
|
357 |
+
}
|
358 |
+
}
|
359 |
+
this.kvCacheLength += tokens.length - 1;
|
360 |
+
tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt;
|
361 |
+
return outputPrompt;
|
362 |
+
}
|
363 |
+
|
364 |
+
async evaluate() {
|
365 |
+
// run a canonical evaluation of the flow
|
366 |
+
this.#clearKVCache();
|
367 |
+
const testPrompt = "The capital of Canada is";
|
368 |
+
const ids = await this.tokenizer.encodeIds(testPrompt);
|
369 |
+
const inputPromptSize = ids.length;
|
370 |
+
const tokens = Array.from(ids);
|
371 |
+
tokens.unshift(this.bosTokenId);
|
372 |
+
if (tokens.length == 0) {
|
373 |
+
throw Error("empty token");
|
374 |
+
}
|
375 |
+
|
376 |
+
this.tvm.beginScope();
|
377 |
+
const inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
|
378 |
+
inputData.copyFrom(tokens);
|
379 |
+
const encodingStart = performance.now();
|
380 |
+
this.#forward(inputData, tokens.length);
|
381 |
+
this.tvm.endScope();
|
382 |
+
await this.device.sync();
|
383 |
+
|
384 |
+
const decodingStart = performance.now();
|
385 |
+
|
386 |
+
this.tvm.beginScope();
|
387 |
+
const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]);
|
388 |
+
this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1));
|
389 |
+
await this.device.sync();
|
390 |
+
this.tvm.endScope();
|
391 |
+
|
392 |
+
const decodingEnd = performance.now();
|
393 |
+
const msg = (
|
394 |
+
`encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` +
|
395 |
+
`decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`
|
396 |
+
);
|
397 |
+
|
398 |
+
// simply log tokens for eyeballing.
|
399 |
+
console.log("Logits:");
|
400 |
+
console.log(this.logitsOnCPU.toArray());
|
401 |
+
console.log(msg);
|
402 |
+
}
|
403 |
+
|
404 |
+
/**
|
405 |
+
* async preload webgpu pipelines when possible.
|
406 |
+
*/
|
407 |
+
async asyncLoadWebGPUPiplines() {
|
408 |
+
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
|
409 |
+
}
|
410 |
+
|
411 |
+
runtimeStatsText() {
|
412 |
+
return (
|
413 |
+
`encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` +
|
414 |
+
`decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec`
|
415 |
+
)
|
416 |
+
}
|
417 |
+
}
|
418 |
+
|
419 |
+
/**
|
420 |
+
* A instance that can be used to facilitate deployment.
|
421 |
+
*/
|
422 |
+
class LLMChatInstance {
|
423 |
+
constructor() {
|
424 |
+
this.requestInProgress = false;
|
425 |
+
this.config = undefined;
|
426 |
+
this.tvm = undefined;
|
427 |
+
this.pipeline = undefined;
|
428 |
+
this.logger = console.log;
|
429 |
+
this.debugTest = false;
|
430 |
+
}
|
431 |
+
/**
|
432 |
+
* Initialize TVM
|
433 |
+
* @param wasmUrl URL to wasm source.
|
434 |
+
* @param cacheUrl URL to NDArray cache.
|
435 |
+
* @param logger Custom logger.
|
436 |
+
*/
|
437 |
+
async #asyncInitTVM(wasmUrl, cacheUrl) {
|
438 |
+
if (this.tvm !== undefined) {
|
439 |
+
return;
|
440 |
+
}
|
441 |
+
this.logger = console.log;
|
442 |
+
|
443 |
+
const wasmSource = await (
|
444 |
+
await fetch(wasmUrl)
|
445 |
+
).arrayBuffer();
|
446 |
+
const tvm = await tvmjs.instantiate(
|
447 |
+
new Uint8Array(wasmSource),
|
448 |
+
new EmccWASI(),
|
449 |
+
this.logger
|
450 |
+
);
|
451 |
+
// intialize WebGPU
|
452 |
+
try {
|
453 |
+
const output = await tvmjs.detectGPUDevice();
|
454 |
+
if (output !== undefined) {
|
455 |
+
var label = "WebGPU";
|
456 |
+
if (output.adapterInfo.description.length != 0) {
|
457 |
+
label += " - " + output.adapterInfo.description;
|
458 |
+
} else {
|
459 |
+
label += " - " + output.adapterInfo.vendor;
|
460 |
+
}
|
461 |
+
this.appendMessage("init", "Initialize GPU device: " + label);
|
462 |
+
tvm.initWebGPU(output.device);
|
463 |
+
} else {
|
464 |
+
this.appendMessage("error", "This browser env do not support WebGPU");
|
465 |
+
this.reset();
|
466 |
+
throw Error("This browser env do not support WebGPU");
|
467 |
+
}
|
468 |
+
} catch (err) {
|
469 |
+
this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString());
|
470 |
+
console.log(err);
|
471 |
+
this.reset();
|
472 |
+
throw Error("Find an error initializing WebGPU: " + err.toString());
|
473 |
+
}
|
474 |
+
this.tvm = tvm;
|
475 |
+
const initProgressCallback = (report) => {
|
476 |
+
this.updateLastMessage("initing", report.text);
|
477 |
+
}
|
478 |
+
tvm.registerInitProgressCallback(initProgressCallback);
|
479 |
+
|
480 |
+
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu());
|
481 |
+
}
|
482 |
+
/**
|
483 |
+
* Async initialize instance.
|
484 |
+
*/
|
485 |
+
async asyncInit() {
|
486 |
+
if (this.pipeline !== undefined) return;
|
487 |
+
await this.#asyncInitConfig();
|
488 |
+
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl);
|
489 |
+
await this.#asyncInitPipeline();
|
490 |
+
}
|
491 |
+
|
492 |
+
/**
|
493 |
+
* Async initialize config
|
494 |
+
*/
|
495 |
+
async #asyncInitConfig() {
|
496 |
+
if (this.config !== undefined) return;
|
497 |
+
this.config = await (await fetch("/lib/WebLLM/config.json")).json();
|
498 |
+
}
|
499 |
+
|
500 |
+
/**
|
501 |
+
* Initialize the pipeline
|
502 |
+
*
|
503 |
+
* @param tokenizerModel The url to tokenizer model.
|
504 |
+
*/
|
505 |
+
async #asyncInitPipeline() {
|
506 |
+
if (this.pipeline !== undefined) return;
|
507 |
+
// initialize UX and tokenizer
|
508 |
+
const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer);
|
509 |
+
this.pipeline = this.tvm.withNewScope(() => {
|
510 |
+
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
|
511 |
+
});
|
512 |
+
await this.pipeline.asyncLoadWebGPUPiplines();
|
513 |
+
this.appendMessage("initing", "All initialization finished.", true);
|
514 |
+
}
|
515 |
+
|
516 |
+
appendMessage(kind, text, ifFinish) {
|
517 |
+
if (kind == "initing") {
|
518 |
+
text = "[System Initalize] " + text;
|
519 |
+
}
|
520 |
+
console.log(`[${kind}] ${text}`);
|
521 |
+
globalThis.postMessage({
|
522 |
+
type: 'initing',
|
523 |
+
action: 'append',
|
524 |
+
msg: text,
|
525 |
+
ifError: kind == 'error',
|
526 |
+
ifFinish: !!ifFinish
|
527 |
+
})
|
528 |
+
}
|
529 |
+
|
530 |
+
updateLastMessage(type, text, ifFinish) {
|
531 |
+
if (type == "initing") {
|
532 |
+
text = `[System Initalize] ${text}`
|
533 |
+
}
|
534 |
+
globalThis.postMessage({
|
535 |
+
type,
|
536 |
+
action: 'updateLast',
|
537 |
+
msg: text,
|
538 |
+
ifFinish: !!ifFinish
|
539 |
+
})
|
540 |
+
}
|
541 |
+
|
542 |
+
async respondTestMessage(repeat) {
|
543 |
+
const testMessage = "I am a friendly bot. Please ask questions.";
|
544 |
+
const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage);
|
545 |
+
|
546 |
+
const currentIds = [];
|
547 |
+
for (let k = 0; k < repeat; ++k) {
|
548 |
+
for (let i = 0; i < encodedResult.length; ++i) {
|
549 |
+
currentIds.push(encodedResult[i]);
|
550 |
+
const msg = this.pipeline.tokenizer.decodeIds(currentIds);
|
551 |
+
this.updateLastMessage("chatting", msg);
|
552 |
+
await new Promise(resolve => setTimeout(resolve, 50));
|
553 |
+
}
|
554 |
+
}
|
555 |
+
}
|
556 |
+
|
557 |
+
resetChat() {
|
558 |
+
if (this.pipeline) {
|
559 |
+
this.pipeline.resetChat();
|
560 |
+
}
|
561 |
+
}
|
562 |
+
|
563 |
+
/**
|
564 |
+
* Run generate
|
565 |
+
*/
|
566 |
+
async generate() {
|
567 |
+
if (this.requestInProgress) {
|
568 |
+
return;
|
569 |
+
}
|
570 |
+
|
571 |
+
this.requestInProgress = true;
|
572 |
+
|
573 |
+
try {
|
574 |
+
await this.asyncInit();
|
575 |
+
} catch (err) {
|
576 |
+
this.appendMessage("error", "Init error, " + err.toString());
|
577 |
+
console.log(err);
|
578 |
+
this.reset();
|
579 |
+
this.requestInProgress = false;
|
580 |
+
return;
|
581 |
+
}
|
582 |
+
|
583 |
+
if (this.debugTest) {
|
584 |
+
await this.pipeline.evaluate();
|
585 |
+
this.requestInProgress = false;
|
586 |
+
return;
|
587 |
+
}
|
588 |
+
|
589 |
+
const prompt = tvmjsGlobalEnv.message;
|
590 |
+
if (prompt == "") {
|
591 |
+
this.requestInProgress = false;
|
592 |
+
return;
|
593 |
+
}
|
594 |
+
|
595 |
+
const callbackUpdateResponse = (step, msg) => {
|
596 |
+
if (msg.endsWith("##")) {
|
597 |
+
msg = msg.substring(0, msg.length - 2);
|
598 |
+
} else if (msg.endsWith("#")) {
|
599 |
+
msg = msg.substring(0, msg.length - 1);
|
600 |
+
}
|
601 |
+
this.updateLastMessage("chatting", msg);
|
602 |
+
};
|
603 |
+
try {
|
604 |
+
const output = await this.pipeline.generate(prompt, callbackUpdateResponse);
|
605 |
+
this.updateLastMessage("chatting", output, true);
|
606 |
+
this.updateLastMessage("stats",this.pipeline.runtimeStatsText())
|
607 |
+
console.log(this.pipeline.runtimeStatsText());
|
608 |
+
} catch (err) {
|
609 |
+
this.appendMessage("error", "Generate error, " + err.toString());
|
610 |
+
console.log(err);
|
611 |
+
this.reset();
|
612 |
+
}
|
613 |
+
this.requestInProgress = false;
|
614 |
+
}
|
615 |
+
|
616 |
+
/**
|
617 |
+
* Reset the instance;
|
618 |
+
*/
|
619 |
+
reset() {
|
620 |
+
this.tvm = undefined;
|
621 |
+
if (this.pipeline !== undefined) {
|
622 |
+
this.pipeline.dispose();
|
623 |
+
}
|
624 |
+
this.pipeline = undefined;
|
625 |
+
}
|
626 |
+
}
|
627 |
+
|
628 |
+
localLLMChatIntance = new LLMChatInstance();
|
629 |
+
|
630 |
+
tvmjsGlobalEnv.asyncOnGenerate = async function () {
|
631 |
+
await localLLMChatIntance.generate();
|
632 |
+
};
|
633 |
+
|
634 |
+
tvmjsGlobalEnv.asyncOnReset = async function () {
|
635 |
+
await localLLMChatIntance.resetChat();
|
636 |
+
};
|
sentencepiece.js
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tvmjs.bundle.js
CHANGED
@@ -678,7 +678,11 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
678 |
class CanvaRenderManager {
|
679 |
constructor(device, canvas) {
|
680 |
this.device = device;
|
681 |
-
const ctx = canvas.getContext("webgpu"
|
|
|
|
|
|
|
|
|
682 |
if (ctx == null) {
|
683 |
throw Error("Cannot bind WebGPU context");
|
684 |
}
|
@@ -2022,7 +2026,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
2022 |
*/
|
2023 |
constructor(mod, device) {
|
2024 |
this.mod = mod;
|
2025 |
-
this.mod.getFunction("vm_initialization")(new Scalar(device.deviceType, "int"), new Scalar(device.deviceId, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"),
|
2026 |
// explicitly specify host device type
|
2027 |
new Scalar(DeviceStrToEnum.cpu, "int"), new Scalar(0, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"));
|
2028 |
}
|
@@ -2388,7 +2392,10 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
2388 |
*/
|
2389 |
fetchNDArrayCache(ndarrayCacheUrl, device) {
|
2390 |
return __awaiter(this, void 0, void 0, function* () {
|
2391 |
-
|
|
|
|
|
|
|
2392 |
var list;
|
2393 |
try {
|
2394 |
list = yield (yield fetch(jsonUrl)).json();
|
@@ -2420,12 +2427,10 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
2420 |
const reportCallback = (iter) => {
|
2421 |
// report
|
2422 |
for (let j = 0; j < this.initProgressCallback.length; ++j) {
|
2423 |
-
let text = "
|
2424 |
-
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB
|
2425 |
-
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "%
|
2426 |
-
text += timeElapsed + " secs
|
2427 |
-
text += " It can take a while when we first visit this page to populate the cache.";
|
2428 |
-
text += " Later refreshes will become faster.";
|
2429 |
this.initProgressCallback[j]({
|
2430 |
progress: fetchedBytes / totalBytes,
|
2431 |
timeElapsed: timeElapsed,
|
@@ -2977,7 +2982,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
2977 |
}
|
2978 |
wrapJSFuncAsPackedCFunc(func) {
|
2979 |
const lib = this.lib;
|
2980 |
-
return (argValues, argCodes, nargs, ret,
|
2981 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
2982 |
_handle) => {
|
2983 |
const jsArgs = [];
|
@@ -3454,7 +3459,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
|
|
3454 |
const localSession = flocal();
|
3455 |
support.assert(localSession instanceof runtime.Module);
|
3456 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
3457 |
-
this.inst.registerFunc("rpc.WasmSession",
|
3458 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
3459 |
(_args) => {
|
3460 |
return localSession;
|
|
|
678 |
class CanvaRenderManager {
|
679 |
constructor(device, canvas) {
|
680 |
this.device = device;
|
681 |
+
const ctx = canvas.getContext("webgpu", {
|
682 |
+
alpha: false,
|
683 |
+
antialias: false,
|
684 |
+
powerPreference: "high-performance",
|
685 |
+
});
|
686 |
if (ctx == null) {
|
687 |
throw Error("Cannot bind WebGPU context");
|
688 |
}
|
|
|
2026 |
*/
|
2027 |
constructor(mod, device) {
|
2028 |
this.mod = mod;
|
2029 |
+
this.mod.getFunction("vm_initialization")(new Scalar(device.deviceType, "int"), new Scalar(device.deviceId, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"),
|
2030 |
// explicitly specify host device type
|
2031 |
new Scalar(DeviceStrToEnum.cpu, "int"), new Scalar(0, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"));
|
2032 |
}
|
|
|
2392 |
*/
|
2393 |
fetchNDArrayCache(ndarrayCacheUrl, device) {
|
2394 |
return __awaiter(this, void 0, void 0, function* () {
|
2395 |
+
const cacheExists = yield caches.has("tvmjs");
|
2396 |
+
const jsonUrl = cacheExists
|
2397 |
+
? "/lib/WebLLM/vicuna-7b/ndarray-cache.json"
|
2398 |
+
: new URL("ndarray-cache.json", ndarrayCacheUrl).href;
|
2399 |
var list;
|
2400 |
try {
|
2401 |
list = yield (yield fetch(jsonUrl)).json();
|
|
|
2427 |
const reportCallback = (iter) => {
|
2428 |
// report
|
2429 |
for (let j = 0; j < this.initProgressCallback.length; ++j) {
|
2430 |
+
let text = "[" + iter + "/" + list.length + "]: ";
|
2431 |
+
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + " MB, ";
|
2432 |
+
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% complete, ";
|
2433 |
+
text += timeElapsed + " secs";
|
|
|
|
|
2434 |
this.initProgressCallback[j]({
|
2435 |
progress: fetchedBytes / totalBytes,
|
2436 |
timeElapsed: timeElapsed,
|
|
|
2982 |
}
|
2983 |
wrapJSFuncAsPackedCFunc(func) {
|
2984 |
const lib = this.lib;
|
2985 |
+
return (argValues, argCodes, nargs, ret,
|
2986 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
2987 |
_handle) => {
|
2988 |
const jsArgs = [];
|
|
|
3459 |
const localSession = flocal();
|
3460 |
support.assert(localSession instanceof runtime.Module);
|
3461 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
3462 |
+
this.inst.registerFunc("rpc.WasmSession",
|
3463 |
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
3464 |
(_args) => {
|
3465 |
return localSession;
|