mirror of
https://github.com/BillyOutlast/stash-box.git
synced 2026-02-04 11:01:17 +01:00
356 lines
7.6 KiB
Go
356 lines
7.6 KiB
Go
package sqlx
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
)
|
|
|
|
type queryBuilder struct {
|
|
Table table
|
|
Body string
|
|
|
|
whereClauses []string
|
|
havingClauses []string
|
|
args []interface{}
|
|
|
|
Sort string
|
|
Pagination string
|
|
}
|
|
|
|
func newQueryBuilder(t table) *queryBuilder {
|
|
ret := &queryBuilder{
|
|
Table: t,
|
|
}
|
|
|
|
tableName := t.Name()
|
|
ret.Body = "SELECT " + tableName + ".* FROM " + tableName + " "
|
|
|
|
return ret
|
|
}
|
|
|
|
func newDeleteQueryBuilder(t table) *queryBuilder {
|
|
ret := &queryBuilder{
|
|
Table: t,
|
|
}
|
|
|
|
tableName := t.Name()
|
|
ret.Body = "DELETE FROM " + tableName + " "
|
|
|
|
return ret
|
|
}
|
|
|
|
func (qb *queryBuilder) AddJoin(jt table, on string) {
|
|
qb.Body += " JOIN " + jt.Name() + " ON " + on
|
|
}
|
|
|
|
func (qb *queryBuilder) AddJoinTableFilter(tj tableJoin, query string, group bool, having *string, not bool, args ...interface{}) {
|
|
clause := fmt.Sprintf(" JOIN (SELECT %[1]s.%[2]s FROM %[1]s WHERE %[3]s", tj.Name(), tj.joinColumn, query)
|
|
if group {
|
|
clause += fmt.Sprintf(" GROUP BY %s.%s", tj.Name(), tj.joinColumn)
|
|
|
|
if having != nil {
|
|
clause += fmt.Sprintf(" HAVING %s", *having)
|
|
}
|
|
}
|
|
|
|
clause += fmt.Sprintf(") %[1]s ON %[3]s.id = %[1]s.%[2]s", tj.Name(), tj.joinColumn, tj.primaryTable)
|
|
|
|
if not {
|
|
clause = " LEFT" + clause
|
|
qb.AddWhere(fmt.Sprintf("%s.%s IS NULL", tj.Name(), tj.joinColumn))
|
|
}
|
|
|
|
qb.Body += clause
|
|
qb.AddArg(args...)
|
|
}
|
|
|
|
func (qb *queryBuilder) AddWhere(clauses ...string) {
|
|
qb.whereClauses = append(qb.whereClauses, clauses...)
|
|
}
|
|
|
|
func (qb *queryBuilder) Eq(column string, arg interface{}) {
|
|
qb.AddWhere(column + " = ?")
|
|
qb.AddArg(arg)
|
|
}
|
|
|
|
func (qb *queryBuilder) NotEq(column string, arg interface{}) {
|
|
qb.AddWhere(column + " != ?")
|
|
qb.AddArg(arg)
|
|
}
|
|
|
|
func (qb *queryBuilder) IsNull(column string) {
|
|
qb.AddWhere(column + " is NULL")
|
|
}
|
|
|
|
func (qb *queryBuilder) IsNotNull(column string) {
|
|
qb.AddWhere(column + " is not NULL")
|
|
}
|
|
|
|
func (qb *queryBuilder) AddHaving(clauses ...string) {
|
|
if len(clauses) == 1 && clauses[0] == "" {
|
|
return
|
|
}
|
|
qb.havingClauses = append(qb.havingClauses, clauses...)
|
|
}
|
|
|
|
func (qb *queryBuilder) AddArg(args ...interface{}) {
|
|
qb.args = append(qb.args, args...)
|
|
}
|
|
|
|
func (qb queryBuilder) buildBody(isCount bool) string {
|
|
body := qb.Body
|
|
|
|
if len(qb.whereClauses) > 0 {
|
|
body = body + " WHERE " + strings.Join(qb.whereClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
if len(qb.havingClauses) > 0 {
|
|
body = body + " GROUP BY " + qb.Table.Name() + ".id HAVING " + strings.Join(qb.havingClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
|
|
if !isCount {
|
|
body += qb.Sort
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func (qb queryBuilder) buildCountQuery() string {
|
|
return "SELECT COUNT(*) as count FROM (" + qb.buildBody(true) + ") as temp"
|
|
}
|
|
|
|
func (qb queryBuilder) buildQuery() string {
|
|
return qb.buildBody(false) + qb.Pagination
|
|
}
|
|
|
|
type optionalValue interface {
|
|
IsValid() bool
|
|
}
|
|
|
|
func ensureTx(txn *txnState) {
|
|
if !txn.InTxn() {
|
|
panic("must use a transaction")
|
|
}
|
|
}
|
|
|
|
func getByID(txn *txnState, t string, id uuid.UUID, object interface{}) error {
|
|
query := txn.DB().Rebind(`SELECT * FROM ` + t + ` WHERE id = ? LIMIT 1`)
|
|
return txn.DB().GetContext(txn.ctx, object, query, id)
|
|
}
|
|
|
|
func insertObject(txn *txnState, t string, object interface{}, conflictHandling *string) error {
|
|
ensureTx(txn)
|
|
fields, values := sqlGenKeysCreate(object)
|
|
|
|
conflictClause := ""
|
|
if conflictHandling != nil {
|
|
conflictClause = *conflictHandling
|
|
}
|
|
|
|
_, err := txn.DB().NamedExecContext(
|
|
txn.ctx,
|
|
`INSERT INTO `+t+` (`+fields+`)
|
|
VALUES (`+values+`)
|
|
`+conflictClause+`
|
|
`,
|
|
object,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func updateObjectByID(txn *txnState, t string, object interface{}, updateEmptyValues bool) error {
|
|
ensureTx(txn)
|
|
_, err := txn.DB().NamedExecContext(
|
|
txn.ctx,
|
|
`UPDATE `+t+` SET `+sqlGenKeys(object, updateEmptyValues)+` WHERE `+t+`.id = :id`,
|
|
object,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
func executeDeleteQuery(tableName string, id uuid.UUID, txn *txnState) error {
|
|
ensureTx(txn)
|
|
idColumnName := getColumn(tableName, "id")
|
|
query := txn.DB().Rebind(`DELETE FROM ` + tableName + ` WHERE ` + idColumnName + ` = ?`)
|
|
_, err := txn.DB().ExecContext(txn.ctx, query, id)
|
|
return err
|
|
}
|
|
|
|
func softDeleteObjectByID(txn *txnState, t string, id uuid.UUID) error {
|
|
ensureTx(txn)
|
|
idColumnName := getColumn(t, "id")
|
|
query := txn.DB().Rebind(`UPDATE ` + t + ` SET deleted=TRUE WHERE ` + idColumnName + ` = ?`)
|
|
_, err := txn.DB().ExecContext(txn.ctx, query, id)
|
|
return err
|
|
}
|
|
|
|
func deleteObjectsByColumn(txn *txnState, t string, column string, value interface{}) error {
|
|
ensureTx(txn)
|
|
query := txn.DB().Rebind(`DELETE FROM ` + t + ` WHERE ` + column + ` = ?`)
|
|
_, err := txn.DB().ExecContext(txn.ctx, query, value)
|
|
return err
|
|
}
|
|
|
|
func getColumn(tableName string, columnName string) string {
|
|
return tableName + "." + columnName
|
|
}
|
|
|
|
func sqlGenKeysCreate(i interface{}) (string, string) {
|
|
var fields []string
|
|
var values []string
|
|
|
|
addPlaceholder := func(key string) {
|
|
fields = append(fields, fieldQuote(key))
|
|
values = append(values, ":"+key)
|
|
}
|
|
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
// get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
switch t := v.Field(i).Interface().(type) {
|
|
case string:
|
|
if t != "" {
|
|
addPlaceholder(key)
|
|
}
|
|
case int, int64, float64:
|
|
if t != 0 {
|
|
addPlaceholder(key)
|
|
}
|
|
case uuid.UUID:
|
|
if t != uuid.Nil {
|
|
addPlaceholder(key)
|
|
}
|
|
case bool:
|
|
addPlaceholder(key)
|
|
case time.Time:
|
|
if !t.IsZero() {
|
|
addPlaceholder(key)
|
|
}
|
|
case optionalValue:
|
|
if t.IsValid() {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullString:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullBool:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullInt64:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case uuid.NullUUID:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullFloat64:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
case sql.NullTime:
|
|
if t.Valid {
|
|
addPlaceholder(key)
|
|
}
|
|
default:
|
|
reflectValue := reflect.ValueOf(t)
|
|
isNil := reflectValue.IsNil()
|
|
if !isNil {
|
|
addPlaceholder(key)
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(fields, ", "), strings.Join(values, ", ")
|
|
}
|
|
|
|
func sqlGenKeys(i interface{}, partial bool) string {
|
|
var query []string
|
|
|
|
addKey := func(key string) {
|
|
query = append(query, fmt.Sprintf("%s=:%s", fieldQuote(key), key))
|
|
}
|
|
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
// get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
if key == "id" {
|
|
continue
|
|
}
|
|
switch t := v.Field(i).Interface().(type) {
|
|
case string:
|
|
if partial || t != "" {
|
|
addKey(key)
|
|
}
|
|
case uuid.UUID:
|
|
if partial || t != uuid.Nil {
|
|
addKey(key)
|
|
}
|
|
case int, int64, float64:
|
|
if partial || t != 0 {
|
|
addKey(key)
|
|
}
|
|
case bool:
|
|
addKey(key)
|
|
case time.Time:
|
|
if !t.IsZero() {
|
|
addKey(key)
|
|
}
|
|
case optionalValue:
|
|
if partial || t.IsValid() {
|
|
addKey(key)
|
|
}
|
|
case sql.NullString:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullBool:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullInt64:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case uuid.NullUUID:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullFloat64:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
case sql.NullTime:
|
|
if partial || t.Valid {
|
|
addKey(key)
|
|
}
|
|
default:
|
|
reflectValue := reflect.ValueOf(t)
|
|
isNil := reflectValue.IsNil()
|
|
if !isNil {
|
|
addKey(key)
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(query, ", ")
|
|
}
|
|
|
|
func fieldQuote(field string) string {
|
|
return `"` + field + `"`
|
|
}
|
|
|
|
func nullsLast() string {
|
|
return " NULLS LAST "
|
|
}
|