mrick commited on
Commit
dcd56c3
·
1 Parent(s): 43f611f

Upload 4 files

Browse files
Files changed (3) hide show
  1. llm_chat.js +636 -0
  2. sentencepiece.js +0 -0
  3. tvmjs.bundle.js +16 -11
llm_chat.js ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Helper to keep track of history conversations.
3
+ */
4
+ class Conversation {
5
+ constructor(config) {
6
+ this.system = config.system;
7
+ this.roles = config.roles;
8
+ this.offset = config.offset;
9
+ this.seps = config.seps;
10
+ this.convId = null;
11
+ this.contextWindowStart = 0;
12
+ }
13
+
14
+ /**
15
+ * Get prompt arrays with the first one as system.
16
+ *
17
+ * @returns The prompt array.
18
+ */
19
+ getPromptArray() {
20
+ if (this.seps.length == 0) {
21
+ throw Error("Need seps to work")
22
+ }
23
+ let ret = [this.system + this.seps[0]];
24
+
25
+ for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
26
+ const item = tvmjsGlobalEnv.workerHistoryMsg[i];
27
+ const role = item[0];
28
+ const message = item[1];
29
+ if (message !== undefined && message != "") {
30
+ ret.push(role + ": " + message + this.seps[i % this.seps.length]);
31
+ } else {
32
+ ret.push(role + ":");
33
+ }
34
+ }
35
+ return ret;
36
+ }
37
+
38
+ /**
39
+ * Get prompt arrays that has not been fed as input
40
+ *
41
+ * @returns The prompt array.
42
+ */
43
+ getPromptArrayUnproccessed() {
44
+ if (this.seps.length == 0) {
45
+ throw Error("Need seps to work")
46
+ }
47
+ if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) {
48
+ throw Error("needs to call getLastPromptArray for the first message");
49
+ }
50
+ let ret = [this.seps[this.seps.length - 1]];
51
+ for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
52
+ const item = tvmjsGlobalEnv.workerHistoryMsg[i];
53
+ const role = item[0];
54
+ const message = item[1];
55
+ if (message !== undefined && message != "") {
56
+ ret.push(role + ": " + message + this.seps[i % this.seps.length]);
57
+ } else {
58
+ ret.push(role + ":");
59
+ }
60
+ }
61
+ return ret;
62
+
63
+ }
64
+
65
+ /**
66
+ * Get last prompt array with prefix as system.
67
+ *
68
+ * @returns The prompt array.
69
+ */
70
+ getLastPromptArray() {
71
+ if (this.seps.length == 0) {
72
+ throw Error("Need seps to work")
73
+ }
74
+ let ret = [this.system + this.seps[0]];
75
+
76
+ for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) {
77
+ const item = tvmjsGlobalEnv.workerHistoryMsg[i];
78
+ const role = item[0];
79
+ const message = item[1];
80
+ if (message !== undefined && message != "") {
81
+ ret.push(role + ": " + message + this.seps[i % this.seps.length]);
82
+ } else {
83
+ ret.push(role + ":");
84
+ }
85
+ }
86
+ return ret;
87
+ }
88
+
89
+ reset() {
90
+ tvmjsGlobalEnv.workerHistoryMsg = [];
91
+ this.covId = null
92
+ }
93
+
94
+ getStopStr() {
95
+ return this.seps[this.seps.length - 1];
96
+ }
97
+
98
+ appendMessage(role, message) {
99
+ tvmjsGlobalEnv.workerHistoryMsg.push([role, message]);
100
+ }
101
+
102
+ switchConversation(message) {
103
+ tvmjsGlobalEnv.workerHistoryMsg = message
104
+ this.covId = tvmjsGlobalEnv.covId
105
+ }
106
+ }
107
+
108
+ function defaultConversation(maxWindowLength = 2048) {
109
+ return new Conversation({
110
+ system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.",
111
+ roles: ["user", "assistant"],
112
+ maxWindowLength: maxWindowLength,
113
+ offset: 0,
114
+ seps: [" ", "</s>"],
115
+ });
116
+ };
117
+
118
+ class LLMChatPipeline {
119
+ constructor(tvm, tokenizer, cacheMetadata, config) {
120
+ if (cacheMetadata == undefined) {
121
+ throw Error("Expect cacheMetadata");
122
+ }
123
+ this.tvm = tvm;
124
+ this.logger = console.log;
125
+ this.tokenizer = tokenizer;
126
+ this.bosTokenId = 1;
127
+ this.eosTokenId = 2;
128
+
129
+ this.maxWindowLength = config.maxWindowLength;
130
+ this.maxGenLength = config.maxGenLength;
131
+ this.meanGenLength = config.meanGenLength;
132
+ this.streamInterval = 1;
133
+
134
+ this.decodingTotalTime = 0;
135
+ this.decodingTotalTokens = 0;
136
+ this.encodingTotalTime = 0;
137
+ this.encodingTotalTokens = 0;
138
+
139
+ this.conversation = defaultConversation(this.maxWindowLength);
140
+
141
+ this.device = this.tvm.webgpu();
142
+ this.vm = this.tvm.detachFromCurrentScope(
143
+ this.tvm.createVirtualMachine(this.device)
144
+ );
145
+ this.encoding = this.tvm.detachFromCurrentScope(
146
+ this.vm.getFunction("encoding")
147
+ );
148
+ this.decoding = this.tvm.detachFromCurrentScope(
149
+ this.vm.getFunction("decoding")
150
+ );
151
+ this.params = this.tvm.detachFromCurrentScope(
152
+ this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize)
153
+ );
154
+ const fcreateCache = this.vm.getFunction("create_kv_cache");
155
+ this.fclearKVCaches = this.tvm.detachFromCurrentScope(
156
+ this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
157
+ );
158
+
159
+ // use extern config for now
160
+ this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
161
+ // fill with pad token
162
+ this.logitsOnCPU = undefined;
163
+
164
+ this.kvCacheLength = 0;
165
+ this.clearCache = true
166
+ }
167
+
168
+
169
+ dispose() {
170
+ // note: tvm instance is not owned by this class
171
+ this.params.dispose();
172
+ this.decoding.dispose();
173
+ this.encoding.dispose();
174
+ this.vm.dispose();
175
+ this.kvCache.dispose();
176
+ this.fclearKVCaches.dispose();
177
+ if (this.logitsOnCPU != undefined) {
178
+ this.logitsOnCPU.dispose();
179
+ }
180
+ }
181
+
182
+ #clearKVCache() {
183
+ this.fclearKVCaches(this.kvCache);
184
+ this.kvCacheLength = 0;
185
+ }
186
+
187
+ #forward(inputs, curPos) {
188
+ this.tvm.beginScope();
189
+ var retValue;
190
+ const seqLenShape = this.tvm.makeShapeTuple([curPos]);
191
+ if (inputs.shape[1] > 1) {
192
+ retValue = this.encoding(
193
+ inputs, seqLenShape, this.kvCache, this.params
194
+ );
195
+ } else {
196
+ retValue = this.decoding(
197
+ inputs, seqLenShape, this.kvCache, this.params
198
+ );
199
+ }
200
+ const logits = this.tvm.detachFromCurrentScope(retValue.get(0));
201
+ this.tvm.endScope();
202
+ this.tvm.attachToCurrentScope(logits);
203
+ return logits;
204
+ }
205
+
206
+ // NOTE: caller must call device.sync()
207
+ #updateLogitsOnCPU(logits) {
208
+ if (this.logitsOnCPU == undefined) {
209
+ this.logitsOnCPU = this.tvm.detachFromCurrentScope(
210
+ this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())
211
+ );
212
+ } else {
213
+ if (logits.shape[0] != this.logitsOnCPU.shape[0]) {
214
+ throw Error("We expect the size of logits to remain unchanged");
215
+ }
216
+ }
217
+ this.logitsOnCPU.copyFrom(logits);
218
+ }
219
+
220
+ async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) {
221
+ this.tvm.beginScope();
222
+ this.#updateLogitsOnCPU(logits);
223
+ this.tvm.endScope();
224
+ await this.device.sync();
225
+ return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
226
+ }
227
+
228
+ async getInputTokens() {
229
+ let tokens = [this.bosTokenId];
230
+ let prompts = ""
231
+ if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) {
232
+ prompts = this.conversation.getPromptArray();
233
+ } else {
234
+ tokens.pop();
235
+ prompts = this.conversation.getPromptArrayUnproccessed();
236
+ }
237
+ tokens.push(...await this.tokenizer.encodeIds(prompts[0]));
238
+ let ctxLength = tokens.length;
239
+ let context = [];
240
+ let need_shift_window = false;
241
+ for (let i = prompts.length - 1; i > 0; --i) {
242
+ const encoded = this.tokenizer.encodeIds(prompts[i]);
243
+ ctxLength += encoded.length;
244
+ if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) {
245
+ need_shift_window = true;
246
+ break;
247
+ }
248
+ context.unshift(encoded);
249
+ }
250
+ if (!need_shift_window) {
251
+ for (const ctx of context) {
252
+ tokens.push(...ctx);
253
+ }
254
+ return tokens;
255
+ }
256
+ // need shift window and re-encode
257
+ this.logger("need shift window")
258
+ this.kvCacheLength = 0;
259
+ this.clearCache = true;
260
+ // abandon all tokens we collected
261
+ tokens = [this.bosTokenId]
262
+ let all_prompts = this.conversation.getPromptArray();
263
+ tokens.push(...await this.tokenizer.encodeIds(all_prompts[0]));
264
+ context = [];
265
+ ctxLength = tokens.length;
266
+ //only keep 10% of the window context
267
+ const fill_factor = 0.1
268
+ for (let i = all_prompts.length - 1; i > 0; --i) {
269
+ const encoded = this.tokenizer.encodeIds(all_prompts[i]);
270
+ ctxLength += encoded.length;
271
+ if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) {
272
+ break;
273
+ }
274
+ context.unshift(encoded);
275
+ }
276
+ for (const ctx of context) {
277
+ tokens.push(...ctx);
278
+ }
279
+ if (tokens.length + this.meanGenLength >= this.maxWindowLength) {
280
+ throw Error("Exceed max window length curr=" + tokens.length);
281
+ }
282
+ return tokens;
283
+ }
284
+
285
+ resetChat() {
286
+ if (this.conversation) {
287
+ this.conversation.reset();
288
+ }
289
+ this.#clearKVCache();
290
+ this.decodingTotalTime = 0;
291
+ this.encodingTotalTime = 0;
292
+ this.decodingTotalTokens = 0;
293
+ this.encodingTotalTokens = 0;
294
+ }
295
+
296
+ async generate(inputPrompt, callbackUpdateResponse) {
297
+ // switch to new Conversation
298
+ if (this.conversation.convId !== tvmjsGlobalEnv.covId) {}
299
+ this.conversation.appendMessage(this.conversation.roles[0], inputPrompt);
300
+ this.conversation.appendMessage(this.conversation.roles[1], "");
301
+ const stopStr = this.conversation.getStopStr();
302
+ const tokens = await this.getInputTokens();
303
+ const inputTokenLength = tokens.length;
304
+
305
+ var outputPrompt = "";
306
+ if (this.clearCache) {
307
+ this.#clearKVCache();
308
+ this.clearCache = false;
309
+ }
310
+ const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length);
311
+ if (maxGenLen < this.meanGenLength) {
312
+ throw Error("Too small window size config");
313
+ }
314
+ let step = 0;
315
+ for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
316
+ this.tvm.beginScope();
317
+ var inputData;
318
+
319
+ let tstart = performance.now();
320
+ if (step == 0) {
321
+ inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
322
+ inputData.copyFrom(tokens);
323
+ } else {
324
+ inputData = this.tvm.empty([1, 1], "int32", this.device);
325
+ inputData.copyFrom(tokens.slice(tokens.length - 1));
326
+ }
327
+ const logits = this.tvm.detachFromCurrentScope(
328
+ this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
329
+ );
330
+ this.tvm.endScope();
331
+
332
+ const nextToken = await this.sampleTokenFromLogits(logits);
333
+ logits.dispose();
334
+
335
+ tokens.push(nextToken);
336
+ const outputTokens = tokens.slice(inputTokenLength);
337
+ outputPrompt = this.tokenizer.decodeIds(outputTokens);
338
+
339
+ if (nextToken == this.eosTokenId) break;
340
+
341
+ const stopPos = outputPrompt.lastIndexOf(stopStr);
342
+ if (stopPos != -1) {
343
+ outputPrompt = outputPrompt.substring(0, stopPos);
344
+ break;
345
+ }
346
+ let tend = performance.now();
347
+ if (step != 0) {
348
+ this.decodingTotalTokens += 1;
349
+ this.decodingTotalTime += (tend - tstart) / 1000;
350
+ } else {
351
+ this.encodingTotalTime += (tend - tstart) / 1000;
352
+ this.encodingTotalTokens += inputTokenLength;
353
+ }
354
+
355
+ if (step % this.streamInterval == 0) {
356
+ callbackUpdateResponse(step, outputPrompt);
357
+ }
358
+ }
359
+ this.kvCacheLength += tokens.length - 1;
360
+ tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt;
361
+ return outputPrompt;
362
+ }
363
+
364
+ async evaluate() {
365
+ // run a canonical evaluation of the flow
366
+ this.#clearKVCache();
367
+ const testPrompt = "The capital of Canada is";
368
+ const ids = await this.tokenizer.encodeIds(testPrompt);
369
+ const inputPromptSize = ids.length;
370
+ const tokens = Array.from(ids);
371
+ tokens.unshift(this.bosTokenId);
372
+ if (tokens.length == 0) {
373
+ throw Error("empty token");
374
+ }
375
+
376
+ this.tvm.beginScope();
377
+ const inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
378
+ inputData.copyFrom(tokens);
379
+ const encodingStart = performance.now();
380
+ this.#forward(inputData, tokens.length);
381
+ this.tvm.endScope();
382
+ await this.device.sync();
383
+
384
+ const decodingStart = performance.now();
385
+
386
+ this.tvm.beginScope();
387
+ const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]);
388
+ this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1));
389
+ await this.device.sync();
390
+ this.tvm.endScope();
391
+
392
+ const decodingEnd = performance.now();
393
+ const msg = (
394
+ `encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` +
395
+ `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec`
396
+ );
397
+
398
+ // simply log tokens for eyeballing.
399
+ console.log("Logits:");
400
+ console.log(this.logitsOnCPU.toArray());
401
+ console.log(msg);
402
+ }
403
+
404
+ /**
405
+ * async preload webgpu pipelines when possible.
406
+ */
407
+ async asyncLoadWebGPUPiplines() {
408
+ await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
409
+ }
410
+
411
+ runtimeStatsText() {
412
+ return (
413
+ `encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` +
414
+ `decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec`
415
+ )
416
+ }
417
+ }
418
+
419
+ /**
420
+ * A instance that can be used to facilitate deployment.
421
+ */
422
+ class LLMChatInstance {
423
+ constructor() {
424
+ this.requestInProgress = false;
425
+ this.config = undefined;
426
+ this.tvm = undefined;
427
+ this.pipeline = undefined;
428
+ this.logger = console.log;
429
+ this.debugTest = false;
430
+ }
431
+ /**
432
+ * Initialize TVM
433
+ * @param wasmUrl URL to wasm source.
434
+ * @param cacheUrl URL to NDArray cache.
435
+ * @param logger Custom logger.
436
+ */
437
+ async #asyncInitTVM(wasmUrl, cacheUrl) {
438
+ if (this.tvm !== undefined) {
439
+ return;
440
+ }
441
+ this.logger = console.log;
442
+
443
+ const wasmSource = await (
444
+ await fetch(wasmUrl)
445
+ ).arrayBuffer();
446
+ const tvm = await tvmjs.instantiate(
447
+ new Uint8Array(wasmSource),
448
+ new EmccWASI(),
449
+ this.logger
450
+ );
451
+ // intialize WebGPU
452
+ try {
453
+ const output = await tvmjs.detectGPUDevice();
454
+ if (output !== undefined) {
455
+ var label = "WebGPU";
456
+ if (output.adapterInfo.description.length != 0) {
457
+ label += " - " + output.adapterInfo.description;
458
+ } else {
459
+ label += " - " + output.adapterInfo.vendor;
460
+ }
461
+ this.appendMessage("init", "Initialize GPU device: " + label);
462
+ tvm.initWebGPU(output.device);
463
+ } else {
464
+ this.appendMessage("error", "This browser env do not support WebGPU");
465
+ this.reset();
466
+ throw Error("This browser env do not support WebGPU");
467
+ }
468
+ } catch (err) {
469
+ this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString());
470
+ console.log(err);
471
+ this.reset();
472
+ throw Error("Find an error initializing WebGPU: " + err.toString());
473
+ }
474
+ this.tvm = tvm;
475
+ const initProgressCallback = (report) => {
476
+ this.updateLastMessage("initing", report.text);
477
+ }
478
+ tvm.registerInitProgressCallback(initProgressCallback);
479
+
480
+ await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu());
481
+ }
482
+ /**
483
+ * Async initialize instance.
484
+ */
485
+ async asyncInit() {
486
+ if (this.pipeline !== undefined) return;
487
+ await this.#asyncInitConfig();
488
+ await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl);
489
+ await this.#asyncInitPipeline();
490
+ }
491
+
492
+ /**
493
+ * Async initialize config
494
+ */
495
+ async #asyncInitConfig() {
496
+ if (this.config !== undefined) return;
497
+ this.config = await (await fetch("/lib/WebLLM/config.json")).json();
498
+ }
499
+
500
+ /**
501
+ * Initialize the pipeline
502
+ *
503
+ * @param tokenizerModel The url to tokenizer model.
504
+ */
505
+ async #asyncInitPipeline() {
506
+ if (this.pipeline !== undefined) return;
507
+ // initialize UX and tokenizer
508
+ const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer);
509
+ this.pipeline = this.tvm.withNewScope(() => {
510
+ return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
511
+ });
512
+ await this.pipeline.asyncLoadWebGPUPiplines();
513
+ this.appendMessage("initing", "All initialization finished.", true);
514
+ }
515
+
516
+ appendMessage(kind, text, ifFinish) {
517
+ if (kind == "initing") {
518
+ text = "[System Initalize] " + text;
519
+ }
520
+ console.log(`[${kind}] ${text}`);
521
+ globalThis.postMessage({
522
+ type: 'initing',
523
+ action: 'append',
524
+ msg: text,
525
+ ifError: kind == 'error',
526
+ ifFinish: !!ifFinish
527
+ })
528
+ }
529
+
530
+ updateLastMessage(type, text, ifFinish) {
531
+ if (type == "initing") {
532
+ text = `[System Initalize] ${text}`
533
+ }
534
+ globalThis.postMessage({
535
+ type,
536
+ action: 'updateLast',
537
+ msg: text,
538
+ ifFinish: !!ifFinish
539
+ })
540
+ }
541
+
542
+ async respondTestMessage(repeat) {
543
+ const testMessage = "I am a friendly bot. Please ask questions.";
544
+ const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage);
545
+
546
+ const currentIds = [];
547
+ for (let k = 0; k < repeat; ++k) {
548
+ for (let i = 0; i < encodedResult.length; ++i) {
549
+ currentIds.push(encodedResult[i]);
550
+ const msg = this.pipeline.tokenizer.decodeIds(currentIds);
551
+ this.updateLastMessage("chatting", msg);
552
+ await new Promise(resolve => setTimeout(resolve, 50));
553
+ }
554
+ }
555
+ }
556
+
557
+ resetChat() {
558
+ if (this.pipeline) {
559
+ this.pipeline.resetChat();
560
+ }
561
+ }
562
+
563
+ /**
564
+ * Run generate
565
+ */
566
+ async generate() {
567
+ if (this.requestInProgress) {
568
+ return;
569
+ }
570
+
571
+ this.requestInProgress = true;
572
+
573
+ try {
574
+ await this.asyncInit();
575
+ } catch (err) {
576
+ this.appendMessage("error", "Init error, " + err.toString());
577
+ console.log(err);
578
+ this.reset();
579
+ this.requestInProgress = false;
580
+ return;
581
+ }
582
+
583
+ if (this.debugTest) {
584
+ await this.pipeline.evaluate();
585
+ this.requestInProgress = false;
586
+ return;
587
+ }
588
+
589
+ const prompt = tvmjsGlobalEnv.message;
590
+ if (prompt == "") {
591
+ this.requestInProgress = false;
592
+ return;
593
+ }
594
+
595
+ const callbackUpdateResponse = (step, msg) => {
596
+ if (msg.endsWith("##")) {
597
+ msg = msg.substring(0, msg.length - 2);
598
+ } else if (msg.endsWith("#")) {
599
+ msg = msg.substring(0, msg.length - 1);
600
+ }
601
+ this.updateLastMessage("chatting", msg);
602
+ };
603
+ try {
604
+ const output = await this.pipeline.generate(prompt, callbackUpdateResponse);
605
+ this.updateLastMessage("chatting", output, true);
606
+ this.updateLastMessage("stats",this.pipeline.runtimeStatsText())
607
+ console.log(this.pipeline.runtimeStatsText());
608
+ } catch (err) {
609
+ this.appendMessage("error", "Generate error, " + err.toString());
610
+ console.log(err);
611
+ this.reset();
612
+ }
613
+ this.requestInProgress = false;
614
+ }
615
+
616
+ /**
617
+ * Reset the instance;
618
+ */
619
+ reset() {
620
+ this.tvm = undefined;
621
+ if (this.pipeline !== undefined) {
622
+ this.pipeline.dispose();
623
+ }
624
+ this.pipeline = undefined;
625
+ }
626
+ }
627
+
628
+ localLLMChatIntance = new LLMChatInstance();
629
+
630
+ tvmjsGlobalEnv.asyncOnGenerate = async function () {
631
+ await localLLMChatIntance.generate();
632
+ };
633
+
634
+ tvmjsGlobalEnv.asyncOnReset = async function () {
635
+ await localLLMChatIntance.resetChat();
636
+ };
sentencepiece.js ADDED
The diff for this file is too large to render. See raw diff
 
