|
import { ESearchEngine, IChatInputMessage, IStreamHandler, Provider, SearchFunc, TMode } from '../interface'; |
|
import { searchWithBing, searchWithGoogle, searchWithSogou, searchWithSearXNG, searchWithChatGLM } from '../service'; |
|
import { DeepQueryPrompt, MoreQuestionsPrompt, RagQueryPrompt, TranslatePrompt } from './prompt'; |
|
import platform from '../provider'; |
|
import { Models } from '../utils/constant'; |
|
import { ESearXNGCategory } from '../search/searxng'; |
|
import util from 'util'; |
|
|
|
interface RagOptions { |
|
engine?: ESearchEngine |
|
stream?: boolean |
|
model?: string |
|
|
|
locally?: boolean |
|
provider?: Provider |
|
} |
|
|
|
|
|
|
|
export class Rag { |
|
private search: SearchFunc; |
|
private chat: (...args: any[]) => Promise<any>; |
|
private model: string; |
|
|
|
private stream: boolean; |
|
|
|
private engine: ESearchEngine; |
|
|
|
constructor(params?: RagOptions) { |
|
const { engine = ESearchEngine.SEARXNG, stream = true, model, locally, provider } = params || {}; |
|
if (!model) throw new Error('model is required'); |
|
if (locally && provider) { |
|
this.chat = platform[provider].chatStream.bind(platform[provider]); |
|
} else { |
|
const chat = processModel(model); |
|
if (!chat) throw new Error('model is not supported'); |
|
this.chat = chat; |
|
} |
|
this.model = model; |
|
this.stream = stream; |
|
console.info('[query with]:', engine, model); |
|
console.info('[query with local llm]:', locally); |
|
this.engine = engine; |
|
switch (engine) { |
|
case ESearchEngine.GOOGLE: |
|
this.search = searchWithGoogle; |
|
break; |
|
case ESearchEngine.BING: |
|
this.search = searchWithBing; |
|
break; |
|
case ESearchEngine.SOGOU: |
|
this.search = searchWithSogou; |
|
break; |
|
case ESearchEngine.SEARXNG: |
|
this.search = searchWithSearXNG; |
|
break; |
|
case ESearchEngine.CHATGLM: |
|
this.search = searchWithChatGLM; |
|
break; |
|
default: |
|
this.search = searchWithSearXNG; |
|
} |
|
} |
|
|
|
public async query(query: string, categories = [ESearXNGCategory.GENERAL], mode: TMode = 'simple', language = 'all', onMessage?: (...args: any[]) => void) { |
|
let searchQuery = query; |
|
|
|
if (categories.includes(ESearXNGCategory.SCIENCE) && this.engine === ESearchEngine.SEARXNG) { |
|
const rewrite = await this.translate(query); |
|
if (rewrite) searchQuery = rewrite; |
|
} |
|
|
|
|
|
const contexts = await this.search(searchQuery, categories, language); |
|
console.log(`[search [${categories}] results]`, contexts.length); |
|
console.log('[search mode]', mode); |
|
const REFERENCE_COUNT = process.env.REFERENCE_COUNT || 8; |
|
const limitContexts = contexts.slice(0, +REFERENCE_COUNT); |
|
if (!this.stream) { |
|
const relatedPromise = this.getRelatedQuestions(query, limitContexts); |
|
const answerPromise = this.getAiAnswer(query, contexts); |
|
const [related, answer] = await Promise.all([relatedPromise, answerPromise]); |
|
return { |
|
related, |
|
answer, |
|
contexts: limitContexts |
|
}; |
|
} |
|
|
|
if (this.engine === ESearchEngine.SEARXNG) { |
|
const res = await this.search(query, [ESearXNGCategory.IMAGES], language); |
|
const engines = process.env.SEARXNG_IMAGES_ENGINES ? process.env.SEARXNG_IMAGES_ENGINES.split(',') : []; |
|
|
|
const images = res.filter(item => { |
|
if (!item.thumbnail) return false; |
|
if (engines.length > 0) |
|
return engines.some(engine => item.engine?.includes(engine)); |
|
return item.engine?.includes('bing') || item.engine?.includes('google'); |
|
}); |
|
|
|
for (const image of images) { |
|
onMessage?.(JSON.stringify({ image })); |
|
} |
|
} |
|
|
|
for (const context of limitContexts) { |
|
onMessage?.(JSON.stringify({ context })); |
|
} |
|
await this.getAiAnswer(query, limitContexts, mode, (msg) => { |
|
onMessage?.(JSON.stringify({ answer: msg })); |
|
}); |
|
await this.getRelatedQuestions(query, limitContexts, (msg) => { |
|
onMessage?.(JSON.stringify({ related: msg })); |
|
}); |
|
onMessage?.(null, true); |
|
} |
|
|
|
|
|
private async getRelatedQuestions(query: string, contexts: any[], onMessage?: IStreamHandler) { |
|
try { |
|
const { messages } = this.paramsFormatter(query, undefined, contexts, 'related'); |
|
const { model, stream } = this; |
|
if (!stream) { |
|
const res = await this.chat(messages, this.model); |
|
return res.split('\n'); |
|
} |
|
await this.chat(messages, onMessage, model); |
|
} catch (err) { |
|
console.error('[LLM Error]:', err); |
|
return []; |
|
} |
|
} |
|
|
|
private async getAiAnswer(query: string, contexts: any[], mode: TMode = 'simple', onMessage?: IStreamHandler) { |
|
const { model, stream } = this; |
|
try { |
|
const { messages } = this.paramsFormatter(query, mode, contexts, 'answer'); |
|
if (!stream) { |
|
const res = await this.chat(messages, this.model); |
|
return res; |
|
} |
|
await this.chat(messages, (msg: string, done: boolean) => { |
|
onMessage?.(msg, done); |
|
}, model); |
|
} catch (err: any) { |
|
console.error('[LLM Error]:', err); |
|
const msg = `[Oops~ Some errors seem to have occurred]: ${err?.message || 'Please check the console'}`; |
|
if (!stream) return msg; |
|
else onMessage?.(msg, true); |
|
} |
|
} |
|
|
|
|
|
private async translate(text: string, targetLang = 'English'): Promise<string> { |
|
try { |
|
const content = util.format(TranslatePrompt, targetLang, text); |
|
const messages: IChatInputMessage[] = [ |
|
{ |
|
role: 'user', |
|
content |
|
} |
|
]; |
|
|
|
let translated = ''; |
|
if (!this.stream) { |
|
const res = await this.chat(messages, this.model); |
|
translated = res; |
|
} else { |
|
await this.chat(messages, (msg: string) => { |
|
if (msg) translated += msg; |
|
}, this.model); |
|
} |
|
return translated; |
|
} catch (err) { |
|
console.log('[RAG Translate error]', err); |
|
return text; |
|
} |
|
} |
|
|
|
private paramsFormatter(query: string, mode: TMode = 'simple', contexts: any[], type: 'answer' | 'related') { |
|
const context = contexts.map((item, index) => `[[citation:${index + 1}]] ${item.snippet}`).join('\n\n'); |
|
let prompt = type === 'answer' ? RagQueryPrompt : MoreQuestionsPrompt; |
|
|
|
|
|
if (mode === 'deep' && type === 'answer') { |
|
prompt = DeepQueryPrompt; |
|
} |
|
|
|
const system = util.format(prompt, context); |
|
const messages: IChatInputMessage[] = [ |
|
{ |
|
role: 'user', |
|
content: `${system} ${query}` |
|
} |
|
]; |
|
return { |
|
messages |
|
}; |
|
} |
|
} |
|
|
|
function processModel(model: string) { |
|
const targetModel = Models.find(item => { |
|
return item.models.includes(model); |
|
}); |
|
if (targetModel?.platform) { |
|
const target = platform[targetModel.platform]; |
|
return target.chatStream.bind(target); |
|
} |
|
} |
|
|