Files
stash-box/pkg/sqlx/sql.go
2025-01-02 14:43:41 +01:00

356 lines
7.6 KiB
Go

package sqlx
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/gofrs/uuid"
)
type queryBuilder struct {
Table table
Body string
whereClauses []string
havingClauses []string
args []interface{}
Sort string
Pagination string
}
func newQueryBuilder(t table) *queryBuilder {
ret := &queryBuilder{
Table: t,
}
tableName := t.Name()
ret.Body = "SELECT " + tableName + ".* FROM " + tableName + " "
return ret
}
func newDeleteQueryBuilder(t table) *queryBuilder {
ret := &queryBuilder{
Table: t,
}
tableName := t.Name()
ret.Body = "DELETE FROM " + tableName + " "
return ret
}
func (qb *queryBuilder) AddJoin(jt table, on string) {
qb.Body += " JOIN " + jt.Name() + " ON " + on
}
func (qb *queryBuilder) AddJoinTableFilter(tj tableJoin, query string, group bool, having *string, not bool, args ...interface{}) {
clause := fmt.Sprintf(" JOIN (SELECT %[1]s.%[2]s FROM %[1]s WHERE %[3]s", tj.Name(), tj.joinColumn, query)
if group {
clause += fmt.Sprintf(" GROUP BY %s.%s", tj.Name(), tj.joinColumn)
if having != nil {
clause += fmt.Sprintf(" HAVING %s", *having)
}
}
clause += fmt.Sprintf(") %[1]s ON %[3]s.id = %[1]s.%[2]s", tj.Name(), tj.joinColumn, tj.primaryTable)
if not {
clause = " LEFT" + clause
qb.AddWhere(fmt.Sprintf("%s.%s IS NULL", tj.Name(), tj.joinColumn))
}
qb.Body += clause
qb.AddArg(args...)
}
func (qb *queryBuilder) AddWhere(clauses ...string) {
qb.whereClauses = append(qb.whereClauses, clauses...)
}
func (qb *queryBuilder) Eq(column string, arg interface{}) {
qb.AddWhere(column + " = ?")
qb.AddArg(arg)
}
func (qb *queryBuilder) NotEq(column string, arg interface{}) {
qb.AddWhere(column + " != ?")
qb.AddArg(arg)
}
func (qb *queryBuilder) IsNull(column string) {
qb.AddWhere(column + " is NULL")
}
func (qb *queryBuilder) IsNotNull(column string) {
qb.AddWhere(column + " is not NULL")
}
func (qb *queryBuilder) AddHaving(clauses ...string) {
if len(clauses) == 1 && clauses[0] == "" {
return
}
qb.havingClauses = append(qb.havingClauses, clauses...)
}
func (qb *queryBuilder) AddArg(args ...interface{}) {
qb.args = append(qb.args, args...)
}
func (qb queryBuilder) buildBody(isCount bool) string {
body := qb.Body
if len(qb.whereClauses) > 0 {
body = body + " WHERE " + strings.Join(qb.whereClauses, " AND ") // TODO handle AND or OR
}
if len(qb.havingClauses) > 0 {
body = body + " GROUP BY " + qb.Table.Name() + ".id HAVING " + strings.Join(qb.havingClauses, " AND ") // TODO handle AND or OR
}
if !isCount {
body += qb.Sort
}
return body
}
func (qb queryBuilder) buildCountQuery() string {
return "SELECT COUNT(*) as count FROM (" + qb.buildBody(true) + ") as temp"
}
func (qb queryBuilder) buildQuery() string {
return qb.buildBody(false) + qb.Pagination
}
type optionalValue interface {
IsValid() bool
}
func ensureTx(txn *txnState) {
if !txn.InTxn() {
panic("must use a transaction")
}
}
func getByID(txn *txnState, t string, id uuid.UUID, object interface{}) error {
query := txn.DB().Rebind(`SELECT * FROM ` + t + ` WHERE id = ? LIMIT 1`)
return txn.DB().GetContext(txn.ctx, object, query, id)
}
func insertObject(txn *txnState, t string, object interface{}, conflictHandling *string) error {
ensureTx(txn)
fields, values := sqlGenKeysCreate(object)
conflictClause := ""
if conflictHandling != nil {
conflictClause = *conflictHandling
}
_, err := txn.DB().NamedExecContext(
txn.ctx,
`INSERT INTO `+t+` (`+fields+`)
VALUES (`+values+`)
`+conflictClause+`
`,
object,
)
return err
}
func updateObjectByID(txn *txnState, t string, object interface{}, updateEmptyValues bool) error {
ensureTx(txn)
_, err := txn.DB().NamedExecContext(
txn.ctx,
`UPDATE `+t+` SET `+sqlGenKeys(object, updateEmptyValues)+` WHERE `+t+`.id = :id`,
object,
)
return err
}
func executeDeleteQuery(tableName string, id uuid.UUID, txn *txnState) error {
ensureTx(txn)
idColumnName := getColumn(tableName, "id")
query := txn.DB().Rebind(`DELETE FROM ` + tableName + ` WHERE ` + idColumnName + ` = ?`)
_, err := txn.DB().ExecContext(txn.ctx, query, id)
return err
}
func softDeleteObjectByID(txn *txnState, t string, id uuid.UUID) error {
ensureTx(txn)
idColumnName := getColumn(t, "id")
query := txn.DB().Rebind(`UPDATE ` + t + ` SET deleted=TRUE WHERE ` + idColumnName + ` = ?`)
_, err := txn.DB().ExecContext(txn.ctx, query, id)
return err
}
func deleteObjectsByColumn(txn *txnState, t string, column string, value interface{}) error {
ensureTx(txn)
query := txn.DB().Rebind(`DELETE FROM ` + t + ` WHERE ` + column + ` = ?`)
_, err := txn.DB().ExecContext(txn.ctx, query, value)
return err
}
func getColumn(tableName string, columnName string) string {
return tableName + "." + columnName
}
func sqlGenKeysCreate(i interface{}) (string, string) {
var fields []string
var values []string
addPlaceholder := func(key string) {
fields = append(fields, fieldQuote(key))
values = append(values, ":"+key)
}
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
// get key for struct tag
rawKey := v.Type().Field(i).Tag.Get("db")
key := strings.Split(rawKey, ",")[0]
switch t := v.Field(i).Interface().(type) {
case string:
if t != "" {
addPlaceholder(key)
}
case int, int64, float64:
if t != 0 {
addPlaceholder(key)
}
case uuid.UUID:
if t != uuid.Nil {
addPlaceholder(key)
}
case bool:
addPlaceholder(key)
case time.Time:
if !t.IsZero() {
addPlaceholder(key)
}
case optionalValue:
if t.IsValid() {
addPlaceholder(key)
}
case sql.NullString:
if t.Valid {
addPlaceholder(key)
}
case sql.NullBool:
if t.Valid {
addPlaceholder(key)
}
case sql.NullInt64:
if t.Valid {
addPlaceholder(key)
}
case uuid.NullUUID:
if t.Valid {
addPlaceholder(key)
}
case sql.NullFloat64:
if t.Valid {
addPlaceholder(key)
}
case sql.NullTime:
if t.Valid {
addPlaceholder(key)
}
default:
reflectValue := reflect.ValueOf(t)
isNil := reflectValue.IsNil()
if !isNil {
addPlaceholder(key)
}
}
}
return strings.Join(fields, ", "), strings.Join(values, ", ")
}
func sqlGenKeys(i interface{}, partial bool) string {
var query []string
addKey := func(key string) {
query = append(query, fmt.Sprintf("%s=:%s", fieldQuote(key), key))
}
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
// get key for struct tag
rawKey := v.Type().Field(i).Tag.Get("db")
key := strings.Split(rawKey, ",")[0]
if key == "id" {
continue
}
switch t := v.Field(i).Interface().(type) {
case string:
if partial || t != "" {
addKey(key)
}
case uuid.UUID:
if partial || t != uuid.Nil {
addKey(key)
}
case int, int64, float64:
if partial || t != 0 {
addKey(key)
}
case bool:
addKey(key)
case time.Time:
if !t.IsZero() {
addKey(key)
}
case optionalValue:
if partial || t.IsValid() {
addKey(key)
}
case sql.NullString:
if partial || t.Valid {
addKey(key)
}
case sql.NullBool:
if partial || t.Valid {
addKey(key)
}
case sql.NullInt64:
if partial || t.Valid {
addKey(key)
}
case uuid.NullUUID:
if partial || t.Valid {
addKey(key)
}
case sql.NullFloat64:
if partial || t.Valid {
addKey(key)
}
case sql.NullTime:
if partial || t.Valid {
addKey(key)
}
default:
reflectValue := reflect.ValueOf(t)
isNil := reflectValue.IsNil()
if !isNil {
addKey(key)
}
}
}
return strings.Join(query, ", ")
}
func fieldQuote(field string) string {
return `"` + field + `"`
}
func nullsLast() string {
return " NULLS LAST "
}