tvmjs.bundle.js CHANGED
@@ -678,7 +678,11 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
678
  class CanvaRenderManager {
679
  constructor(device, canvas) {
680
  this.device = device;
681
- const ctx = canvas.getContext("webgpu");
 
 
 
 
682
  if (ctx == null) {
683
  throw Error("Cannot bind WebGPU context");
684
  }
@@ -2022,7 +2026,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
2022
  */
2023
  constructor(mod, device) {
2024
  this.mod = mod;
2025
- this.mod.getFunction("vm_initialization")(new Scalar(device.deviceType, "int"), new Scalar(device.deviceId, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"),
2026
  // explicitly specify host device type
2027
  new Scalar(DeviceStrToEnum.cpu, "int"), new Scalar(0, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"));
2028
  }
@@ -2388,7 +2392,10 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
2388
  */
2389
  fetchNDArrayCache(ndarrayCacheUrl, device) {
2390
  return __awaiter(this, void 0, void 0, function* () {
2391
- const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href;
 
 
 
2392
  var list;
2393
  try {
2394
  list = yield (yield fetch(jsonUrl)).json();
@@ -2420,12 +2427,10 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
2420
  const reportCallback = (iter) => {
2421
  // report
2422
  for (let j = 0; j < this.initProgressCallback.length; ++j) {
2423
- let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
2424
- text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. ";
2425
- text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, ";
2426
- text += timeElapsed + " secs elapsed.";
2427
- text += " It can take a while when we first visit this page to populate the cache.";
2428
- text += " Later refreshes will become faster.";
2429
  this.initProgressCallback[j]({
2430
  progress: fetchedBytes / totalBytes,
2431
  timeElapsed: timeElapsed,
@@ -2977,7 +2982,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
2977
  }
2978
  wrapJSFuncAsPackedCFunc(func) {
2979
  const lib = this.lib;
2980
- return (argValues, argCodes, nargs, ret,
2981
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
2982
  _handle) => {
2983
  const jsArgs = [];
@@ -3454,7 +3459,7 @@ fn fragment_clear(@location(0) uv : vec2<f32>) -> @location(0) vec4<f32> {
3454
  const localSession = flocal();
3455
  support.assert(localSession instanceof runtime.Module);
3456
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
3457
- this.inst.registerFunc("rpc.WasmSession",
3458
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
3459
  (_args) => {
3460
  return localSession;
 
678
  class CanvaRenderManager {
679
  constructor(device, canvas) {
680
  this.device = device;
681
+ const ctx = canvas.getContext("webgpu", {
682
+ alpha: false,
683
+ antialias: false,
684
+ powerPreference: "high-performance",
685
+ });
686
  if (ctx == null) {
687
  throw Error("Cannot bind WebGPU context");
688
  }
 
2026
  */
2027
  constructor(mod, device) {
2028
  this.mod = mod;
2029
+ this.mod.getFunction("vm_initialization")(new Scalar(device.deviceType, "int"), new Scalar(device.deviceId, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"),
2030
  // explicitly specify host device type
2031
  new Scalar(DeviceStrToEnum.cpu, "int"), new Scalar(0, "int"), new Scalar(2 /* POOLED_ALLOCATOR */, "int"));
2032
  }
 
2392
  */
2393
  fetchNDArrayCache(ndarrayCacheUrl, device) {
2394
  return __awaiter(this, void 0, void 0, function* () {
2395
+ const cacheExists = yield caches.has("tvmjs");
2396
+ const jsonUrl = cacheExists
2397
+ ? "/lib/WebLLM/vicuna-7b/ndarray-cache.json"
2398
+ : new URL("ndarray-cache.json", ndarrayCacheUrl).href;
2399
  var list;
2400
  try {
2401
  list = yield (yield fetch(jsonUrl)).json();
 
2427
  const reportCallback = (iter) => {
2428
  // report
2429
  for (let j = 0; j < this.initProgressCallback.length; ++j) {
2430
+ let text = "[" + iter + "/" + list.length + "]: ";
2431
+ text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + " MB, ";
2432
+ text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% complete, ";
2433
+ text += timeElapsed + " secs";
 
 
2434
  this.initProgressCallback[j]({
2435
  progress: fetchedBytes / totalBytes,
2436
  timeElapsed: timeElapsed,
 
2982
  }
2983
  wrapJSFuncAsPackedCFunc(func) {
2984
  const lib = this.lib;
2985
+ return (argValues, argCodes, nargs, ret,
2986
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
2987
  _handle) => {
2988
  const jsArgs = [];
 
3459
  const localSession = flocal();
3460
  support.assert(localSession instanceof runtime.Module);
3461
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
3462
+ this.inst.registerFunc("rpc.WasmSession",
3463
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
3464
  (_args) => {
3465
  return localSession;