Spaces:
Running
Running
/** | |
* Database abstraction layer that's vaguely ORM-like. | |
* Modern (Promises, strict types, tagged template literals), but ORMs | |
* are a bit _too_ magical for me, so none of that magic here. | |
* | |
* @author Zarel | |
*/ | |
import * as mysql from 'mysql2'; | |
import * as pg from 'pg'; | |
export type BasicSQLValue = string | number | null; | |
export type SQLRow = { [k: string]: BasicSQLValue }; | |
export type SQLValue = BasicSQLValue | SQLStatement | PartialOrSQL<SQLRow> | BasicSQLValue[] | undefined; | |
export function isSQL(value: any): value is SQLStatement { | |
/** | |
* This addresses a scenario where objects get out of sync due to hotpatching. | |
* Table A is instantiated, and retains SQLStatement at that specific point in time. Consumer A is also instantiated at | |
* the same time, and both can interact freely, since consumer A and table A share the same reference to SQLStatement. | |
* However, when consumer A is hotpatched, consumer A imports a new instance of SQLStatement. Thus, when consumer A | |
* provides that new SQLStatement, it does not pass the `instanceof SQLStatement` check in Table A, | |
* since table A is still referencing he old SQLStatement (checking that the new is an instance of the old). | |
* This does not work. Thus, we're forced to check constructor name instead. | |
*/ | |
return value instanceof SQLStatement || ( | |
// assorted safety checks to be sure it'll actually work (theoretically preventing certain attacks) | |
value?.constructor.name === 'SQLStatement' && (Array.isArray(value.sql) && Array.isArray(value.values)) | |
); | |
} | |
export class SQLStatement { | |
sql: string[]; | |
values: BasicSQLValue[]; | |
constructor(strings: TemplateStringsArray, values: SQLValue[]) { | |
this.sql = [strings[0]]; | |
this.values = []; | |
for (let i = 0; i < strings.length; i++) { | |
this.append(values[i], strings[i + 1]); | |
} | |
} | |
append(value: SQLValue, nextString = ''): this { | |
if (isSQL(value)) { | |
if (!value.sql.length) return this; | |
const oldLength = this.sql.length; | |
this.sql = this.sql.concat(value.sql.slice(1)); | |
this.sql[oldLength - 1] += value.sql[0]; | |
this.values = this.values.concat(value.values); | |
if (nextString) this.sql[this.sql.length - 1] += nextString; | |
} else if (typeof value === 'string' || typeof value === 'number' || value === null) { | |
this.values.push(value); | |
this.sql.push(nextString); | |
} else if (value === undefined) { | |
this.sql[this.sql.length - 1] += nextString; | |
} else if (Array.isArray(value)) { | |
if ('"`'.includes(this.sql[this.sql.length - 1].slice(-1))) { | |
// "`a`, `b`" syntax | |
const quoteChar = this.sql[this.sql.length - 1].slice(-1); | |
for (const col of value) { | |
this.append(col, `${quoteChar}, ${quoteChar}`); | |
} | |
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + nextString; | |
} else { | |
// "1, 2" syntax | |
for (const val of value) { | |
this.append(val, `, `); | |
} | |
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; | |
} | |
} else if (this.sql[this.sql.length - 1].endsWith('(')) { | |
// "(`a`, `b`) VALUES (1, 2)" syntax | |
this.sql[this.sql.length - 1] += `"`; | |
for (const col in value) { | |
this.append(col, `", "`); | |
} | |
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -4) + `") VALUES (`; | |
for (const col in value) { | |
this.append(value[col], `, `); | |
} | |
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -2) + nextString; | |
} else if (this.sql[this.sql.length - 1].toUpperCase().endsWith(' SET ')) { | |
// "`a` = 1, `b` = 2" syntax | |
this.sql[this.sql.length - 1] += `"`; | |
for (const col in value) { | |
this.append(col, `" = `); | |
this.append(value[col], `, "`); | |
} | |
this.sql[this.sql.length - 1] = this.sql[this.sql.length - 1].slice(0, -3) + nextString; | |
} else { | |
throw new Error( | |
`Objects can only appear in (obj) or after SET; ` + | |
`unrecognized: ${this.sql[this.sql.length - 1]}[obj]${nextString}` | |
); | |
} | |
return this; | |
} | |
} | |
/** | |
* Tag function for SQL, with some magic. | |
* | |
* * `` SQL`UPDATE table SET a = ${'hello"'}` `` | |
* * `` `UPDATE table SET a = 'hello'` `` | |
* | |
* Values surrounded by `"` or `` ` `` become identifiers: | |
* | |
* * ``` SQL`SELECT * FROM "${'table'}"` ``` | |
* * `` `SELECT * FROM "table"` `` | |
* | |
* (Make sure to use `"` for Postgres and `` ` `` for MySQL.) | |
* | |
* Objects preceded by SET become setters: | |
* | |
* * `` SQL`UPDATE table SET ${{a: 1, b: 2}}` `` | |
* * `` `UPDATE table SET "a" = 1, "b" = 2` `` | |
* | |
* Objects surrounded by `()` become keys and values: | |
* | |
* * `` SQL`INSERT INTO table (${{a: 1, b: 2}})` `` | |
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` | |
* | |
* Arrays become lists; surrounding by `"` or `` ` `` turns them into lists of names: | |
* | |
* * `` SQL`INSERT INTO table ("${['a', 'b']}") VALUES (${[1, 2]})` `` | |
* * `` `INSERT INTO table ("a", "b") VALUES (1, 2)` `` | |
*/ | |
export function SQL(strings: TemplateStringsArray, ...values: SQLValue[]) { | |
return new SQLStatement(strings, values); | |
} | |
export interface ResultRow { [k: string]: BasicSQLValue } | |
export const connectedDatabases: Database[] = []; | |
export abstract class Database<Pool extends mysql.Pool | pg.Pool = mysql.Pool | pg.Pool, OkPacket = unknown> { | |
connection: Pool; | |
prefix: string; | |
type = ''; | |
constructor(connection: Pool, prefix = '') { | |
this.prefix = prefix; | |
this.connection = connection; | |
connectedDatabases.push(this); | |
} | |
abstract _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]]; | |
abstract _query(sql: string, values: BasicSQLValue[]): Promise<any>; | |
abstract _queryExec(sql: string, values: BasicSQLValue[]): Promise<OkPacket>; | |
abstract escapeId(param: string): string; | |
query<T = ResultRow>(sql: SQLStatement): Promise<T[]>; | |
query<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]>; | |
query<T = ResultRow>(sql?: SQLStatement) { | |
if (!sql) return (strings: any, ...rest: any) => this.query<T>(new SQLStatement(strings, rest)); | |
const [query, values] = this._resolveSQL(sql); | |
return this._query(query, values); | |
} | |
queryOne<T = ResultRow>(sql: SQLStatement): Promise<T | undefined>; | |
queryOne<T = ResultRow>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined>; | |
queryOne<T = ResultRow>(sql?: SQLStatement) { | |
if (!sql) return (strings: any, ...rest: any) => this.queryOne<T>(new SQLStatement(strings, rest)); | |
return this.query<T>(sql).then(res => Array.isArray(res) ? res[0] : res); | |
} | |
queryExec(sql: SQLStatement): Promise<OkPacket>; | |
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacket>; | |
queryExec(sql?: SQLStatement) { | |
if (!sql) return (strings: any, ...rest: any) => this.queryExec(new SQLStatement(strings, rest)); | |
const [query, values] = this._resolveSQL(sql); | |
return this._queryExec(query, values); | |
} | |
getTable<Row>(name: string, primaryKeyName: keyof Row & string | null = null): DatabaseTable<Row, this> { | |
return new DatabaseTable<Row, this>(this, name, primaryKeyName); | |
} | |
close() { | |
void this.connection.end(); | |
} | |
} | |
type PartialOrSQL<T> = { | |
[P in keyof T]?: T[P] | SQLStatement; | |
}; | |
type OkPacketOf<DB extends Database> = DB extends Database<any, infer T> ? T : never; | |
// Row extends SQLRow but TS doesn't support closed types so we can't express this | |
export class DatabaseTable<Row, DB extends Database> { | |
db: DB; | |
name: string; | |
primaryKeyName: keyof Row & string | null; | |
constructor( | |
db: DB, | |
name: string, | |
primaryKeyName: keyof Row & string | null = null | |
) { | |
this.db = db; | |
this.name = db.prefix + name; | |
this.primaryKeyName = primaryKeyName; | |
} | |
escapeId(param: string) { | |
return this.db.escapeId(param); | |
} | |
// raw | |
query<T = Row>(sql: SQLStatement): Promise<T[]>; | |
query<T = Row>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]>; | |
query<T = Row>(sql?: SQLStatement) { | |
return this.db.query<T>(sql as any) as any; | |
} | |
queryOne<T = Row>(sql: SQLStatement): Promise<T | undefined>; | |
queryOne<T = Row>(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined>; | |
queryOne<T = Row>(sql?: SQLStatement) { | |
return this.db.queryOne<T>(sql as any) as any; | |
} | |
queryExec(sql: SQLStatement): Promise<OkPacketOf<DB>>; | |
queryExec(): (strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacketOf<DB>>; | |
queryExec(sql?: SQLStatement) { | |
return this.db.queryExec(sql as any) as any; | |
} | |
// low-level | |
selectAll<T = Row>(entries?: (keyof Row & string)[] | SQLStatement): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T[]> { | |
if (!entries) entries = SQL`*`; | |
if (Array.isArray(entries)) entries = SQL`"${entries}"`; | |
return (strings, ...rest) => | |
this.query<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)}`; | |
} | |
selectOne<T = Row>(entries?: (keyof Row & string)[] | SQLStatement): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> { | |
if (!entries) entries = SQL`*`; | |
if (Array.isArray(entries)) entries = SQL`"${entries}"`; | |
return (strings, ...rest) => | |
this.queryOne<T>()`SELECT ${entries} FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; | |
} | |
updateAll(partialRow: PartialOrSQL<Row>): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacketOf<DB>> { | |
return (strings, ...rest) => | |
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(strings, rest)}`; | |
} | |
updateOne(partialRow: PartialOrSQL<Row>): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacketOf<DB>> { | |
return (s, ...r) => | |
this.queryExec()`UPDATE "${this.name}" SET ${partialRow as any} ${new SQLStatement(s, r)} LIMIT 1`; | |
} | |
deleteAll(): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacketOf<DB>> { | |
return (strings, ...rest) => | |
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)}`; | |
} | |
deleteOne(): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<OkPacketOf<DB>> { | |
return (strings, ...rest) => | |
this.queryExec()`DELETE FROM "${this.name}" ${new SQLStatement(strings, rest)} LIMIT 1`; | |
} | |
eval<T>(): | |
(strings: TemplateStringsArray, ...rest: SQLValue[]) => Promise<T | undefined> { | |
return (strings, ...rest) => | |
this.queryOne<{ result: T }>( | |
)`SELECT ${new SQLStatement(strings, rest)} AS result FROM "${this.name}" LIMIT 1` | |
.then(row => row?.result); | |
} | |
// high-level | |
insert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) { | |
return this.queryExec()`INSERT INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; | |
} | |
insertIgnore(partialRow: PartialOrSQL<Row>, where?: SQLStatement) { | |
return this.queryExec()`INSERT IGNORE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; | |
} | |
async tryInsert(partialRow: PartialOrSQL<Row>, where?: SQLStatement) { | |
try { | |
return await this.insert(partialRow, where); | |
} catch (err: any) { | |
if (err.code === 'ER_DUP_ENTRY') { | |
return undefined; | |
} | |
throw err; | |
} | |
} | |
upsert(partialRow: PartialOrSQL<Row>, partialUpdate = partialRow, where?: SQLStatement) { | |
if (this.db.type === 'pg') { | |
return this.queryExec( | |
)`INSERT INTO "${this.name}" (${partialRow as any}) ON CONFLICT (${this.primaryKeyName | |
}) DO UPDATE ${partialUpdate as any} ${where}`; | |
} | |
return this.queryExec( | |
)`INSERT INTO "${this.name}" (${partialRow as any}) ON DUPLICATE KEY UPDATE ${partialUpdate as any} ${where}`; | |
} | |
set(primaryKey: BasicSQLValue, partialRow: PartialOrSQL<Row>, where?: SQLStatement) { | |
if (!this.primaryKeyName) throw new Error(`Cannot set() without a single-column primary key`); | |
partialRow[this.primaryKeyName] = primaryKey as any; | |
return this.replace(partialRow, where); | |
} | |
replace(partialRow: PartialOrSQL<Row>, where?: SQLStatement) { | |
return this.queryExec()`REPLACE INTO "${this.name}" (${partialRow as SQLValue}) ${where}`; | |
} | |
get(primaryKey: BasicSQLValue, entries?: (keyof Row & string)[] | SQLStatement) { | |
if (!this.primaryKeyName) throw new Error(`Cannot get() without a single-column primary key`); | |
return this.selectOne(entries)`WHERE "${this.primaryKeyName}" = ${primaryKey}`; | |
} | |
delete(primaryKey: BasicSQLValue) { | |
if (!this.primaryKeyName) throw new Error(`Cannot delete() without a single-column primary key`); | |
return this.deleteAll()`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; | |
} | |
update(primaryKey: BasicSQLValue, data: PartialOrSQL<Row>) { | |
if (!this.primaryKeyName) throw new Error(`Cannot update() without a single-column primary key`); | |
return this.updateAll(data)`WHERE "${this.primaryKeyName}" = ${primaryKey} LIMIT 1`; | |
} | |
} | |
export class MySQLDatabase extends Database<mysql.Pool, mysql.OkPacket> { | |
override type = 'mysql' as const; | |
constructor(config: mysql.PoolOptions & { prefix?: string }) { | |
const prefix = config.prefix || ""; | |
if (config.prefix) { | |
config = { ...config }; | |
delete config.prefix; | |
} | |
super(mysql.createPool(config), prefix); | |
} | |
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { | |
let sql = query.sql[0]; | |
const values = []; | |
for (let i = 0; i < query.values.length; i++) { | |
const value = query.values[i]; | |
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { | |
sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); | |
} else { | |
sql += '?' + query.sql[i + 1]; | |
values.push(value); | |
} | |
} | |
return [sql, values]; | |
} | |
override _query(query: string, values: BasicSQLValue[]): Promise<any> { | |
return new Promise((resolve, reject) => { | |
this.connection.query(query, values, (e, results: any) => { | |
if (e) { | |
return reject(new Error(`${e.message} (${query}) (${values}) [${e.code}]`)); | |
} | |
if (Array.isArray(results)) { | |
for (const row of results) { | |
for (const col in row) { | |
if (Buffer.isBuffer(row[col])) row[col] = row[col].toString(); | |
} | |
} | |
} | |
return resolve(results); | |
}); | |
}); | |
} | |
override _queryExec(sql: string, values: BasicSQLValue[]): Promise<mysql.OkPacket> { | |
return this._query(sql, values); | |
} | |
override escapeId(id: string) { | |
return mysql.escapeId(id); | |
} | |
} | |
export class PGDatabase extends Database<pg.Pool, { affectedRows: number | null }> { | |
override type = 'pg' as const; | |
constructor(config: pg.PoolConfig) { | |
super(new pg.Pool(config)); | |
} | |
override _resolveSQL(query: SQLStatement): [query: string, values: BasicSQLValue[]] { | |
let sql = query.sql[0]; | |
const values = []; | |
let paramCount = 0; | |
for (let i = 0; i < query.values.length; i++) { | |
const value = query.values[i]; | |
if (query.sql[i + 1].startsWith('`') || query.sql[i + 1].startsWith('"')) { | |
sql = sql.slice(0, -1) + this.escapeId(`${value}`) + query.sql[i + 1].slice(1); | |
} else { | |
paramCount++; | |
sql += `$${paramCount}` + query.sql[i + 1]; | |
values.push(value); | |
} | |
} | |
return [sql, values]; | |
} | |
override _query(query: string, values: BasicSQLValue[]) { | |
return this.connection.query(query, values).then(res => res.rows); | |
} | |
override _queryExec(query: string, values: BasicSQLValue[]) { | |
return this.connection.query<never>(query, values).then(res => ({ affectedRows: res.rowCount })); | |
} | |
override escapeId(id: string) { | |
// @ts-expect-error @types/pg really needs to be updated | |
return pg.escapeIdentifier(id); | |
} | |
} | |