/** * Database abstraction layer that's vaguely ORM-like. * Modern (Promises, strict types, tagged template literals), but ORMs * are a bit _too_ magical for me, so none of that magic here. * * @author Zarel */ import * as mysql from 'mysql2'; import * as pg from 'pg'; export type BasicSQLValue = string | number | null; export type SQLRow = { [k: string]: BasicSQLValue }; export type SQLValue = BasicSQLValue | SQLStatement | PartialOrSQL | BasicSQLValue[] | undefined; export function isSQL(value: any): value is SQLStatement { /** * This addresses a scenario where objects get out of sync due to hotpatching. * Table A is instantiated, and retains SQLStatement at that specific point in time. Consumer A is also instantiated at * the same time, and both can interact freely, since consumer A and table A share the same reference to SQLStatement. * However, when consumer A is hotpatched, consumer A imports a new instance of SQLStatement. Thus, when consumer A * provides that new SQLStatement, it does not pass the `instanceof SQLStatement` check in Table A, * since table A is still referencing he old SQLStatement (checking that the new is an instance of the old). * This does not work. Thus, we're forced to check constructor name instead. */ return value instanceof SQLStatement || ( // assorted safety checks to be sure it'll actually work (theoretically preventing certain attacks) value?.constructor.name === 'SQLStatement' && (Array.isArray(value.sql) && Array.isArray(value.values)) ); } export class SQLStatement { sql: string[]; values: BasicSQLValue[]; constructor(strings: TemplateStringsArray, values: SQLValue[]) { this.sql = [strings[0]]; this.values = []; for (let i = 0; i < strings.length; i++) { this.append(values[i], strings[i + 1]); } } append(value: SQLValue, nextString = ''): this { if (isSQL(value)) { if (!value.sql.length) return this; const oldLength = this.sql.length; this.sql = this.sql.concat(value.sql.slice(1)); this.sql[oldLength - 1] += value.sql[0]; this.values = this.values.concat(value.values); if (nextString) this.sql[this.sql.length - 1] += nextString; } else if (typeof value === 'string' || typeof value === 'number' || value === null) { this.values.push(value); this.sql.push(nextString); } else if (value === undefined) { this.sql[this.sql.length - 1] += nextString; } else if (Array.isArray(value)) { if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { // "`a`, `b`" syntax const quoteChar = this.sql[this.sql.length - 1].slice(-1); for (const col of value) { this.append(col, `${quoteChar}, ${quoteChar}`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString; } else { // "1, 2" syntax for (const val of value) { this.append(val, `, `); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; } } else if (this.sql[this.sql.length - 1].endsWith('(')) { // "(`a`, `b`) VALUES (1, 2)" syntax this.sql[this.sql.length - 1] += `"`; for (const col in value) { this.append(col, `", "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`; for (const col in value) { this.append(value[col], `, `); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; } else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) { // "`a` = 1, `b` = 2" syntax this.sql[this.sql.length - 1] += `"`; for (const col in value) { this.append(col, `" = `); this.append(value[col], `, "`); } this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString; } else { throw new Error( `Objects can only appear in (obj) or after SET; ` + `unrecognized: ${this.sql[this.sql.length - 1]}[obj]${nextString}` ); } return this; } } /** * Tag function for SQL, with some magic. * * * `` SQL`UPDATE table SET a = ${'hello"'}` `` * * `` `UPDATE table SET a = 'hello'` `` * * Values surrounded by `"` or `` ` `` become identifiers: * * * ``` SQL`SELECT * FROM "${'table'}"` ``` * * `` `SELECT * FROM "table"` `` * * (Make sure to use `"` for Postgres and `` ` `` for MySQL.) * * Objects preceded by SET become setters: * * * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` `` * * `` `UPDATE table SET "a" = 1, "b" = 2` `` * * Objects surrounded by `()` become keys and values: * * * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` `` * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` * * Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names: * * * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` `` * * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` */ export function SQL(strings: TemplateStringsArray, ...values: SQLValue[]) { return new SQLStatement(strings, values); } export interface ResultRow { [k: string]: BasicSQLValue } export const connectedDatabases: Database[] = []; export abstract class Database { connection: Pool; prefix: string; type = ''; constructor(connection: Pool, prefix = '') { this.prefix = prefix; this.connection = connection; connectedDatabases.push(this); } abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]]; abstract _query(sql: string, values: BasicSQLValue[]): Promise; abstract _queryExec(sql: string, values: BasicSQLValue[]): Promise; abstract escapeId(param: string): string; query(sql: SQLStatement): Promise; query(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; query(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.query(new SQLStatement(strings, rest)); const [query, values] = this._resolveSQL(sql); return this._query(query, values); } queryOne(sql: SQLStatement): Promise; queryOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryOne(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.queryOne(new SQLStatement(strings, rest)); return this.query(sql).then(res => Array.isArray(res) ? res[0] : res); } queryExec(sql: SQLStatement): Promise; queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryExec(sql?: SQLStatement) { if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest)); const [query, values] = this._resolveSQL(sql); return this._queryExec(query, values); } getTable(name: string, primaryKeyName: keyof Row & string | null = null): DatabaseTable { return new DatabaseTable(this, name, primaryKeyName); } close() { void this.connection.end(); } } type PartialOrSQL = { [P in keyof T]?: T[P] | SQLStatement; }; type OkPacketOf = DB extends Database ? T : never; // Row extends SQLRow but TS doesn't support closed types so we can't express this export class DatabaseTable { db: DB; name: string; primaryKeyName: keyof Row & string | null; constructor( db: DB, name: string, primaryKeyName: keyof Row & string | null = null ) { this.db = db; this.name = db.prefix + name; this.primaryKeyName = primaryKeyName; } escapeId(param: string) { return this.db.escapeId(param); } // raw query(sql: SQLStatement): Promise; query(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; query(sql?: SQLStatement) { return this.db.query(sql as any) as any; } queryOne(sql: SQLStatement): Promise; queryOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise; queryOne(sql?: SQLStatement) { return this.db.queryOne(sql as any) as any; } queryExec(sql: SQLStatement): Promise>; queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise>; queryExec(sql?: SQLStatement) { return this.db.queryExec(sql as any) as any; } // low-level selectAll(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => this.query()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } selectOne(entries?: (keyof Row & string)[] | SQLStatement): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { if (!entries) entries = SQL`*`; if (Array.isArray(entries)) entries = SQL`"${entries}"`; return (strings, ...rest) => this.queryOne()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; } updateAll(partialRow: PartialOrSQL): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; } updateOne(partialRow: PartialOrSQL): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (s, ...r) => this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`; } deleteAll(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`; } deleteOne(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise> { return (strings, ...rest) => this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; } eval(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise { return (strings, ...rest) => this.queryOne<{ result: T }>( )`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1` .then(row => row?.result); } // high-level insert(partialRow: PartialOrSQL, where?: SQLStatement) { return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } insertIgnore(partialRow: PartialOrSQL, where?: SQLStatement) { return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } async tryInsert(partialRow: PartialOrSQL, where?: SQLStatement) { try { return await this.insert(partialRow, where); } catch (err: any) { if (err.code === 'ER_DUP_ENTRY') { return undefined; } throw err; } } upsert(partialRow: PartialOrSQL, partialUpdate = partialRow, where?: SQLStatement) { if (this.db.type === 'pg') { return this.queryExec( )`INSERT INTO "${this.name}" (${partialRow as any}) ON CONFLICT (${this.primaryKeyName }) DO UPDATE ${partialUpdate as any} ${where}`; } return this.queryExec( )`INSERT INTO "${this.name}" (${partialRow as any}) ON DUPLICATE KEY UPDATE ${partialUpdate as any} ${where}`; } set(primaryKey: BasicSQLValue, partialRow: PartialOrSQL, where?: SQLStatement) { if (!this.primaryKeyName) throw new Error(`Cannot set() without a single-column primary key`); partialRow[this.primaryKeyName] = primaryKey as any; return this.replace(partialRow, where); } replace(partialRow: PartialOrSQL, where?: SQLStatement) { return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; } get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) { if (!this.primaryKeyName) throw new Error(`Cannot get() without a single-column primary key`); return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`; } delete(primaryKey: BasicSQLValue) { if (!this.primaryKeyName) throw new Error(`Cannot delete() without a single-column primary key`); return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; } update(primaryKey: BasicSQLValue, data: PartialOrSQL) { if (!this.primaryKeyName) throw new Error(`Cannot update() without a single-column primary key`); return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; } } export class MySQLDatabase extends Database { override type = 'mysql' as const; constructor(config: mysql.PoolOptions & { prefix?: string }) { const prefix = config.prefix || ""; if (config.prefix) { config = { ...config }; delete config.prefix; } super(mysql.createPool(config), prefix); } override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { let sql = query.sql[0]; const values = []; for (let i = 0; i < query.values.length; i++) { const value = query.values[i]; if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); } else { sql += '?' + query.sql[i + 1]; values.push(value); } } return [sql, values]; } override _query(query: string, values: BasicSQLValue[]): Promise { return new Promise((resolve, reject) => { this.connection.query(query, values, (e, results: any) => { if (e) { return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`)); } if (Array.isArray(results)) { for (const row of results) { for (const col in row) { if (Buffer.isBuffer(row[col])) row[col] = row[col].toString(); } } } return resolve(results); }); }); } override _queryExec(sql: string, values: BasicSQLValue[]): Promise { return this._query(sql, values); } override escapeId(id: string) { return mysql.escapeId(id); } } export class PGDatabase extends Database { override type = 'pg' as const; constructor(config: pg.PoolConfig) { super(new pg.Pool(config)); } override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { let sql = query.sql[0]; const values = []; let paramCount = 0; for (let i = 0; i < query.values.length; i++) { const value = query.values[i]; if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); } else { paramCount++; sql += `$${paramCount}` + query.sql[i + 1]; values.push(value); } } return [sql, values]; } override _query(query: string, values: BasicSQLValue[]) { return this.connection.query(query, values).then(res => res.rows); } override _queryExec(query: string, values: BasicSQLValue[]) { return this.connection.query(query, values).then(res => ({ affectedRows: res.rowCount })); } override escapeId(id: string) { // @ts-expect-error @types/pg really needs to be updated return pg.escapeIdentifier(id); } }