Add studio, tags, scene (#3)

* Add tags
* Add studio
* Add scenes support
* Add integration test
* Add gitattributes. Add test targets
* Add DB interface and refactor
* Add performer image
* Replace checksums with fingerprints
* Update dependencies
* Make performers unique on name/disambiguation
* Add first draft of README
This commit is contained in:
WithoutPants
2019-11-27 13:17:46 +11:00
committed by GitHub
parent 0f29cf0c84
commit 856ebd6c94
51 changed files with 4390 additions and 502 deletions

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
go.mod text eol=lf
go.sum text eol=lf

View File

@@ -20,6 +20,14 @@ generate:
go generate
packr2
.PHONY: test
test:
go test ./...
.PHONY: it
it:
go test -tags=integration ./...
# Runs gofmt -w on the project's source code, modifying any files that do not match its style.
.PHONY: fmt
fmt:

View File

@@ -1,2 +1,98 @@
# stash-box
Stash App's own OpenSource video indexing and Perceptual Hashing MetaData API
[![Discord](https://img.shields.io/discord/559159668438728723.svg?logo=discord)](https://discord.gg/2TsNFKt)
**stash-box is Stash App's own OpenSource video indexing and Perceptual Hashing MetaData API for porn.**
The intent of stash-box is to provide a collaborative, crowd-sourced database of porn metadata, in the same way as [MusicBrainz](https://musicbrainz.org/) does for music. The submission and editing of metadata is expected to follow the same principle as that of the MusicBrainz database. [See here](https://musicbrainz.org/doc/Editing_FAQ) for how MusicBrainz does it.
Currently, stash-box provides a graphql backend API only. There is no built in UI. The graphql playground can be accessed at `host:port/playground`. The graphql interface is at `host:port/graphql`.
# Docker install
TODO
# Bare-metal Install
Stash-box supports macOS, Windows, and Linux.
Releases TODO
## CLI
Stash-box provides some command line options. See what is currently available by running `stashdb --help`.
For example, to run stash locally on port 80 run it like this (OSX / Linux) `stashdb --host 127.0.0.1 --port 80`
## Configuration
Stash-box generates a configuration file in the current working directory when it is first started up. This configuration file is generated with the following defaults:
- running on `0.0.0.0` port `9998`
- sqlite3 database generated in the current working directory named `stashdb-go.sqlite`
- generated read (`read_api_key`) and write (`modify_api_key`) API keys. These can be deleted to disable read/write authentication (all requests will be allowed without API key)
### API keys
These are a very basic authorization method. When set, the `ApiKey` header must be set to the correct value to read/write the data. The write API key allows reading and writing. The read API key allows only reading.
## SSL (HTTPS)
Stash-box supports HTTPS with some additional work. First you must generate a SSL certificate and key combo. Here is an example using openssl:
`openssl req -x509 -newkey rsa:4096 -sha256 -days 7300 -nodes -keyout stashdb.key -out stashdb.crt -extensions san -config <(echo "[req]"; echo distinguished_name=req; echo "[san]"; echo subjectAltName=DNS:stashdb.server,IP:127.0.0.1) -subj /CN=stashdb.server`
This command would need customizing for your environment. [This link](https://stackoverflow.com/questions/10175812/how-to-create-a-self-signed-certificate-with-openssl) might be useful.
Once you have a certificate and key file name them `stashdb.crt` and `stashdb.key` and place them in the directory where stash-box is run from. Stash-box detects these and starts up using HTTPS rather than HTTP.
# FAQ
> I have a question not answered here.
Join the [Discord server](https://discord.gg/2TsNFKt).
# Development
## Install
* [Revive](https://github.com/mgechev/revive) - Configurable linter
* Go Install: `go get github.com/mgechev/revive`
* [Packr2](https://github.com/gobuffalo/packr/tree/v2.0.2/v2) - Static asset bundler
* Go Install: `go get github.com/gobuffalo/packr/v2/packr2@v2.0.2`
* [Binary Download](https://github.com/gobuffalo/packr/releases)
* [Yarn](https://yarnpkg.com/en/docs/install) - Yarn package manager
NOTE: You may need to run the `go get` commands outside the project directory to avoid modifying the projects module file.
## Environment
### macOS
TODO
### Windows
1. Download and install [Go for Windows](https://golang.org/dl/)
2. Download and install [MingW](https://sourceforge.net/projects/mingw-w64/)
3. Search for "advanced system settings" and open the system properties dialog.
1. Click the `Environment Variables` button
2. Add `GO111MODULE=on`
3. Under system variables find the `Path`. Edit and add `C:\Program Files\mingw-w64\*\mingw64\bin` (replace * with the correct path).
## Commands
* `make generate` - Generate Go GraphQL and packr2 files. This should be run if the graphql schema or schema migration files have changed.
* `make build` - Builds the binary
* `make vet` - Run `go vet`
* `make lint` - Run the linter
* `make test` - Runs the unit tests
* `make it` - Runs the unit and integration tests
## Building a release
1. Run `make generate` to create generated files
2. Run `make build` to build the executable for your current platform
## Cross compiling
TODO

1
go.mod
View File

@@ -9,7 +9,6 @@ require (
github.com/golang-migrate/migrate/v4 v4.3.1
github.com/gorilla/websocket v1.4.0
github.com/h2non/filetype v1.0.8
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jmoiron/sqlx v1.2.0
github.com/mattn/go-sqlite3 v1.10.0
github.com/pkg/errors v0.8.1

1
go.sum
View File

@@ -243,6 +243,7 @@ github.com/gobuffalo/packr v1.15.0/go.mod h1:t5gXzEhIviQwVlNx/+3SfS07GS+cZ2hn76W
github.com/gobuffalo/packr v1.15.1/go.mod h1:IeqicJ7jm8182yrVmNbM6PR4g79SjN9tZLH8KduZZwE=
github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU=
github.com/gobuffalo/packr v1.20.0/go.mod h1:JDytk1t2gP+my1ig7iI4NcVaXr886+N0ecUga6884zw=
github.com/gobuffalo/packr v1.21.0 h1:p2ujcDJQp2QTiYWcI0ByHbr/gMoCouok6M0vXs/yTYQ=
github.com/gobuffalo/packr v1.21.0/go.mod h1:H00jGfj1qFKxscFJSw8wcL4hpQtPe1PfU2wa6sg/SR0=
github.com/gobuffalo/packr/v2 v2.0.0-rc.8/go.mod h1:y60QCdzwuMwO2R49fdQhsjCPv7tLQFR0ayzxxla9zes=
github.com/gobuffalo/packr/v2 v2.0.0-rc.9/go.mod h1:fQqADRfZpEsgkc7c/K7aMew3n4aF1Kji7+lIZeR98Fc=

View File

@@ -16,4 +16,10 @@ struct_tag: gqlgen
models:
Performer:
model: github.com/stashapp/stashdb/pkg/models.Performer
Tag:
model: github.com/stashapp/stashdb/pkg/models.Tag
Studio:
model: github.com/stashapp/stashdb/pkg/models.Studio
Scene:
model: github.com/stashapp/stashdb/pkg/models.Scene

View File

@@ -29,9 +29,12 @@ type Query {
#### Scenes ####
# ids and checksums should be unique
"""Find a scene by ID or checksum"""
findScene(id: ID, checksum: String): Scene
# ids should be unique
"""Find a scene by ID"""
findScene(id: ID!): Scene
"""Finds a scene by an algorithm-specific checksum"""
findSceneByFingerprint(fingerprint: FingerprintInput!): [Scene!]!
queryScenes(scene_filter: SceneFilterType, filter: QuerySpec): QueryScenesResultType!

View File

@@ -109,6 +109,8 @@ input PerformerCreateInput {
career_end_year: Int
tattoos: [BodyModificationInput!]
piercings: [BodyModificationInput!]
"""Should be base64 encoded"""
image: String
}
input PerformerUpdateInput {
@@ -130,6 +132,8 @@ input PerformerUpdateInput {
career_end_year: Int
tattoos: [BodyModificationInput!]
piercings: [BodyModificationInput!]
"""Should be base64 encoded"""
image: String
}
input PerformerDestroyInput {
@@ -154,6 +158,8 @@ input PerformerEditDetailsInput {
career_end_year: Int
tattoos: [BodyModificationInput!]
piercings: [BodyModificationInput!]
"""Should be base64 encoded"""
image: String
}
input PerformerEditInput {

View File

@@ -10,6 +10,20 @@ input PerformerAppearanceInput {
as: String
}
enum FingerprintAlgorithm {
MD5
}
type Fingerprint {
hash: String!
algorithm: FingerprintAlgorithm!
}
input FingerprintInput {
hash: String!
algorithm: FingerprintAlgorithm!
}
type Scene {
id: ID!
title: String
@@ -20,7 +34,7 @@ type Scene {
studio: Studio
tags: [Tag!]!
performers: [PerformerAppearance!]!
checksums: [String!]!
fingerprints: [Fingerprint!]!
}
input SceneCreateInput {
@@ -31,7 +45,7 @@ input SceneCreateInput {
studio_id: ID
performers: [PerformerAppearanceInput!]
tag_ids: [ID!]
checksums: [String!]!
fingerprints: [FingerprintInput!]!
}
input SceneUpdateInput {
@@ -43,6 +57,7 @@ input SceneUpdateInput {
studio_id: ID
performers: [PerformerAppearanceInput!]
tag_ids: [ID!]
fingerprints: [FingerprintInput!]
}
input SceneDestroyInput {
@@ -57,7 +72,7 @@ input SceneEditDetailsInput {
studio_id: ID
performers: [PerformerAppearanceInput!]
tag_ids: [ID!]
checksums: [String!]
fingerprints: [FingerprintInput!]
}
input SceneEditInput {
@@ -82,6 +97,8 @@ type SceneEdit {
removed_performers: [PerformerAppearance!]
added_tags: [Tag!]
removed_tags: [Tag!]
added_fingerprints: [Fingerprint!]
removed_fingerprints: [Fingerprint!]
}
type QueryScenesResultType {

View File

@@ -15,7 +15,7 @@ input StudioCreateInput {
input StudioUpdateInput {
id: ID!
name: String!
name: String
urls: [URLInput!]
parent_id: ID
child_studio_ids: [ID!]

11
pkg/api/context_keys.go Normal file
View File

@@ -0,0 +1,11 @@
package api
// https://stackoverflow.com/questions/40891345/fix-should-not-use-basic-type-string-as-key-in-context-withvalue-golint
type key int
const (
performerKey key = 1
sceneKey key = 2
studioKey key = 3
)

176
pkg/api/integration_test.go Normal file
View File

@@ -0,0 +1,176 @@
// +build integration
package api_test
import (
"context"
"strconv"
"testing"
"github.com/stashapp/stashdb/pkg/api"
dbtest "github.com/stashapp/stashdb/pkg/database/databasetest"
"github.com/stashapp/stashdb/pkg/models"
"github.com/99designs/gqlgen/graphql"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
func TestMain(m *testing.M) {
dbtest.TestWithDatabase(m, nil)
}
type testRunner struct {
t *testing.T
resolver api.Resolver
ctx context.Context
err error
}
var performerSuffix int
var studioSuffix int
var tagSuffix int
var sceneChecksumSuffix int
func createTestRunner(t *testing.T) *testRunner {
resolver := api.Resolver{}
ctx := context.TODO()
ctx = context.WithValue(ctx, api.ContextRole, api.ModifyRole)
return &testRunner{
t: t,
resolver: resolver,
ctx: ctx,
}
}
func (t *testRunner) doTest(test func()) {
if t.t.Failed() {
return
}
test()
}
func (t *testRunner) fieldMismatch(expected interface{}, actual interface{}, field string) {
t.t.Helper()
t.t.Errorf("%s mismatch: %+v != %+v", field, actual, expected)
}
func (t *testRunner) updateContext(fields []string) context.Context {
variables := make(map[string]interface{})
for _, v := range fields {
variables[v] = true
}
rctx := &graphql.RequestContext{
Variables: variables,
}
return graphql.WithRequestContext(t.ctx, rctx)
}
func (s *testRunner) generatePerformerName() string {
performerSuffix += 1
return "performer-" + strconv.Itoa(performerSuffix)
}
func (s *testRunner) createTestPerformer(input *models.PerformerCreateInput) (*models.Performer, error) {
s.t.Helper()
if input == nil {
input = &models.PerformerCreateInput{
Name: s.generatePerformerName(),
}
}
createdPerformer, err := s.resolver.Mutation().PerformerCreate(s.ctx, *input)
if err != nil {
s.t.Errorf("Error creating performer: %s", err.Error())
return nil, err
}
return createdPerformer, nil
}
func (s *testRunner) generateStudioName() string {
studioSuffix += 1
return "studio-" + strconv.Itoa(studioSuffix)
}
func (s *testRunner) createTestStudio(input *models.StudioCreateInput) (*models.Studio, error) {
s.t.Helper()
if input == nil {
input = &models.StudioCreateInput{
Name: s.generateStudioName(),
}
}
createdStudio, err := s.resolver.Mutation().StudioCreate(s.ctx, *input)
if err != nil {
s.t.Errorf("Error creating studio: %s", err.Error())
return nil, err
}
return createdStudio, nil
}
func (s *testRunner) generateTagName() string {
tagSuffix += 1
return "tag-" + strconv.Itoa(tagSuffix)
}
func (s *testRunner) createTestTag(input *models.TagCreateInput) (*models.Tag, error) {
s.t.Helper()
if input == nil {
input = &models.TagCreateInput{
Name: s.generateTagName(),
}
}
createdTag, err := s.resolver.Mutation().TagCreate(s.ctx, *input)
if err != nil {
s.t.Errorf("Error creating tag: %s", err.Error())
return nil, err
}
return createdTag, nil
}
func (s *testRunner) createTestScene(input *models.SceneCreateInput) (*models.Scene, error) {
s.t.Helper()
if input == nil {
title := "title"
input = &models.SceneCreateInput{
Title: &title,
Fingerprints: []*models.FingerprintInput{
s.generateSceneFingerprint(),
},
}
}
createdScene, err := s.resolver.Mutation().SceneCreate(s.ctx, *input)
if err != nil {
s.t.Errorf("Error creating scene: %s", err.Error())
return nil, err
}
return createdScene, nil
}
func (s *testRunner) generateSceneFingerprint() *models.FingerprintInput {
sceneChecksumSuffix += 1
return &models.FingerprintInput{
Algorithm: "MD5",
Hash: "scene-" + strconv.Itoa(sceneChecksumSuffix),
}
}
func oneNil(l interface{}, r interface{}) bool {
return l != r && (l == nil || r == nil)
}
func bothNil(l interface{}, r interface{}) bool {
return l == nil && r == nil
}

View File

@@ -0,0 +1,504 @@
// +build integration
package api_test
import (
"reflect"
"strconv"
"testing"
"github.com/stashapp/stashdb/pkg/models"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
type performerTestRunner struct {
testRunner
}
func createPerformerTestRunner(t *testing.T) *performerTestRunner {
return &performerTestRunner{
testRunner: *createTestRunner(t),
}
}
func (s *performerTestRunner) testCreatePerformer() {
disambiguation := "Disambiguation"
country := "USA"
height := 182
cupSize := "C"
bandSize := 32
careerStartYear := 2000
tattooDesc := "Foobar"
gender := models.GenderEnumFemale
ethnicity := models.EthnicityEnumCaucasian
eyeColor := models.EyeColorEnumBlue
hairColor := models.HairColorEnumBlonde
breastType := models.BreastTypeEnumNatural
input := models.PerformerCreateInput{
Name: s.generatePerformerName(),
Disambiguation: &disambiguation,
Aliases: []string{"Alias1", "Alias2"},
Gender: &gender,
Urls: []*models.URLInput{
&models.URLInput{
URL: "URL",
Type: "Type",
},
},
Birthdate: &models.FuzzyDateInput{
Date: "2001-02-03",
Accuracy: models.DateAccuracyEnumDay,
},
Ethnicity: &ethnicity,
Country: &country,
EyeColor: &eyeColor,
HairColor: &hairColor,
Height: &height,
Measurements: &models.MeasurementsInput{
CupSize: &cupSize,
BandSize: &bandSize,
Waist: &bandSize,
Hip: &bandSize,
},
BreastType: &breastType,
CareerStartYear: &careerStartYear,
CareerEndYear: nil,
Tattoos: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Inner thigh",
Description: &tattooDesc,
},
},
Piercings: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Nose",
Description: nil,
},
},
}
performer, err := s.resolver.Mutation().PerformerCreate(s.ctx, input)
if err != nil {
s.t.Errorf("Error creating performer: %s", err.Error())
return
}
s.verifyCreatedPerformer(input, performer)
}
func compareBodyMods(input []*models.BodyModificationInput, bodyMods []*models.BodyModification) bool {
if len(bodyMods) != len(input) {
return false
}
for i, v := range bodyMods {
if v.Location != input[i].Location {
return false
}
if v.Description != input[i].Description {
if v.Description == nil || input[i].Description == nil {
return false
}
if *v.Description != *input[i].Description {
return false
}
}
}
return true
}
func compareUrls(input []*models.URLInput, urls []*models.URL) bool {
if len(urls) != len(input) {
return false
}
for i, v := range urls {
if v.URL != input[i].URL || v.Type != input[i].Type {
return false
}
}
return true
}
func (s *performerTestRunner) verifyCreatedPerformer(input models.PerformerCreateInput, performer *models.Performer) {
// ensure basic attributes are set correctly
if input.Name != performer.Name {
s.fieldMismatch(input.Name, performer.Name, "Name")
}
r := s.resolver.Performer()
id, _ := r.ID(s.ctx, performer)
if id == "" {
s.t.Errorf("Expected created performer id to be non-zero")
}
if v, _ := r.Disambiguation(s.ctx, performer); !reflect.DeepEqual(v, input.Disambiguation) {
s.fieldMismatch(*input.Disambiguation, v, "Disambiguation")
}
if v, _ := r.Aliases(s.ctx, performer); !reflect.DeepEqual(v, input.Aliases) {
s.fieldMismatch(input.Aliases, v, "Aliases")
}
if v, _ := r.Gender(s.ctx, performer); !reflect.DeepEqual(v, input.Gender) {
s.fieldMismatch(*input.Gender, v, "Gender")
}
// ensure urls were set correctly
urls, _ := s.resolver.Performer().Urls(s.ctx, performer)
if !compareUrls(input.Urls, urls) {
s.fieldMismatch(input.Urls, urls, "Urls")
}
birthdate, _ := r.Birthdate(s.ctx, performer)
if !bothNil(birthdate, input.Birthdate) && (oneNil(birthdate, input.Birthdate) || birthdate.Date != input.Birthdate.Date || birthdate.Accuracy != input.Birthdate.Accuracy) {
s.fieldMismatch(input.Birthdate, birthdate, "Birthdate")
}
if v, _ := r.Ethnicity(s.ctx, performer); !reflect.DeepEqual(v, input.Ethnicity) {
s.fieldMismatch(*input.Ethnicity, v, "Ethnicity")
}
if v, _ := r.Country(s.ctx, performer); !reflect.DeepEqual(v, input.Country) {
s.fieldMismatch(*input.Country, v, "Country")
}
if v, _ := r.EyeColor(s.ctx, performer); !reflect.DeepEqual(v, input.EyeColor) {
s.fieldMismatch(*input.HairColor, v, "EyeColor")
}
if v, _ := r.HairColor(s.ctx, performer); !reflect.DeepEqual(v, input.HairColor) {
s.fieldMismatch(*input.HairColor, v, "HairColor")
}
if v, _ := r.Height(s.ctx, performer); !reflect.DeepEqual(v, input.Height) {
s.fieldMismatch(*input.Height, v, "Height")
}
measurements, _ := r.Measurements(s.ctx, performer)
if !bothNil(measurements, input.Measurements) && (oneNil(measurements, input.Measurements) ||
*measurements.CupSize != *input.Measurements.CupSize ||
*measurements.BandSize != *input.Measurements.BandSize ||
*measurements.Waist != *input.Measurements.Waist ||
*measurements.Hip != *input.Measurements.Hip) {
s.fieldMismatch(input.Measurements, measurements, "Measurements")
}
if v, _ := r.BreastType(s.ctx, performer); !reflect.DeepEqual(v, input.BreastType) {
s.fieldMismatch(*input.BreastType, v, "BreastType")
}
if v, _ := r.CareerStartYear(s.ctx, performer); !reflect.DeepEqual(v, input.CareerStartYear) {
s.fieldMismatch(*input.CareerStartYear, v, "CareerStartYear")
}
if v, _ := r.CareerEndYear(s.ctx, performer); !reflect.DeepEqual(v, input.CareerEndYear) {
s.fieldMismatch(nil, v, "CareerEndYear")
}
tattoos, _ := s.resolver.Performer().Tattoos(s.ctx, performer)
if !compareBodyMods(input.Tattoos, tattoos) {
s.fieldMismatch(input.Tattoos, tattoos, "Tattoos")
}
piercings, _ := s.resolver.Performer().Piercings(s.ctx, performer)
if !compareBodyMods(input.Piercings, piercings) {
s.fieldMismatch(input.Piercings, piercings, "Piercings")
}
}
func (s *performerTestRunner) testFindPerformer() {
createdPerformer, err := s.createTestPerformer(nil)
if err != nil {
return
}
performer, err := s.resolver.Query().FindPerformer(s.ctx, strconv.FormatInt(createdPerformer.ID, 10))
if err != nil {
s.t.Errorf("Error finding performer: %s", err.Error())
return
}
// ensure returned performer is not nil
if performer == nil {
s.t.Error("Did not find performer by id")
return
}
// ensure values were set
if createdPerformer.Name != performer.Name {
s.fieldMismatch(createdPerformer.Name, performer.Name, "Name")
}
}
func (s *performerTestRunner) testUpdatePerformer() {
cupSize := "C"
bandSize := 32
tattooDesc := "Foobar"
input := &models.PerformerCreateInput{
Name: s.generatePerformerName(),
Aliases: []string{"Alias1", "Alias2"},
Urls: []*models.URLInput{
&models.URLInput{
URL: "URL",
Type: "Type",
},
},
Birthdate: &models.FuzzyDateInput{
Date: "2001-02-03",
Accuracy: models.DateAccuracyEnumDay,
},
Measurements: &models.MeasurementsInput{
CupSize: &cupSize,
BandSize: &bandSize,
Waist: &bandSize,
Hip: &bandSize,
},
Tattoos: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Inner thigh",
Description: &tattooDesc,
},
},
Piercings: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Nose",
Description: nil,
},
},
}
createdPerformer, err := s.createTestPerformer(input)
if err != nil {
return
}
performerID := strconv.FormatInt(createdPerformer.ID, 10)
updateInput := models.PerformerUpdateInput{
ID: performerID,
Aliases: []string{"Alias3", "Alias4"},
Urls: []*models.URLInput{
&models.URLInput{
URL: "URL",
Type: "Type",
},
},
Birthdate: &models.FuzzyDateInput{
Date: "2001-02-03",
Accuracy: models.DateAccuracyEnumDay,
},
Measurements: &models.MeasurementsInput{
CupSize: &cupSize,
BandSize: &bandSize,
Waist: &bandSize,
Hip: &bandSize,
},
Tattoos: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Tramp stamp",
Description: &tattooDesc,
},
},
Piercings: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Navel",
Description: nil,
},
},
}
// need some mocking of the context to make the field ignore behaviour work
ctx := s.updateContext([]string{
"aliases",
"urls",
"birthdate",
"measurements",
"tattoos",
"piercings",
})
updatedPerformer, err := s.resolver.Mutation().PerformerUpdate(ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating performer: %s", err.Error())
return
}
s.verifyUpdatedPerformer(updateInput, updatedPerformer)
}
func (s *performerTestRunner) testUpdatePerformerName() {
cupSize := "C"
bandSize := 32
tattooDesc := "Foobar"
input := &models.PerformerCreateInput{
Name: s.generatePerformerName(),
Aliases: []string{"Alias1", "Alias2"},
Urls: []*models.URLInput{
&models.URLInput{
URL: "URL",
Type: "Type",
},
},
Birthdate: &models.FuzzyDateInput{
Date: "2001-02-03",
Accuracy: models.DateAccuracyEnumDay,
},
Measurements: &models.MeasurementsInput{
CupSize: &cupSize,
BandSize: &bandSize,
Waist: &bandSize,
Hip: &bandSize,
},
Tattoos: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Inner thigh",
Description: &tattooDesc,
},
},
Piercings: []*models.BodyModificationInput{
&models.BodyModificationInput{
Location: "Nose",
Description: nil,
},
},
}
createdPerformer, err := s.createTestPerformer(input)
if err != nil {
return
}
performerID := strconv.FormatInt(createdPerformer.ID, 10)
updatedName := s.generatePerformerName()
updateInput := models.PerformerUpdateInput{
ID: performerID,
Name: &updatedName,
}
// need some mocking of the context to make the field ignore behaviour work
ctx := s.updateContext([]string{
"name",
})
updatedPerformer, err := s.resolver.Mutation().PerformerUpdate(ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating performer: %s", err.Error())
return
}
input.Name = updatedName
s.verifyCreatedPerformer(*input, updatedPerformer)
}
func (s *performerTestRunner) verifyUpdatedPerformer(input models.PerformerUpdateInput, performer *models.Performer) {
// ensure basic attributes are set correctly
if input.Name != nil && *input.Name != performer.Name {
s.fieldMismatch(input.Name, performer.Name, "Name")
}
r := s.resolver.Performer()
if v, _ := r.Aliases(s.ctx, performer); !reflect.DeepEqual(v, input.Aliases) {
s.fieldMismatch(input.Aliases, v, "Aliases")
}
// ensure urls were set correctly
urls, _ := s.resolver.Performer().Urls(s.ctx, performer)
if !compareUrls(input.Urls, urls) {
s.fieldMismatch(input.Urls, urls, "Urls")
}
birthdate, _ := r.Birthdate(s.ctx, performer)
if birthdate != nil && (birthdate.Date != input.Birthdate.Date || birthdate.Accuracy != input.Birthdate.Accuracy) {
s.fieldMismatch(input.Birthdate, birthdate, "Birthdate")
}
measurements, _ := r.Measurements(s.ctx, performer)
if input.Measurements != nil && (*measurements.CupSize != *input.Measurements.CupSize ||
*measurements.BandSize != *input.Measurements.BandSize ||
*measurements.Waist != *input.Measurements.Waist ||
*measurements.Hip != *input.Measurements.Hip) {
s.fieldMismatch(input.Measurements, measurements, "Measurements")
}
tattoos, _ := s.resolver.Performer().Tattoos(s.ctx, performer)
if !compareBodyMods(input.Tattoos, tattoos) {
s.fieldMismatch(input.Tattoos, tattoos, "Tattoos")
}
piercings, _ := s.resolver.Performer().Piercings(s.ctx, performer)
if !compareBodyMods(input.Piercings, piercings) {
s.fieldMismatch(input.Piercings, piercings, "Piercings")
}
}
func (s *performerTestRunner) testDestroyPerformer() {
createdPerformer, err := s.createTestPerformer(nil)
if err != nil {
return
}
performerID := strconv.FormatInt(createdPerformer.ID, 10)
destroyed, err := s.resolver.Mutation().PerformerDestroy(s.ctx, models.PerformerDestroyInput{
ID: performerID,
})
if err != nil {
s.t.Errorf("Error destroying performer: %s", err.Error())
return
}
if !destroyed {
s.t.Error("Performer was not destroyed")
return
}
// ensure cannot find performer
foundPerformer, err := s.resolver.Query().FindPerformer(s.ctx, performerID)
if err != nil {
s.t.Errorf("Error finding performer after destroying: %s", err.Error())
return
}
if foundPerformer != nil {
s.t.Error("Found performer after destruction")
}
// TODO - ensure scene was not removed
}
func TestCreatePerformer(t *testing.T) {
pt := createPerformerTestRunner(t)
pt.testCreatePerformer()
}
func TestFindPerformer(t *testing.T) {
pt := createPerformerTestRunner(t)
pt.testFindPerformer()
}
func TestUpdatePerformer(t *testing.T) {
pt := createPerformerTestRunner(t)
pt.testUpdatePerformer()
}
func TestUpdatePerformerName(t *testing.T) {
pt := createPerformerTestRunner(t)
pt.testUpdatePerformerName()
}
func TestDestroyPerformer(t *testing.T) {
pt := createPerformerTestRunner(t)
pt.testDestroyPerformer()
}

View File

@@ -16,6 +16,15 @@ func (r *Resolver) Mutation() models.MutationResolver {
func (r *Resolver) Performer() models.PerformerResolver {
return &performerResolver{r}
}
func (r *Resolver) Tag() models.TagResolver {
return &tagResolver{r}
}
func (r *Resolver) Studio() models.StudioResolver {
return &studioResolver{r}
}
func (r *Resolver) Scene() models.SceneResolver {
return &sceneResolver{r}
}
func (r *Resolver) Query() models.QueryResolver {
return &queryResolver{r}
}
@@ -34,6 +43,10 @@ func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) {
func wasFieldIncluded(ctx context.Context, field string) bool {
rctx := graphql.GetRequestContext(ctx)
_, ret := rctx.Variables[field]
return ret
if rctx != nil {
_, ret := rctx.Variables[field]
return ret
}
return false
}

View File

@@ -20,7 +20,7 @@ func (r *performerResolver) Disambiguation(ctx context.Context, obj *models.Perf
}
func (r *performerResolver) Aliases(ctx context.Context, obj *models.Performer) ([]string, error) {
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
aliases, err := qb.GetAliases(obj.ID)
if err != nil {
@@ -40,7 +40,7 @@ func (r *performerResolver) Gender(ctx context.Context, obj *models.Performer) (
}
func (r *performerResolver) Urls(ctx context.Context, obj *models.Performer) ([]*models.URL, error) {
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
urls, err := qb.GetUrls(obj.ID)
if err != nil {
@@ -140,7 +140,7 @@ func (r *performerResolver) CareerEndYear(ctx context.Context, obj *models.Perfo
}
func (r *performerResolver) Tattoos(ctx context.Context, obj *models.Performer) ([]*models.BodyModification, error) {
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
tattoos, err := qb.GetTattoos(obj.ID)
if err != nil {
@@ -157,7 +157,7 @@ func (r *performerResolver) Tattoos(ctx context.Context, obj *models.Performer)
}
func (r *performerResolver) Piercings(ctx context.Context, obj *models.Performer) ([]*models.BodyModification, error) {
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
piercings, err := qb.GetPiercings(obj.ID)
if err != nil {

View File

@@ -2,40 +2,76 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
type sceneResolver struct{ *Resolver }
func (r *sceneResolver) ID(ctx context.Context, obj *models.Scene) (string, error) {
return strconv.FormatInt(obj.ID, 10), nil
}
func (r *sceneResolver) Title(ctx context.Context, obj *models.Scene) (*string, error) {
panic("not implemented")
return resolveNullString(obj.Title)
}
func (r *sceneResolver) Details(ctx context.Context, obj *models.Scene) (*string, error) {
panic("not implemented")
return resolveNullString(obj.Details)
}
func (r *sceneResolver) URL(ctx context.Context, obj *models.Scene) (*string, error) {
panic("not implemented")
return resolveNullString(obj.URL)
}
func (r *sceneResolver) Date(ctx context.Context, obj *models.Scene) (*string, error) {
panic("not implemented")
return resolveSQLiteDate(obj.Date)
}
func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (*models.Studio, error) {
panic("not implemented")
}
if !obj.StudioID.Valid {
return nil, nil
}
qb := models.NewStudioQueryBuilder(nil)
parent, err := qb.Find(obj.StudioID.Int64)
if err != nil {
return nil, err
}
return parent, nil
}
func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) ([]*models.Tag, error) {
panic("not implemented")
qb := models.NewTagQueryBuilder(nil)
return qb.FindBySceneID(obj.ID)
}
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.PerformerAppearance, error) {
pqb := models.NewPerformerQueryBuilder(nil)
sqb := models.NewSceneQueryBuilder(nil)
performersScenes, err := sqb.GetPerformers(obj.ID)
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.Performer, error) {
panic("not implemented")
}
if err != nil {
return nil, err
}
func (r *sceneResolver) Checksums(ctx context.Context, obj *models.Scene) ([]string, error) {
panic("not implemented")
// TODO - probably a better way to do this
var ret []*models.PerformerAppearance
for _, appearance := range performersScenes {
performer, err := pqb.Find(appearance.PerformerID)
if err != nil {
return nil, err
}
as, _ := resolveNullString(appearance.As)
retApp := models.PerformerAppearance{
Performer: performer,
As: as,
}
ret = append(ret, &retApp)
}
return ret, nil
}
func (r *sceneResolver) Fingerprints(ctx context.Context, obj *models.Scene) ([]*models.Fingerprint, error) {
qb := models.NewSceneQueryBuilder(nil)
return qb.GetFingerprints(obj.ID)
}

View File

@@ -2,18 +2,56 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
type studioResolver struct{ *Resolver }
func (r *studioResolver) ID(ctx context.Context, obj *models.Studio) (string, error) {
return strconv.FormatInt(obj.ID, 10), nil
}
func (r *studioResolver) Urls(ctx context.Context, obj *models.Studio) ([]*models.URL, error) {
panic("not implemented")
qb := models.NewStudioQueryBuilder(nil)
urls, err := qb.GetUrls(obj.ID)
if err != nil {
return nil, err
}
var ret []*models.URL
for _, url := range urls {
retURL := url.ToURL()
ret = append(ret, &retURL)
}
return ret, nil
}
func (r *studioResolver) Parent(ctx context.Context, obj *models.Studio) (*models.Studio, error) {
panic("not implemented")
if !obj.ParentStudioID.Valid {
return nil, nil
}
qb := models.NewStudioQueryBuilder(nil)
parent, err := qb.Find(obj.ParentStudioID.Int64)
if err != nil {
return nil, err
}
return parent, nil
}
func (r *studioResolver) AddedChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) {
panic("not implemented")
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) {
qb := models.NewStudioQueryBuilder(nil)
children, err := qb.FindByParentID(obj.ID)
if err != nil {
return nil, err
}
return children, nil
}

View File

@@ -2,15 +2,26 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
type tagResolver struct{ *Resolver }
func (r *tagResolver) ID(ctx context.Context, obj *models.Tag) (string, error) {
return strconv.FormatInt(obj.ID, 10), nil
}
func (r *tagResolver) Description(ctx context.Context, obj *models.Tag) (*string, error) {
panic("not implemented")
return resolveNullString(obj.Description)
}
func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) ([]string, error) {
panic("not implemented")
qb := models.NewTagQueryBuilder(nil)
aliases, err := qb.GetAliases(obj.ID)
if err != nil {
return nil, err
}
return aliases, nil
}

View File

@@ -2,6 +2,7 @@ package api
import (
"context"
"fmt"
"strconv"
"time"
@@ -27,12 +28,15 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
newPerformer.CopyFromCreateInput(input)
err = newPerformer.CopyFromCreateInput(input)
if err != nil {
return nil, err
}
// Start the transaction and save the performer
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewPerformerQueryBuilder()
performer, err := qb.Create(newPerformer, tx)
qb := models.NewPerformerQueryBuilder(tx)
performer, err := qb.Create(newPerformer)
if err != nil {
_ = tx.Rollback()
return nil, err
@@ -40,28 +44,28 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
// Save the aliases
performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases)
if err := qb.CreateAliases(performerAliases, tx); err != nil {
if err := qb.CreateAliases(performerAliases); err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the URLs
performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls)
if err := qb.CreateUrls(performerUrls, tx); err != nil {
if err := qb.CreateUrls(performerUrls); err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the Tattoos
performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos)
if err := qb.CreateTattoos(performerTattoos, tx); err != nil {
if err := qb.CreateTattoos(performerTattoos); err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the Piercings
performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings)
if err := qb.CreatePiercings(performerPiercings, tx); err != nil {
if err := qb.CreatePiercings(performerPiercings); err != nil {
_ = tx.Rollback()
return nil, err
}
@@ -79,56 +83,74 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
return nil, err
}
qb := models.NewPerformerQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewPerformerQueryBuilder(tx)
// get the existing performer and modify it
performerID, _ := strconv.Atoi(input.ID)
performerID, _ := strconv.ParseInt(input.ID, 10, 64)
updatedPerformer, err := qb.Find(performerID)
if err != nil {
return nil, err
}
if updatedPerformer == nil {
return nil, fmt.Errorf("Performer with id %d cannot be found", performerID)
}
updatedPerformer.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()}
// Start the transaction and save the performer
tx := database.DB.MustBeginTx(ctx, nil)
// Populate performer from the input
updatedPerformer.CopyFromUpdateInput(input)
err = updatedPerformer.CopyFromUpdateInput(input)
if err != nil {
_ = tx.Rollback()
return nil, err
}
performer, err := qb.Update(*updatedPerformer, tx)
performer, err := qb.Update(*updatedPerformer)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the aliases
performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases)
if err := qb.UpdateAliases(performer.ID, performerAliases, tx); err != nil {
_ = tx.Rollback()
return nil, err
// only do this if provided
if wasFieldIncluded(ctx, "aliases") {
performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases)
if err := qb.UpdateAliases(performer.ID, performerAliases); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the URLs
performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls)
if err := qb.UpdateUrls(performer.ID, performerUrls, tx); err != nil {
_ = tx.Rollback()
return nil, err
// only do this if provided
if wasFieldIncluded(ctx, "urls") {
performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls)
if err := qb.UpdateUrls(performer.ID, performerUrls); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the Tattoos
performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos)
if err := qb.UpdateTattoos(performer.ID, performerTattoos, tx); err != nil {
_ = tx.Rollback()
return nil, err
// only do this if provided
if wasFieldIncluded(ctx, "tattoos") {
performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos)
if err := qb.UpdateTattoos(performer.ID, performerTattoos); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the Piercings
performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings)
if err := qb.UpdatePiercings(performer.ID, performerPiercings, tx); err != nil {
_ = tx.Rollback()
return nil, err
// only do this if provided
if wasFieldIncluded(ctx, "piercings") {
performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings)
if err := qb.UpdatePiercings(performer.ID, performerPiercings); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Commit
@@ -144,8 +166,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe
return false, err
}
qb := models.NewPerformerQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewPerformerQueryBuilder(tx)
// references have on delete cascade, so shouldn't be necessary
// to remove them explicitly
@@ -154,7 +176,7 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe
if err != nil {
return false, err
}
if err = qb.Destroy(performerID, tx); err != nil {
if err = qb.Destroy(performerID); err != nil {
_ = tx.Rollback()
return false, err
}

View File

@@ -2,7 +2,10 @@ package api
import (
"context"
"strconv"
"time"
"github.com/stashapp/stashdb/pkg/database"
"github.com/stashapp/stashdb/pkg/models"
)
@@ -11,11 +14,124 @@ func (r *mutationResolver) SceneCreate(ctx context.Context, input models.SceneCr
return nil, err
}
return nil, nil
var err error
if err != nil {
return nil, err
}
// Populate a new scene from the input
currentTime := time.Now()
newScene := models.Scene{
CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
newScene.CopyFromCreateInput(input)
// Start the transaction and save the scene
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder(tx)
scene, err := qb.Create(newScene)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the checksums
sceneFingerprints := models.CreateSceneFingerprints(scene.ID, input.Fingerprints)
if err := qb.CreateFingerprints(sceneFingerprints); err != nil {
_ = tx.Rollback()
return nil, err
}
// save the performers
scenePerformers := models.CreateScenePerformers(scene.ID, input.Performers)
jqb := models.NewJoinsQueryBuilder(tx)
if err := jqb.CreatePerformersScenes(scenePerformers); err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the tags
tagJoins := models.CreateSceneTags(scene.ID, input.TagIds)
if err := jqb.CreateScenesTags(tagJoins); err != nil {
return nil, err
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return scene, nil
}
func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (*models.Scene, error) {
return nil, nil
if err := validateModify(ctx); err != nil {
return nil, err
}
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder(tx)
// get the existing scene and modify it
sceneID, _ := strconv.ParseInt(input.ID, 10, 64)
updatedScene, err := qb.Find(sceneID)
if err != nil {
return nil, err
}
updatedScene.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()}
// Populate scene from the input
updatedScene.CopyFromUpdateInput(input)
scene, err := qb.Update(*updatedScene)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the checksums
// only do this if provided
if wasFieldIncluded(ctx, "fingerprints") {
sceneFingerprints := models.CreateSceneFingerprints(scene.ID, input.Fingerprints)
if err := qb.UpdateFingerprints(scene.ID, sceneFingerprints); err != nil {
_ = tx.Rollback()
return nil, err
}
}
jqb := models.NewJoinsQueryBuilder(tx)
// only do this if provided
if wasFieldIncluded(ctx, "performers") {
scenePerformers := models.CreateScenePerformers(scene.ID, input.Performers)
if err := jqb.UpdatePerformersScenes(scene.ID, scenePerformers); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the tags
// only do this if provided
if wasFieldIncluded(ctx, "tagIds") {
tagJoins := models.CreateSceneTags(scene.ID, input.TagIds)
if err := jqb.UpdateScenesTags(scene.ID, tagJoins); err != nil {
return nil, err
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return scene, nil
}
func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneDestroyInput) (bool, error) {
@@ -23,5 +139,23 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
return false, err
}
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder(tx)
// references have on delete cascade, so shouldn't be necessary
// to remove them explicitly
sceneID, err := strconv.ParseInt(input.ID, 10, 64)
if err != nil {
return false, err
}
if err = qb.Destroy(sceneID); err != nil {
_ = tx.Rollback()
return false, err
}
if err := tx.Commit(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -2,7 +2,10 @@ package api
import (
"context"
"strconv"
"time"
"github.com/stashapp/stashdb/pkg/database"
"github.com/stashapp/stashdb/pkg/models"
)
@@ -11,7 +14,45 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
return nil, err
}
return nil, nil
var err error
if err != nil {
return nil, err
}
// Populate a new studio from the input
currentTime := time.Now()
newStudio := models.Studio{
CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
newStudio.CopyFromCreateInput(input)
// Start the transaction and save the studio
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewStudioQueryBuilder(tx)
studio, err := qb.Create(newStudio)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// TODO - save child studios
// Save the URLs
studioUrls := models.CreateStudioUrls(studio.ID, input.Urls)
if err := qb.CreateUrls(studioUrls); err != nil {
_ = tx.Rollback()
return nil, err
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return studio, nil
}
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
@@ -19,7 +60,44 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
return nil, err
}
return nil, nil
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewStudioQueryBuilder(tx)
// get the existing studio and modify it
studioID, _ := strconv.ParseInt(input.ID, 10, 64)
updatedStudio, err := qb.Find(studioID)
if err != nil {
return nil, err
}
updatedStudio.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()}
// Populate studio from the input
updatedStudio.CopyFromUpdateInput(input)
studio, err := qb.Update(*updatedStudio)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the URLs
// TODO - only do this if provided
studioUrls := models.CreateStudioUrls(studio.ID, input.Urls)
if err := qb.UpdateUrls(studio.ID, studioUrls); err != nil {
_ = tx.Rollback()
return nil, err
}
// TODO - handle child studios
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return studio, nil
}
func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
@@ -27,5 +105,23 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi
return false, err
}
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewStudioQueryBuilder(tx)
// references have on delete cascade, so shouldn't be necessary
// to remove them explicitly
studioID, err := strconv.ParseInt(input.ID, 10, 64)
if err != nil {
return false, err
}
if err = qb.Destroy(studioID); err != nil {
_ = tx.Rollback()
return false, err
}
if err := tx.Commit(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -2,7 +2,10 @@ package api
import (
"context"
"strconv"
"time"
"github.com/stashapp/stashdb/pkg/database"
"github.com/stashapp/stashdb/pkg/models"
)
@@ -11,7 +14,43 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
return nil, err
}
return nil, nil
var err error
if err != nil {
return nil, err
}
// Populate a new performer from the input
currentTime := time.Now()
newTag := models.Tag{
CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
newTag.CopyFromCreateInput(input)
// Start the transaction and save the performer
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewTagQueryBuilder(tx)
tag, err := qb.Create(newTag)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the aliases
tagAliases := models.CreateTagAliases(tag.ID, input.Aliases)
if err := qb.CreateAliases(tagAliases); err != nil {
_ = tx.Rollback()
return nil, err
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return tag, nil
}
func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
@@ -19,7 +58,42 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
return nil, err
}
return nil, nil
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewTagQueryBuilder(tx)
// get the existing tag and modify it
tagID, _ := strconv.ParseInt(input.ID, 10, 64)
updatedTag, err := qb.Find(tagID)
if err != nil {
return nil, err
}
updatedTag.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()}
// Populate performer from the input
updatedTag.CopyFromUpdateInput(input)
tag, err := qb.Update(*updatedTag)
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the aliases
// TODO - only do this if provided
tagAliases := models.CreateTagAliases(tag.ID, input.Aliases)
if err := qb.UpdateAliases(tag.ID, tagAliases); err != nil {
_ = tx.Rollback()
return nil, err
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err
}
return tag, nil
}
func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
@@ -27,5 +101,23 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr
return false, err
}
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewTagQueryBuilder(tx)
// references have on delete cascade, so shouldn't be necessary
// to remove them explicitly
tagID, err := strconv.ParseInt(input.ID, 10, 64)
if err != nil {
return false, err
}
if err = qb.Destroy(tagID); err != nil {
_ = tx.Rollback()
return false, err
}
if err := tx.Commit(); err != nil {
return false, err
}
return true, nil
}

View File

@@ -12,9 +12,9 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (*models.P
return nil, err
}
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
idInt, _ := strconv.Atoi(id)
idInt, _ := strconv.ParseInt(id, 10, 64)
return qb.Find(idInt)
}
func (r *queryResolver) QueryPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.QuerySpec) (*models.QueryPerformersResultType, error) {
@@ -22,7 +22,7 @@ func (r *queryResolver) QueryPerformers(ctx context.Context, performerFilter *mo
return nil, err
}
qb := models.NewPerformerQueryBuilder()
qb := models.NewPerformerQueryBuilder(nil)
performers, count := qb.Query(performerFilter, filter)
return &models.QueryPerformersResultType{

View File

@@ -2,38 +2,42 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) {
panic("not implemented")
func (r *queryResolver) FindScene(ctx context.Context, id string) (*models.Scene, error) {
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewSceneQueryBuilder(nil)
idInt, _ := strconv.ParseInt(id, 10, 64)
return qb.Find(idInt)
}
func (r *queryResolver) FindSceneByFingerprint(ctx context.Context, fingerprint models.FingerprintInput) ([]*models.Scene, error) {
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewSceneQueryBuilder(nil)
return qb.FindByFingerprint(fingerprint.Algorithm, fingerprint.Hash)
}
func (r *queryResolver) QueryScenes(ctx context.Context, sceneFilter *models.SceneFilterType, filter *models.QuerySpec) (*models.QueryScenesResultType, error) {
panic("not implemented")
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewSceneQueryBuilder(nil)
scenes, count := qb.Query(sceneFilter, filter)
return &models.QueryScenesResultType{
Scenes: scenes,
Count: count,
}, nil
}
// func (r *queryResolver) FindScene(ctx context.Context, id string) (*models.Scene, error) {
// if err := validateRead(ctx); err != nil {
// return nil, err
// }
// qb := models.NewSceneQueryBuilder()
// idInt, _ := strconv.Atoi(id)
// var scene *models.Scene
// var err error
// scene, err = qb.Find(idInt)
// return scene, err
// }
// func (r *queryResolver) FindSceneByChecksum(ctx context.Context, checksum string) (*models.Scene, error) {
// if err := validateRead(ctx); err != nil {
// return nil, err
// }
// qb := models.NewSceneQueryBuilder()
// var scene *models.Scene
// var err error
// scene, err = qb.FindByChecksum(checksum)
// return scene, err
// }

View File

@@ -2,14 +2,38 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
func (r *queryResolver) FindStudio(ctx context.Context, id *string, name *string) (*models.Studio, error) {
panic("not implemented")
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewStudioQueryBuilder(nil)
if id != nil {
idInt, _ := strconv.ParseInt(*id, 10, 64)
return qb.Find(idInt)
} else if name != nil {
return qb.FindByName(*name)
}
return nil, nil
}
func (r *queryResolver) QueryStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.QuerySpec) (*models.QueryStudiosResultType, error) {
panic("not implemented")
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewStudioQueryBuilder(nil)
studios, count := qb.Query(studioFilter, filter)
return &models.QueryStudiosResultType{
Studios: studios,
Count: count,
}, nil
}

View File

@@ -2,14 +2,38 @@ package api
import (
"context"
"strconv"
"github.com/stashapp/stashdb/pkg/models"
)
func (r *queryResolver) FindTag(ctx context.Context, id *string, name *string) (*models.Tag, error) {
panic("not implemented")
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewTagQueryBuilder(nil)
if id != nil {
idInt, _ := strconv.ParseInt(*id, 10, 64)
return qb.Find(idInt)
} else if name != nil {
return qb.FindByNameOrAlias(*name)
}
return nil, nil
}
func (r *queryResolver) QueryTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.QuerySpec) (*models.QueryTagsResultType, error) {
panic("not implemented")
if err := validateRead(ctx); err != nil {
return nil, err
}
qb := models.NewTagQueryBuilder(nil)
tags, count := qb.Query(tagFilter, filter)
return &models.QueryTagsResultType{
Tags: tags,
Count: count,
}, nil
}

View File

@@ -0,0 +1,53 @@
package api
import (
"context"
"net/http"
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stashdb/pkg/models"
)
type performerRoutes struct{}
func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{performerId}", func(r chi.Router) {
r.Use(PerformerCtx)
r.Get("/image", rs.Image)
})
return r
}
func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
if err := validateRead(r.Context()); err != nil {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
performer := r.Context().Value(performerKey).(*models.Performer)
_, _ = w.Write(performer.Image)
}
func PerformerCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
performerID, err := strconv.ParseInt(chi.URLParam(r, "performerId"), 10, 64)
if err != nil {
http.Error(w, http.StatusText(404), 404)
return
}
qb := models.NewPerformerQueryBuilder(nil)
performer, err := qb.Find(performerID)
if err != nil {
http.Error(w, http.StatusText(404), 404)
return
}
ctx := context.WithValue(r.Context(), performerKey, performer)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

View File

@@ -0,0 +1,486 @@
// +build integration
package api_test
import (
"reflect"
"strconv"
"testing"
"github.com/stashapp/stashdb/pkg/models"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
type sceneTestRunner struct {
testRunner
}
func createSceneTestRunner(t *testing.T) *sceneTestRunner {
return &sceneTestRunner{
testRunner: *createTestRunner(t),
}
}
func (s *sceneTestRunner) testCreateScene() {
title := "Title"
details := "Details"
url := "URL"
date := "2003-02-01"
performer, _ := s.createTestPerformer(nil)
studio, _ := s.createTestScene(nil)
tag, _ := s.createTestTag(nil)
performerID := strconv.FormatInt(performer.ID, 10)
studioID := strconv.FormatInt(studio.ID, 10)
tagID := strconv.FormatInt(tag.ID, 10)
performerAlias := "alias"
input := models.SceneCreateInput{
Title: &title,
Details: &details,
URL: &url,
Date: &date,
Fingerprints: []*models.FingerprintInput{
s.generateSceneFingerprint(),
s.generateSceneFingerprint(),
},
StudioID: &studioID,
Performers: []*models.PerformerAppearanceInput{
&models.PerformerAppearanceInput{
PerformerID: performerID,
As: &performerAlias,
},
},
TagIds: []string{
tagID,
},
}
scene, err := s.resolver.Mutation().SceneCreate(s.ctx, input)
if err != nil {
s.t.Errorf("Error creating scene: %s", err.Error())
return
}
s.verifyCreatedScene(input, scene)
}
func comparePerformers(input []*models.PerformerAppearanceInput, performers []*models.PerformerAppearance) bool {
if len(performers) != len(input) {
return false
}
for i, v := range performers {
performerID := strconv.FormatInt(v.Performer.ID, 10)
if performerID != input[i].PerformerID {
return false
}
if v.As != input[i].As {
if v.As == nil || input[i].As == nil {
return false
}
if *v.As != *input[i].As {
return false
}
}
}
return true
}
func compareTags(tagIDs []string, tags []*models.Tag) bool {
if len(tags) != len(tagIDs) {
return false
}
for i, v := range tags {
tagID := strconv.FormatInt(v.ID, 10)
if tagID != tagIDs[i] {
return false
}
}
return true
}
func compareFingerprints(input []*models.FingerprintInput, fingerprints []*models.Fingerprint) bool {
if len(input) != len(fingerprints) {
return false
}
for i, v := range fingerprints {
if input[i].Algorithm != v.Algorithm || input[i].Hash != v.Hash {
return false
}
}
return true
}
func (s *sceneTestRunner) verifyCreatedScene(input models.SceneCreateInput, scene *models.Scene) {
// ensure basic attributes are set correctly
r := s.resolver.Scene()
id, _ := r.ID(s.ctx, scene)
if id == "" {
s.t.Errorf("Expected created scene id to be non-zero")
}
if v, _ := r.Title(s.ctx, scene); !reflect.DeepEqual(v, input.Title) {
s.fieldMismatch(*input.Title, v, "Title")
}
if v, _ := r.Details(s.ctx, scene); !reflect.DeepEqual(v, input.Details) {
s.fieldMismatch(input.Details, v, "Details")
}
if v, _ := r.URL(s.ctx, scene); !reflect.DeepEqual(v, input.URL) {
s.fieldMismatch(*input.URL, v, "URL")
}
if v, _ := r.Date(s.ctx, scene); !reflect.DeepEqual(v, input.Date) {
s.fieldMismatch(*input.Date, v, "Date")
}
if v, _ := r.Fingerprints(s.ctx, scene); !compareFingerprints(input.Fingerprints, v) {
s.fieldMismatch(input.Fingerprints, v, "Fingerprints")
}
performers, err := s.resolver.Scene().Performers(s.ctx, scene)
if err != nil {
s.t.Errorf("Error getting scene performers: %s", err.Error())
}
if !comparePerformers(input.Performers, performers) {
s.fieldMismatch(input.Performers, performers, "Performers")
}
tags, err := s.resolver.Scene().Tags(s.ctx, scene)
if err != nil {
s.t.Errorf("Error getting scene tags: %s", err.Error())
}
if !compareTags(input.TagIds, tags) {
s.fieldMismatch(input.TagIds, tags, "Tags")
}
}
func (s *sceneTestRunner) testFindSceneById() {
createdScene, err := s.createTestScene(nil)
if err != nil {
return
}
sceneID := strconv.FormatInt(createdScene.ID, 10)
scene, err := s.resolver.Query().FindScene(s.ctx, sceneID)
if err != nil {
s.t.Errorf("Error finding scene: %s", err.Error())
return
}
// ensure returned scene is not nil
if scene == nil {
s.t.Error("Did not find scene by id")
return
}
// ensure values were set
if createdScene.Title != scene.Title {
s.fieldMismatch(createdScene.Title, scene.Title, "Title")
}
}
func (s *sceneTestRunner) testFindSceneByFingerprint() {
createdScene, err := s.createTestScene(nil)
if err != nil {
return
}
fingerprints, err := s.resolver.Scene().Fingerprints(s.ctx, createdScene)
fingerprint := models.FingerprintInput{
Algorithm: fingerprints[0].Algorithm,
Hash: fingerprints[0].Hash,
}
scenes, err := s.resolver.Query().FindSceneByFingerprint(s.ctx, fingerprint)
if err != nil {
s.t.Errorf("Error finding scene: %s", err.Error())
return
}
// ensure returned scene is not nil
if len(scenes) == 0 {
s.t.Error("Did not find scene by fingerprint")
return
}
// ensure values were set
if createdScene.Title != scenes[0].Title {
s.fieldMismatch(createdScene.Title, scenes[0].Title, "Title")
}
}
func (s *sceneTestRunner) testUpdateScene() {
title := "Title"
details := "Details"
url := "URL"
date := "2003-02-01"
performer, _ := s.createTestPerformer(nil)
studio, _ := s.createTestScene(nil)
tag, _ := s.createTestTag(nil)
performerID := strconv.FormatInt(performer.ID, 10)
studioID := strconv.FormatInt(studio.ID, 10)
tagID := strconv.FormatInt(tag.ID, 10)
performerAlias := "alias"
input := models.SceneCreateInput{
Title: &title,
Details: &details,
URL: &url,
Date: &date,
Fingerprints: []*models.FingerprintInput{
s.generateSceneFingerprint(),
s.generateSceneFingerprint(),
},
StudioID: &studioID,
Performers: []*models.PerformerAppearanceInput{
&models.PerformerAppearanceInput{
PerformerID: performerID,
As: &performerAlias,
},
},
TagIds: []string{
tagID,
},
}
createdScene, err := s.createTestScene(&input)
if err != nil {
return
}
sceneID := strconv.FormatInt(createdScene.ID, 10)
newTitle := "NewTitle"
newDetails := "NewDetails"
newURL := "NewURL"
newDate := "2001-02-03"
performer, _ = s.createTestPerformer(nil)
studio, _ = s.createTestScene(nil)
tag, _ = s.createTestTag(nil)
performerID = strconv.FormatInt(performer.ID, 10)
studioID = strconv.FormatInt(studio.ID, 10)
tagID = strconv.FormatInt(tag.ID, 10)
performerAlias = "updatedAlias"
updateInput := models.SceneUpdateInput{
ID: sceneID,
Title: &newTitle,
Details: &newDetails,
URL: &newURL,
Date: &newDate,
Fingerprints: []*models.FingerprintInput{
s.generateSceneFingerprint(),
},
Performers: []*models.PerformerAppearanceInput{
&models.PerformerAppearanceInput{
PerformerID: performerID,
As: &performerAlias,
},
},
StudioID: &studioID,
TagIds: []string{
tagID,
},
}
// need some mocking of the context to make the field ignore behaviour work
ctx := s.updateContext([]string{
"fingerprints",
"performers",
"tagIds",
})
updatedScene, err := s.resolver.Mutation().SceneUpdate(ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating scene: %s", err.Error())
return
}
s.verifyUpdatedScene(updateInput, updatedScene)
}
func (s *sceneTestRunner) testUpdateSceneTitle() {
title := "Title"
details := "Details"
url := "URL"
date := "2003-02-01"
performer, _ := s.createTestPerformer(nil)
studio, _ := s.createTestScene(nil)
tag, _ := s.createTestTag(nil)
performerID := strconv.FormatInt(performer.ID, 10)
studioID := strconv.FormatInt(studio.ID, 10)
tagID := strconv.FormatInt(tag.ID, 10)
performerAlias := "alias"
input := models.SceneCreateInput{
Title: &title,
Details: &details,
URL: &url,
Date: &date,
Fingerprints: []*models.FingerprintInput{
s.generateSceneFingerprint(),
s.generateSceneFingerprint(),
},
Performers: []*models.PerformerAppearanceInput{
&models.PerformerAppearanceInput{
PerformerID: performerID,
As: &performerAlias,
},
},
StudioID: &studioID,
TagIds: []string{
tagID,
},
}
createdScene, err := s.createTestScene(&input)
if err != nil {
return
}
sceneID := strconv.FormatInt(createdScene.ID, 10)
newTitle := "NewTitle"
updateInput := models.SceneUpdateInput{
ID: sceneID,
Title: &newTitle,
}
// need some mocking of the context to make the field ignore behaviour work
ctx := s.updateContext([]string{
"title",
})
updatedScene, err := s.resolver.Mutation().SceneUpdate(ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating scene: %s", err.Error())
return
}
input.Title = &newTitle
s.verifyCreatedScene(input, updatedScene)
}
func (s *sceneTestRunner) verifyUpdatedScene(input models.SceneUpdateInput, scene *models.Scene) {
// ensure basic attributes are set correctly
r := s.resolver.Scene()
if v, _ := r.Title(s.ctx, scene); !reflect.DeepEqual(v, input.Title) {
s.fieldMismatch(input.Title, v, "Title")
}
if v, _ := r.Details(s.ctx, scene); !reflect.DeepEqual(v, input.Details) {
s.fieldMismatch(input.Details, v, "Details")
}
if v, _ := r.URL(s.ctx, scene); !reflect.DeepEqual(v, input.URL) {
s.fieldMismatch(input.URL, v, "URL")
}
if v, _ := r.Date(s.ctx, scene); !reflect.DeepEqual(v, input.Date) {
s.fieldMismatch(input.Date, v, "Date")
}
if v, _ := r.Fingerprints(s.ctx, scene); !compareFingerprints(input.Fingerprints, v) {
s.fieldMismatch(input.Fingerprints, v, "Fingerprints")
}
performers, _ := s.resolver.Scene().Performers(s.ctx, scene)
if !comparePerformers(input.Performers, performers) {
s.fieldMismatch(input.Performers, performers, "Performers")
}
tags, _ := s.resolver.Scene().Tags(s.ctx, scene)
if !compareTags(input.TagIds, tags) {
s.fieldMismatch(input.TagIds, tags, "Tags")
}
}
func (s *sceneTestRunner) testDestroyScene() {
createdScene, err := s.createTestScene(nil)
if err != nil {
return
}
sceneID := strconv.FormatInt(createdScene.ID, 10)
destroyed, err := s.resolver.Mutation().SceneDestroy(s.ctx, models.SceneDestroyInput{
ID: sceneID,
})
if err != nil {
s.t.Errorf("Error destroying scene: %s", err.Error())
return
}
if !destroyed {
s.t.Error("Scene was not destroyed")
return
}
// ensure cannot find scene
foundScene, err := s.resolver.Query().FindScene(s.ctx, sceneID)
if err != nil {
s.t.Errorf("Error finding scene after destroying: %s", err.Error())
return
}
if foundScene != nil {
s.t.Error("Found scene after destruction")
}
// TODO - ensure scene was not removed
}
func TestCreateScene(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testCreateScene()
}
func TestFindSceneById(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testFindSceneById()
}
func TestFindSceneByFingerprint(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testFindSceneByFingerprint()
}
func TestUpdateScene(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testUpdateScene()
}
func TestUpdateSceneTitle(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testUpdateSceneTitle()
}
func TestDestroyScene(t *testing.T) {
pt := createSceneTestRunner(t)
pt.testDestroyScene()
}

View File

@@ -110,6 +110,8 @@ func Start() {
// TODO - this should be disabled in production
r.Handle("/playground", handler.Playground("GraphQL playground", "/graphql"))
r.Mount("/performer", performerRoutes{}.Routes())
address := config.GetHost() + ":" + strconv.Itoa(config.GetPort())
if tlsConfig := makeTLSConfig(); tlsConfig != nil {
httpsServer := &http.Server{

View File

@@ -0,0 +1,208 @@
// +build integration
package api_test
import (
"strconv"
"testing"
"github.com/stashapp/stashdb/pkg/models"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
type studioTestRunner struct {
testRunner
studioSuffix int
}
func createStudioTestRunner(t *testing.T) *studioTestRunner {
return &studioTestRunner{
testRunner: *createTestRunner(t),
}
}
func (s *studioTestRunner) generateStudioName() string {
s.studioSuffix += 1
return "studioTestRunner-" + strconv.Itoa(s.studioSuffix)
}
func (s *studioTestRunner) testCreateStudio() {
input := models.StudioCreateInput{
Name: s.generateStudioName(),
}
studio, err := s.resolver.Mutation().StudioCreate(s.ctx, input)
if err != nil {
s.t.Errorf("Error creating studio: %s", err.Error())
return
}
s.verifyCreatedStudio(input, studio)
}
func (s *studioTestRunner) verifyCreatedStudio(input models.StudioCreateInput, studio *models.Studio) {
// ensure basic attributes are set correctly
if input.Name != studio.Name {
s.fieldMismatch(input.Name, studio.Name, "Name")
}
r := s.resolver.Studio()
id, _ := r.ID(s.ctx, studio)
if id == "" {
s.t.Errorf("Expected created studio id to be non-zero")
}
}
func (s *studioTestRunner) testFindStudioById() {
createdStudio, err := s.createTestStudio(nil)
if err != nil {
return
}
studioID := strconv.FormatInt(createdStudio.ID, 10)
studio, err := s.resolver.Query().FindStudio(s.ctx, &studioID, nil)
if err != nil {
s.t.Errorf("Error finding studio: %s", err.Error())
return
}
// ensure returned studio is not nil
if studio == nil {
s.t.Error("Did not find studio by id")
return
}
// ensure values were set
if createdStudio.Name != studio.Name {
s.fieldMismatch(createdStudio.Name, studio.Name, "Name")
}
}
func (s *studioTestRunner) testFindStudioByName() {
createdStudio, err := s.createTestStudio(nil)
if err != nil {
return
}
studioName := createdStudio.Name
studio, err := s.resolver.Query().FindStudio(s.ctx, nil, &studioName)
if err != nil {
s.t.Errorf("Error finding studio: %s", err.Error())
return
}
// ensure returned studio is not nil
if studio == nil {
s.t.Error("Did not find studio by name")
return
}
// ensure values were set
if createdStudio.Name != studio.Name {
s.fieldMismatch(createdStudio.Name, studio.Name, "Name")
}
}
func (s *studioTestRunner) testUpdateStudioName() {
input := &models.StudioCreateInput{
Name: s.generateStudioName(),
}
createdStudio, err := s.createTestStudio(input)
if err != nil {
return
}
studioID := strconv.FormatInt(createdStudio.ID, 10)
updatedName := s.generateStudioName()
updateInput := models.StudioUpdateInput{
ID: studioID,
Name: &updatedName,
}
// need some mocking of the context to make the field ignore behaviour work
ctx := s.updateContext([]string{
"name",
})
updatedStudio, err := s.resolver.Mutation().StudioUpdate(ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating studio: %s", err.Error())
return
}
input.Name = updatedName
s.verifyCreatedStudio(*input, updatedStudio)
}
func (s *studioTestRunner) verifyUpdatedStudio(input models.StudioUpdateInput, studio *models.Studio) {
// ensure basic attributes are set correctly
if input.Name != nil && *input.Name != studio.Name {
s.fieldMismatch(input.Name, studio.Name, "Name")
}
}
func (s *studioTestRunner) testDestroyStudio() {
createdStudio, err := s.createTestStudio(nil)
if err != nil {
return
}
studioID := strconv.FormatInt(createdStudio.ID, 10)
destroyed, err := s.resolver.Mutation().StudioDestroy(s.ctx, models.StudioDestroyInput{
ID: studioID,
})
if err != nil {
s.t.Errorf("Error destroying studio: %s", err.Error())
return
}
if !destroyed {
s.t.Error("Studio was not destroyed")
return
}
// ensure cannot find studio
foundStudio, err := s.resolver.Query().FindStudio(s.ctx, &studioID, nil)
if err != nil {
s.t.Errorf("Error finding studio after destroying: %s", err.Error())
return
}
if foundStudio != nil {
s.t.Error("Found studio after destruction")
}
// TODO - ensure scene was not removed
}
func TestCreateStudio(t *testing.T) {
pt := createStudioTestRunner(t)
pt.testCreateStudio()
}
func TestFindStudioById(t *testing.T) {
pt := createStudioTestRunner(t)
pt.testFindStudioById()
}
func TestFindStudioByName(t *testing.T) {
pt := createStudioTestRunner(t)
pt.testFindStudioByName()
}
func TestUpdateStudioName(t *testing.T) {
pt := createStudioTestRunner(t)
pt.testUpdateStudioName()
}
func TestDestroyStudio(t *testing.T) {
pt := createStudioTestRunner(t)
pt.testDestroyStudio()
}
// TODO - test parent/children studios

View File

@@ -0,0 +1,209 @@
// +build integration
package api_test
import (
"reflect"
"strconv"
"testing"
"github.com/stashapp/stashdb/pkg/models"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
type tagTestRunner struct {
testRunner
}
func createTagTestRunner(t *testing.T) *tagTestRunner {
return &tagTestRunner{
testRunner: *createTestRunner(t),
}
}
func (s *tagTestRunner) testCreateTag() {
description := "Description"
input := models.TagCreateInput{
Name: s.generateTagName(),
Description: &description,
}
tag, err := s.resolver.Mutation().TagCreate(s.ctx, input)
if err != nil {
s.t.Errorf("Error creating tag: %s", err.Error())
return
}
s.verifyCreatedTag(input, tag)
}
func (s *tagTestRunner) verifyCreatedTag(input models.TagCreateInput, tag *models.Tag) {
// ensure basic attributes are set correctly
if input.Name != tag.Name {
s.fieldMismatch(input.Name, tag.Name, "Name")
}
r := s.resolver.Tag()
id, _ := r.ID(s.ctx, tag)
if id == "" {
s.t.Errorf("Expected created tag id to be non-zero")
}
if v, _ := r.Description(s.ctx, tag); !reflect.DeepEqual(v, input.Description) {
s.fieldMismatch(*input.Description, v, "Description")
}
}
func (s *tagTestRunner) testFindTagById() {
createdTag, err := s.createTestTag(nil)
if err != nil {
return
}
tagID := strconv.FormatInt(createdTag.ID, 10)
tag, err := s.resolver.Query().FindTag(s.ctx, &tagID, nil)
if err != nil {
s.t.Errorf("Error finding tag: %s", err.Error())
return
}
// ensure returned tag is not nil
if tag == nil {
s.t.Error("Did not find tag by id")
return
}
// ensure values were set
if createdTag.Name != tag.Name {
s.fieldMismatch(createdTag.Name, tag.Name, "Name")
}
}
func (s *tagTestRunner) testFindTagByName() {
createdTag, err := s.createTestTag(nil)
if err != nil {
return
}
tagName := createdTag.Name
tag, err := s.resolver.Query().FindTag(s.ctx, nil, &tagName)
if err != nil {
s.t.Errorf("Error finding tag: %s", err.Error())
return
}
// ensure returned tag is not nil
if tag == nil {
s.t.Error("Did not find tag by name")
return
}
// ensure values were set
if createdTag.Name != tag.Name {
s.fieldMismatch(createdTag.Name, tag.Name, "Name")
}
}
func (s *tagTestRunner) testUpdateTag() {
createdTag, err := s.createTestTag(nil)
if err != nil {
return
}
tagID := strconv.FormatInt(createdTag.ID, 10)
newDescription := "newDescription"
updateInput := models.TagUpdateInput{
ID: tagID,
Description: &newDescription,
}
updatedTag, err := s.resolver.Mutation().TagUpdate(s.ctx, updateInput)
if err != nil {
s.t.Errorf("Error updating tag: %s", err.Error())
return
}
updateInput.Name = &createdTag.Name
s.verifyUpdatedTag(updateInput, updatedTag)
}
func (s *tagTestRunner) verifyUpdatedTag(input models.TagUpdateInput, tag *models.Tag) {
// ensure basic attributes are set correctly
if input.Name != nil && *input.Name != tag.Name {
s.fieldMismatch(input.Name, tag.Name, "Name")
}
r := s.resolver.Tag()
if v, _ := r.Description(s.ctx, tag); !reflect.DeepEqual(v, input.Description) {
s.fieldMismatch(input.Description, v, "Description")
}
}
func (s *tagTestRunner) testDestroyTag() {
createdTag, err := s.createTestTag(nil)
if err != nil {
return
}
tagID := strconv.FormatInt(createdTag.ID, 10)
destroyed, err := s.resolver.Mutation().TagDestroy(s.ctx, models.TagDestroyInput{
ID: tagID,
})
if err != nil {
s.t.Errorf("Error destroying tag: %s", err.Error())
return
}
if !destroyed {
s.t.Error("Tag was not destroyed")
return
}
// ensure cannot find tag
foundTag, err := s.resolver.Query().FindTag(s.ctx, &tagID, nil)
if err != nil {
s.t.Errorf("Error finding tag after destroying: %s", err.Error())
return
}
if foundTag != nil {
s.t.Error("Found tag after destruction")
}
// TODO - ensure scene was not removed
}
func TestCreateTag(t *testing.T) {
pt := createTagTestRunner(t)
pt.testCreateTag()
}
func TestFindTagById(t *testing.T) {
pt := createTagTestRunner(t)
pt.testFindTagById()
}
func TestFindTagByName(t *testing.T) {
pt := createTagTestRunner(t)
pt.testFindTagByName()
}
func TestUpdateTag(t *testing.T) {
pt := createTagTestRunner(t)
pt.testUpdateTag()
}
func TestDestroyTag(t *testing.T) {
pt := createTagTestRunner(t)
pt.testDestroyTag()
}

View File

@@ -3,7 +3,9 @@ package database
import (
"database/sql"
"fmt"
"os"
"regexp"
"github.com/gobuffalo/packr/v2"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/source"
@@ -11,7 +13,6 @@ import (
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/stashapp/stashdb/pkg/logger"
"github.com/stashapp/stashdb/pkg/utils"
"os"
)
var DB *sqlx.DB
@@ -68,6 +69,8 @@ func runMigrations(databasePath string) {
panic(err.Error())
}
}
m.Close()
}
func registerRegexpFunc() {

View File

@@ -0,0 +1,76 @@
package databasetest
import (
"context"
"fmt"
"io/ioutil"
"os"
"testing"
"github.com/stashapp/stashdb/pkg/database"
)
type DatabasePopulater interface {
PopulateDB() error
}
func testTeardown(databaseFile string) {
err := database.DB.Close()
if err != nil {
panic(err)
}
err = os.Remove(databaseFile)
if err != nil {
panic(err)
}
}
func runTests(m *testing.M, populater DatabasePopulater) int {
// create the database file
f, err := ioutil.TempFile("", "*.sqlite")
if err != nil {
panic(fmt.Sprintf("Could not create temporary file: %s", err.Error()))
}
f.Close()
databaseFile := f.Name()
database.Initialize(databaseFile)
// defer close and delete the database
defer testTeardown(databaseFile)
if populater != nil {
err = populater.PopulateDB()
if err != nil {
panic(fmt.Sprintf("Could not populate database: %s", err.Error()))
}
}
// run the tests
return m.Run()
}
func TestWithDatabase(m *testing.M, populater DatabasePopulater) {
ret := runTests(m, populater)
os.Exit(ret)
}
func WithTransientTransaction(ctx context.Context, fn database.TxFunc) {
txn := database.NewTransaction(ctx)
txn.Begin(ctx)
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
txn.Rollback()
panic(p)
} else {
// something went wrong, rollback
txn.Rollback()
}
}()
fn(txn)
}

249
pkg/database/dbi.go Normal file
View File

@@ -0,0 +1,249 @@
package database
import (
"database/sql"
"fmt"
"reflect"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
// The DBI interface is used to interface with the database.
type DBI interface {
// Insert inserts the provided object as a row into the database.
// It returns the new object.
Insert(model Model) (interface{}, error)
// InsertJoin inserts a join object into the provided join table.
InsertJoin(tableJoin TableJoin, object interface{}) error
// InsertJoins inserts multiple join objects into the provided join table.
InsertJoins(tableJoin TableJoin, joins Joins) error
// Update updates a database row based on the id and values of the provided
// object. It returns the updated object. Update will return an error if
// the object with id does not exist in the database table.
Update(model Model) (interface{}, error)
// ReplaceJoins replaces table join objects with the provided primary table
// id value with the provided join objects.
ReplaceJoins(tableJoin TableJoin, id int64, objects Joins) error
// Delete deletes the table row with the provided id. Delete returns an
// error if the id does not exist in the database table.
Delete(id int64, table Table) error
// DeleteJoins deletes all join objects with the provided primary table
// id value.
DeleteJoins(tableJoin TableJoin, id int64) error
// Find returns the row object with the provided id, or returns nil if not
// found.
Find(id int64, table Table) (interface{}, error)
// FindJoins returns join objects where the foreign key id is equal to the
// provided id. The join objects are output to the provided output slice.
FindJoins(tableJoin TableJoin, id int64, output Joins) error
// RawQuery performs a query on the provided table using the query string
// and argument slice. It outputs the results to the output slice.
RawQuery(table Table, query string, args []interface{}, output Models) error
}
type dbi struct {
tx *sqlx.Tx
}
// DBIWithTxn returns a DBI interface that is to operate within a transaction.
func DBIWithTxn(tx *sqlx.Tx) DBI {
return &dbi{
tx: tx,
}
}
// DBINoTxn returns a DBI interface that is to operate outside of a transaction.
// This DBI will not be able to mutate the database.
func DBINoTxn() DBI {
return &dbi{}
}
// Insert inserts the provided object as a row into the database.
// It returns the new object.
func (q dbi) Insert(model Model) (interface{}, error) {
tableName := model.GetTable().Name()
id, err := insertObject(q.tx, tableName, model)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error creating %s", reflect.TypeOf(model).Name()))
}
// don't want to modify the existing object
newModel := model.GetTable().NewObject()
if err := getByID(q.tx, tableName, id, newModel); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error getting %s after create", reflect.TypeOf(model).Name()))
}
return newModel, nil
}
// Update updates a database row based on the id and values of the provided
// object. It returns the updated object. Update will return an error if
// the object with id does not exist in the database table.
func (q dbi) Update(model Model) (interface{}, error) {
tableName := model.GetTable().Name()
err := updateObjectByID(q.tx, tableName, model)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error updating %s", reflect.TypeOf(model).Name()))
}
// don't want to modify the existing object
updatedModel := model.GetTable().NewObject()
if err := getByID(q.tx, tableName, model.GetID(), updatedModel); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error getting %s after update", reflect.TypeOf(model).Name()))
}
return updatedModel, nil
}
// Delete deletes the table row with the provided id. Delete returns an
// error if the id does not exist in the database table.
func (q dbi) Delete(id int64, table Table) error {
o, err := q.Find(id, table)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Error deleting from %s", table.Name()))
}
if o == nil {
return fmt.Errorf("Row with id %d not found in %s", id, table.Name())
}
return executeDeleteQuery(table.Name(), id, q.tx)
}
func selectStatement(table Table) string {
tableName := table.Name()
return fmt.Sprintf("SELECT %s.* FROM %s", tableName, tableName)
}
func (q dbi) queryx(query string, args ...interface{}) (*sqlx.Rows, error) {
if q.tx != nil {
return q.tx.Queryx(query, args...)
} else {
return DB.Queryx(query, args...)
}
}
// Find returns the row object with the provided id, or returns nil if not
// found.
func (q dbi) Find(id int64, table Table) (interface{}, error) {
query := selectStatement(table) + " WHERE id = ? LIMIT 1"
args := []interface{}{id}
var rows *sqlx.Rows
var err error
rows, err = q.queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
output := table.NewObject()
if rows.Next() {
if err := rows.StructScan(output); err != nil {
return nil, err
}
} else {
// not found
return nil, nil
}
if err := rows.Err(); err != nil {
return nil, err
}
return output, nil
}
// InsertJoin inserts a join object into the provided join table.
func (q dbi) InsertJoin(tableJoin TableJoin, object interface{}) error {
_, err := insertObject(q.tx, tableJoin.Name(), object)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Error creating %s", reflect.TypeOf(object).Name()))
}
return nil
}
// InsertJoins inserts multiple join objects into the provided join table.
func (q dbi) InsertJoins(tableJoin TableJoin, joins Joins) error {
var err error
joins.Each(func(ro interface{}) {
if err != nil {
return
}
err = q.InsertJoin(tableJoin, ro)
})
return err
}
// ReplaceJoins replaces table join objects with the provided primary table
// id value with the provided join objects.
func (q dbi) ReplaceJoins(tableJoin TableJoin, id int64, joins Joins) error {
err := q.DeleteJoins(tableJoin, id)
if err != nil {
return err
}
return q.InsertJoins(tableJoin, joins)
}
// DeleteJoins deletes all join objects with the provided primary table
// id value.
func (q dbi) DeleteJoins(tableJoin TableJoin, id int64) error {
return deleteObjectsByColumn(q.tx, tableJoin.Name(), tableJoin.joinColumn, id)
}
// FindJoins returns join objects where the foreign key id is equal to the
// provided id. The join objects are output to the provided output slice.
func (q dbi) FindJoins(tableJoin TableJoin, id int64, output Joins) error {
query := selectStatement(tableJoin.Table) + " WHERE " + tableJoin.joinColumn + " = ?"
args := []interface{}{id}
return q.RawQuery(tableJoin.Table, query, args, output)
}
// RawQuery performs a query on the provided table using the query string
// and argument slice. It outputs the results to the output slice.
func (q dbi) RawQuery(table Table, query string, args []interface{}, output Models) error {
var rows *sqlx.Rows
var err error
rows, err = q.queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return err
}
defer rows.Close()
for rows.Next() {
o := table.NewObject()
if err := rows.StructScan(o); err != nil {
return err
}
output.Add(o)
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -19,7 +19,8 @@ CREATE TABLE `performers` (
`career_start_year` integer,
`career_end_year` integer,
`created_at` datetime not null,
`updated_at` datetime not null
`updated_at` datetime not null,
unique (`name`, `disambiguation`)
);
CREATE TABLE `performer_aliases` (
@@ -59,3 +60,77 @@ CREATE INDEX `index_performers_on_alias` on `performer_aliases` (`alias`);
CREATE INDEX `index_performers_on_piercing_location` on `performer_piercings` (`location`);
CREATE INDEX `index_performers_on_tattoo_location` on `performer_tattoos` (`location`);
CREATE INDEX `index_performers_on_tattoo_description` on `performer_tattoos` (`description`);
CREATE TABLE `tags` (
`id` integer not null primary key autoincrement,
`name` varchar(255) not null,
`description` varchar(255),
`created_at` datetime not null,
`updated_at` datetime not null,
unique (`name`)
);
CREATE TABLE `tag_aliases` (
`tag_id` integer not null,
`alias` varchar(255) not null,
foreign key(`tag_id`) references `tags`(`id`) ON DELETE CASCADE,
unique (`alias`)
);
CREATE TABLE `studios` (
`id` integer not null primary key autoincrement,
`image` blob,
`name` varchar(255) not null,
`parent_studio_id` integer ,
`created_at` datetime not null,
`updated_at` datetime not null,
foreign key(`parent_studio_id`) references `studios`(`id`) ON DELETE CASCADE
);
CREATE TABLE `studio_urls` (
`studio_id` integer not null,
`url` varchar(255) not null,
`type` varchar(255) not null,
foreign key(`studio_id`) references `studios`(`id`) ON DELETE CASCADE,
unique (`studio_id`, `url`),
unique (`studio_id`, `type`)
);
CREATE TABLE `scenes` (
`id` integer not null primary key autoincrement,
`title` varchar(255),
`details` varchar(255),
`url` varchar(255),
`date` date,
`studio_id` integer,
`created_at` datetime not null,
`updated_at` datetime not null,
foreign key(`studio_id`) references `studios`(`id`) ON DELETE SET NULL
);
CREATE TABLE `scene_fingerprints` (
`scene_id` integer not null,
`hash` varchar(255) not null,
`algorithm` varchar(20) not null,
foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE,
unique (`scene_id`, `algorithm`, `hash`)
);
CREATE INDEX `index_scene_fingerprints_on_hash` on `scene_fingerprints` (`algorithm`, `hash`);
CREATE TABLE `scene_performers` (
`scene_id` integer not null,
`as` varchar(255),
`performer_id` integer not null,
foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE,
foreign key(`performer_id`) references `performers`(`id`) ON DELETE CASCADE,
unique(`scene_id`, `performer_id`)
);
CREATE TABLE `scene_tags` (
`scene_id` integer not null,
`tag_id` integer not null,
foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE,
foreign key(`tag_id`) references `tags`(`id`) ON DELETE CASCADE,
unique(`scene_id`, `tag_id`)
);

187
pkg/database/sql.go Normal file
View File

@@ -0,0 +1,187 @@
package database
import (
"database/sql"
"fmt"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
)
type optionalValue interface {
IsValid() bool
}
func ensureTx(tx *sqlx.Tx) {
if tx == nil {
panic("must use a transaction")
}
}
func getByID(tx *sqlx.Tx, table string, id int64, object interface{}) error {
return tx.Get(object, `SELECT * FROM `+table+` WHERE id = ? LIMIT 1`, id)
}
func insertObject(tx *sqlx.Tx, table string, object interface{}) (int64, error) {
ensureTx(tx)
fields, values := sqlGenKeysCreate(object)
result, err := tx.NamedExec(
`INSERT INTO `+table+` (`+fields+`)
VALUES (`+values+`)
`,
object,
)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
func updateObjectByID(tx *sqlx.Tx, table string, object interface{}) error {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE `+table+` SET `+sqlGenKeys(object, false)+` WHERE `+table+`.id = :id`,
object,
)
return err
}
func executeDeleteQuery(tableName string, id int64, tx *sqlx.Tx) error {
if tx == nil {
panic("must use a transaction")
}
idColumnName := getColumn(tableName, "id")
_, err := tx.Exec(
`DELETE FROM `+tableName+` WHERE `+idColumnName+` = ?`,
id,
)
return err
}
func deleteObjectsByColumn(tx *sqlx.Tx, table string, column string, value interface{}) error {
ensureTx(tx)
_, err := tx.Exec(`DELETE FROM `+table+` WHERE `+column+` = ?`, value)
return err
}
func getColumn(tableName string, columnName string) string {
return tableName + "." + columnName
}
func sqlGenKeysCreate(i interface{}) (string, string) {
var fields []string
var values []string
addPlaceholder := func(key string) {
fields = append(fields, "`"+key+"`")
values = append(values, ":"+key)
}
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
//get key for struct tag
rawKey := v.Type().Field(i).Tag.Get("db")
key := strings.Split(rawKey, ",")[0]
if key == "id" {
continue
}
switch t := v.Field(i).Interface().(type) {
case string:
if t != "" {
addPlaceholder(key)
}
case int, int64, float64:
if t != 0 {
addPlaceholder(key)
}
case optionalValue:
if t.IsValid() {
addPlaceholder(key)
}
case sql.NullString:
if t.Valid {
addPlaceholder(key)
}
case sql.NullBool:
if t.Valid {
addPlaceholder(key)
}
case sql.NullInt64:
if t.Valid {
addPlaceholder(key)
}
case sql.NullFloat64:
if t.Valid {
addPlaceholder(key)
}
default:
reflectValue := reflect.ValueOf(t)
isNil := reflectValue.IsNil()
if !isNil {
fields = append(fields, key)
values = append(values, ":"+key)
}
}
}
return strings.Join(fields, ", "), strings.Join(values, ", ")
}
func sqlGenKeys(i interface{}, partial bool) string {
var query []string
addKey := func(key string) {
query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
//get key for struct tag
rawKey := v.Type().Field(i).Tag.Get("db")
key := strings.Split(rawKey, ",")[0]
if key == "id" {
continue
}
switch t := v.Field(i).Interface().(type) {
case string:
if partial || t != "" {
addKey(key)
}
case int, int64, float64:
if partial || t != 0 {
addKey(key)
}
case optionalValue:
if partial || t.IsValid() {
addKey(key)
}
case sql.NullString:
if partial || t.Valid {
addKey(key)
}
case sql.NullBool:
if partial || t.Valid {
addKey(key)
}
case sql.NullInt64:
if partial || t.Valid {
addKey(key)
}
case sql.NullFloat64:
if partial || t.Valid {
addKey(key)
}
default:
reflectValue := reflect.ValueOf(t)
isNil := reflectValue.IsNil()
if !isNil {
addKey(key)
}
}
}
return strings.Join(query, ", ")
}

96
pkg/database/table.go Normal file
View File

@@ -0,0 +1,96 @@
package database
// NewObjectFunc is a function that returns an instance of an object stored in
// a database table.
type NewObjectFunc func() interface{}
// Table represents a database table.
type Table struct {
name string
newObjectFn NewObjectFunc
}
// Name returns the name of the database table.
func (t Table) Name() string {
return t.name
}
// NewObject returns a new object model of the type that this table stores.
func (t Table) NewObject() interface{} {
return t.newObjectFn()
}
// NewTable creates a new Table object with the provided table name and new
// object function.
func NewTable(name string, newObjectFn NewObjectFunc) Table {
return Table{
name: name,
newObjectFn: newObjectFn,
}
}
// TableJoin represents a database Table that joins two other tables.
type TableJoin struct {
Table
// the primary table that will be joined to this table
primaryTable string
// the column in this table that stores the foreign key to the primary table.
joinColumn string
}
// Creates a new TableJoin instance. The primaryTable is the table that will join
// to the join table. The joinColumn is the name in the join table that stores
// the foreign key in the primary table.
func NewTableJoin(primaryTable string, name string, joinColumn string, newObjectFn func() interface{}) TableJoin {
return TableJoin{
Table: Table{
name: name,
newObjectFn: newObjectFn,
},
primaryTable: primaryTable,
joinColumn: joinColumn,
}
}
// Inverse creates a TableJoin object that is the inverse of this table join.
// The returns TableJoin object will have this table as the primary table.
func (t TableJoin) Inverse(joinColumn string) TableJoin {
return TableJoin{
Table: Table{
name: t.primaryTable,
newObjectFn: t.newObjectFn,
},
primaryTable: t.Name(),
joinColumn: joinColumn,
}
}
// Model is the interface implemented by objects that exist in the database
// that have an `id` column.
type Model interface {
// GetTable returns the table that stores objects of this type.
GetTable() Table
// GetID returns the ID of the object.
GetID() int64
}
// Models is the interface implemented by slices of Model objects.
type Models interface {
// Add adds a new object to the slice. It is assumed that the passed
// object can be type asserted to the correct type.
Add(interface{})
}
// Joins is the interface implemented by slices of join objects.
type Joins interface {
// Each calls the provided function on each of the concrete (not pointer)
// objects in the slice.
Each(func(interface{}))
// Add adds a new object to the slice. It is assumed that the passed
// object can be type asserted to the correct type.
Add(interface{})
}

View File

@@ -0,0 +1,93 @@
package database
import (
"context"
"github.com/jmoiron/sqlx"
)
type Transaction interface {
Begin(ctx context.Context) *sqlx.Tx
Commit() error
Rollback() error
GetTx() *sqlx.Tx
}
type transaction struct {
tx *sqlx.Tx
closed bool
}
func NewTransaction(ctx context.Context) Transaction {
return &transaction{}
}
func (t *transaction) close() {
t.closed = true
}
func (t *transaction) Begin(ctx context.Context) *sqlx.Tx {
if t.tx != nil {
panic("Begin called twice on the same Transaction")
}
if t.closed {
panic("Begin called on closed Transaction")
}
t.tx = DB.MustBeginTx(ctx, nil)
return t.tx
}
func (t *transaction) Commit() error {
if t.closed {
panic("Commit called on closed transaction")
}
if t.tx == nil {
panic("Commit called before Begin")
}
defer t.close()
return t.tx.Commit()
}
func (t *transaction) Rollback() error {
if t.tx == nil {
panic("Rollback called before begin")
}
defer t.close()
return t.tx.Rollback()
}
func (t *transaction) GetTx() *sqlx.Tx {
return t.tx
}
type TxFunc func(Transaction) error
func WithTransaction(ctx context.Context, fn TxFunc) error {
txn := NewTransaction(ctx)
txn.Begin(ctx)
var err error
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
txn.Rollback()
panic(p)
} else if err != nil {
// something went wrong, rollback
txn.Rollback()
} else {
// all good, commit
err = txn.Commit()
}
}()
err = fn(txn)
return err
}

View File

@@ -1,16 +1,56 @@
package models
type PerformersScenes struct {
PerformerID int `db:"performer_id" json:"performer_id"`
SceneID int `db:"scene_id" json:"scene_id"`
import (
"database/sql"
"github.com/stashapp/stashdb/pkg/database"
)
var (
scenePerformerTable = database.NewTableJoin(sceneTable, "scene_performers", sceneJoinKey, func() interface{} {
return &PerformerScene{}
})
performerSceneTable = scenePerformerTable.Inverse(performerJoinKey)
sceneTagTable = database.NewTableJoin(sceneTable, "scene_tags", sceneJoinKey, func() interface{} {
return &SceneTag{}
})
tagSceneTable = sceneTagTable.Inverse(tagJoinKey)
)
type PerformerScene struct {
PerformerID int64 `db:"performer_id" json:"performer_id"`
As sql.NullString `db:"as" json:"as"`
SceneID int64 `db:"scene_id" json:"scene_id"`
}
type ScenesTags struct {
SceneID int `db:"scene_id" json:"scene_id"`
TagID int `db:"tag_id" json:"tag_id"`
type PerformersScenes []*PerformerScene
func (p PerformersScenes) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
type SceneMarkersTags struct {
SceneMarkerID int `db:"scene_marker_id" json:"scene_marker_id"`
TagID int `db:"tag_id" json:"tag_id"`
func (p *PerformersScenes) Add(o interface{}) {
*p = append(*p, o.(*PerformerScene))
}
type SceneTag struct {
SceneID int64 `db:"scene_id" json:"scene_id"`
TagID int64 `db:"tag_id" json:"tag_id"`
}
type ScenesTags []*SceneTag
func (p ScenesTags) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *ScenesTags) Add(o interface{}) {
*p = append(*p, o.(*SceneTag))
}

View File

@@ -2,6 +2,36 @@ package models
import (
"database/sql"
"github.com/stashapp/stashdb/pkg/database"
"github.com/stashapp/stashdb/pkg/utils"
)
const (
performerTable = "performers"
performerJoinKey = "performer_id"
)
var (
performerDBTable = database.NewTable(performerTable, func() interface{} {
return &Performer{}
})
performerAliasTable = database.NewTableJoin(performerTable, "performer_aliases", performerJoinKey, func() interface{} {
return &PerformerAlias{}
})
performerUrlTable = database.NewTableJoin(performerTable, "performer_urls", performerJoinKey, func() interface{} {
return &PerformerUrl{}
})
performerTattooTable = database.NewTableJoin(performerTable, "performer_tattoos", performerJoinKey, func() interface{} {
return &PerformerBodyMod{}
})
performerPiercingTable = database.NewTableJoin(performerTable, "performer_piercings", performerJoinKey, func() interface{} {
return &PerformerBodyMod{}
})
)
type Performer struct {
@@ -28,39 +58,92 @@ type Performer struct {
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}
type PerformerAliases struct {
func (Performer) GetTable() database.Table {
return performerDBTable
}
func (p Performer) GetID() int64 {
return p.ID
}
type Performers []*Performer
func (p Performers) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *Performers) Add(o interface{}) {
*p = append(*p, o.(*Performer))
}
type PerformerAlias struct {
PerformerID int64 `db:"performer_id" json:"performer_id"`
Alias string `db:"alias" json:"alias"`
}
func CreatePerformerAliases(performerId int64, aliases []string) []PerformerAliases {
var ret []PerformerAliases
type PerformerAliases []*PerformerAlias
for _, alias := range aliases {
ret = append(ret, PerformerAliases{PerformerID: performerId, Alias: alias})
func (p PerformerAliases) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *PerformerAliases) Add(o interface{}) {
*p = append(*p, o.(*PerformerAlias))
}
func (p PerformerAliases) ToAliases() []string {
var ret []string
for _, v := range p {
ret = append(ret, v.Alias)
}
return ret
}
type PerformerUrls struct {
func CreatePerformerAliases(performerId int64, aliases []string) PerformerAliases {
var ret PerformerAliases
for _, alias := range aliases {
ret = append(ret, &PerformerAlias{PerformerID: performerId, Alias: alias})
}
return ret
}
type PerformerUrl struct {
PerformerID int64 `db:"performer_id" json:"performer_id"`
URL string `db:"url" json:"url"`
Type string `db:"type" json:"type"`
}
func (p *PerformerUrls) ToURL() URL {
func (p *PerformerUrl) ToURL() URL {
return URL{
URL: p.URL,
Type: p.Type,
}
}
func CreatePerformerUrls(performerId int64, urls []*URLInput) []PerformerUrls {
var ret []PerformerUrls
type PerformerUrls []*PerformerUrl
func (p PerformerUrls) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *PerformerUrls) Add(o interface{}) {
*p = append(*p, o.(*PerformerUrl))
}
func CreatePerformerUrls(performerId int64, urls []*URLInput) PerformerUrls {
var ret PerformerUrls
for _, urlInput := range urls {
ret = append(ret, PerformerUrls{
ret = append(ret, &PerformerUrl{
PerformerID: performerId,
URL: urlInput.URL,
Type: urlInput.Type,
@@ -70,13 +153,13 @@ func CreatePerformerUrls(performerId int64, urls []*URLInput) []PerformerUrls {
return ret
}
type PerformerBodyMods struct {
type PerformerBodyMod struct {
PerformerID int64 `db:"performer_id" json:"performer_id"`
Location string `db:"location" json:"location"`
Description sql.NullString `db:"description" json:"description"`
}
func (m PerformerBodyMods) ToBodyModification() BodyModification {
func (m PerformerBodyMod) ToBodyModification() BodyModification {
ret := BodyModification{
Location: m.Location,
}
@@ -87,8 +170,20 @@ func (m PerformerBodyMods) ToBodyModification() BodyModification {
return ret
}
func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) []PerformerBodyMods {
var ret []PerformerBodyMods
type PerformerBodyMods []*PerformerBodyMod
func (p PerformerBodyMods) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *PerformerBodyMods) Add(o interface{}) {
*p = append(*p, o.(*PerformerBodyMod))
}
func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) PerformerBodyMods {
var ret PerformerBodyMods
for _, bmInput := range urls {
description := sql.NullString{}
@@ -97,7 +192,7 @@ func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) [
description.String = *bmInput.Description
description.Valid = true
}
ret = append(ret, PerformerBodyMods{
ret = append(ret, &PerformerBodyMod{
PerformerID: performerId,
Location: bmInput.Location,
Description: description,
@@ -168,9 +263,27 @@ func (p Performer) ResolveMeasurements() Measurements {
return ret
}
func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) {
func (p *Performer) TranslateImageData(inputData *string) ([]byte, error) {
var imageData []byte
var err error
_, imageData, err = utils.ProcessBase64Image(*inputData)
return imageData, err
}
func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) error {
CopyFull(p, input)
if input.Image != nil {
var err error
p.Image, err = p.TranslateImageData(input.Image)
if err != nil {
return err
}
}
if input.Birthdate != nil {
p.setBirthdate(*input.Birthdate)
}
@@ -178,11 +291,22 @@ func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) {
if input.Measurements != nil {
p.setMeasurements(*input.Measurements)
}
return nil
}
func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) {
func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) error {
CopyFull(p, input)
if input.Image != nil {
var err error
p.Image, err = p.TranslateImageData(input.Image)
if err != nil {
return err
}
}
if input.Birthdate != nil {
p.setBirthdate(*input.Birthdate)
}
@@ -190,4 +314,6 @@ func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) {
if input.Measurements != nil {
p.setMeasurements(*input.Measurements)
}
return nil
}

158
pkg/models/model_scene.go Normal file
View File

@@ -0,0 +1,158 @@
package models
import (
"database/sql"
"strconv"
"github.com/stashapp/stashdb/pkg/database"
)
const (
sceneTable = "scenes"
sceneJoinKey = "scene_id"
)
var (
sceneDBTable = database.NewTable(sceneTable, func() interface{} {
return &Scene{}
})
sceneFingerprintTable = database.NewTableJoin(sceneTable, "scene_fingerprints", sceneJoinKey, func() interface{} {
return &SceneFingerprint{}
})
)
type Scene struct {
ID int64 `db:"id" json:"id"`
Title sql.NullString `db:"title" json:"title"`
Details sql.NullString `db:"details" json:"details"`
URL sql.NullString `db:"url" json:"url"`
Date SQLiteDate `db:"date" json:"date"`
StudioID sql.NullInt64 `db:"studio_id,omitempty" json:"studio_id"`
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}
func (Scene) GetTable() database.Table {
return sceneDBTable
}
func (p Scene) GetID() int64 {
return p.ID
}
type Scenes []*Scene
func (p Scenes) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *Scenes) Add(o interface{}) {
*p = append(*p, o.(*Scene))
}
type SceneFingerprint struct {
SceneID int64 `db:"scene_id" json:"scene_id"`
Hash string `db:"hash" json:"hash"`
Algorithm string `db:"algorithm" json:"algorithm"`
}
func (p SceneFingerprint) ToFingerprint() *Fingerprint {
return &Fingerprint{
Algorithm: FingerprintAlgorithm(p.Algorithm),
Hash: p.Hash,
}
}
type SceneFingerprints []*SceneFingerprint
func (p SceneFingerprints) Each(fn func(interface{})) {
for _, v := range p {
fn(*v)
}
}
func (p *SceneFingerprints) Add(o interface{}) {
*p = append(*p, o.(*SceneFingerprint))
}
func (p SceneFingerprints) ToFingerprints() []*Fingerprint {
var ret []*Fingerprint
for _, v := range p {
ret = append(ret, v.ToFingerprint())
}
return ret
}
func CreateSceneFingerprints(sceneID int64, fingerprints []*FingerprintInput) SceneFingerprints {
var ret SceneFingerprints
for _, fingerprint := range fingerprints {
ret = append(ret, &SceneFingerprint{
SceneID: sceneID,
Hash: fingerprint.Hash,
Algorithm: fingerprint.Algorithm.String(),
})
}
return ret
}
func CreateSceneTags(sceneID int64, tagIds []string) ScenesTags {
var tagJoins ScenesTags
for _, tid := range tagIds {
tagID, _ := strconv.ParseInt(tid, 10, 64)
tagJoin := &SceneTag{
SceneID: sceneID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
}
return tagJoins
}
func CreateScenePerformers(sceneID int64, appearances []*PerformerAppearanceInput) PerformersScenes {
var performerJoins PerformersScenes
for _, a := range appearances {
performerID, _ := strconv.ParseInt(a.PerformerID, 10, 64)
performerJoin := &PerformerScene{
SceneID: sceneID,
PerformerID: performerID,
}
if a.As != nil {
performerJoin.As = sql.NullString{Valid: true, String: *a.As}
}
performerJoins = append(performerJoins, performerJoin)
}
return performerJoins
}
func (p *Scene) IsEditTarget() {
}
func (p *Scene) setDate(date string) {
p.Date = SQLiteDate{String: date, Valid: true}
}
func (p *Scene) CopyFromCreateInput(input SceneCreateInput) {
CopyFull(p, input)
if input.Date != nil {
p.setDate(*input.Date)
}
}
func (p *Scene) CopyFromUpdateInput(input SceneUpdateInput) {
CopyFull(p, input)
if input.Date != nil {
p.setDate(*input.Date)
}
}

116
pkg/models/model_studio.go Normal file
View File

@@ -0,0 +1,116 @@
package models
import (
"database/sql"
"strconv"
"github.com/stashapp/stashdb/pkg/database"
)
const (
studioTable = "studios"
studioJoinKey = "studio_id"
)
var (
studioDBTable = database.NewTable(studioTable, func() interface{} {
return &Studio{}
})
studioUrlTable = database.NewTableJoin(studioTable, "studio_urls", studioJoinKey, func() interface{} {
return &StudioUrl{}
})
)
type Studio struct {
ID int64 `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Image []byte `db:"image" json:"image"`
ParentStudioID sql.NullInt64 `db:"parent_studio_id,omitempty" json:"parent_studio_id"`
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}
func (Studio) GetTable() database.Table {
return studioDBTable
}
func (p Studio) GetID() int64 {
return p.ID
}
type Studios []*Studio
func (p Studios) Each(fn func(interface{})) {
for _, v := range p {
fn(v)
}
}
func (p *Studios) Add(o interface{}) {
*p = append(*p, o.(*Studio))
}
type StudioUrl struct {
StudioID int64 `db:"studio_id" json:"studio_id"`
URL string `db:"url" json:"url"`
Type string `db:"type" json:"type"`
}
func (p *StudioUrl) ToURL() URL {
return URL{
URL: p.URL,
Type: p.Type,
}
}
type StudioUrls []StudioUrl
func (p StudioUrls) Each(fn func(interface{})) {
for _, v := range p {
fn(v)
}
}
func (p *StudioUrls) Add(o interface{}) {
*p = append(*p, o.(StudioUrl))
}
func CreateStudioUrls(studioId int64, urls []*URLInput) []StudioUrl {
var ret []StudioUrl
for _, urlInput := range urls {
ret = append(ret, StudioUrl{
StudioID: studioId,
URL: urlInput.URL,
Type: urlInput.Type,
})
}
return ret
}
func (p *Studio) IsEditTarget() {
}
func (p *Studio) CopyFromCreateInput(input StudioCreateInput) {
CopyFull(p, input)
if input.ParentID != nil {
parentID, err := strconv.ParseInt(*input.ParentID, 10, 64)
if err == nil {
p.ParentStudioID = sql.NullInt64{Int64: parentID, Valid: true}
}
}
}
func (p *Studio) CopyFromUpdateInput(input StudioUpdateInput) {
CopyFull(p, input)
if input.ParentID != nil {
parentID, err := strconv.ParseInt(*input.ParentID, 10, 64)
if err == nil {
p.ParentStudioID = sql.NullInt64{Int64: parentID, Valid: true}
}
}
}

97
pkg/models/model_tag.go Normal file
View File

@@ -0,0 +1,97 @@
package models
import (
"database/sql"
"github.com/stashapp/stashdb/pkg/database"
)
const (
tagTable = "tags"
tagJoinKey = "tag_id"
)
var (
tagDBTable = database.NewTable(tagTable, func() interface{} {
return &Tag{}
})
tagAliasTable = database.NewTableJoin(tagTable, "tag_aliases", tagJoinKey, func() interface{} {
return &TagAlias{}
})
)
type Tag struct {
ID int64 `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description sql.NullString `db:"description" json:"description"`
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}
func (Tag) GetTable() database.Table {
return tagDBTable
}
func (p Tag) GetID() int64 {
return p.ID
}
type Tags []*Tag
func (p Tags) Each(fn func(interface{})) {
for _, v := range p {
fn(v)
}
}
func (p *Tags) Add(o interface{}) {
*p = append(*p, o.(*Tag))
}
type TagAlias struct {
TagID int64 `db:"tag_id" json:"tag_id"`
Alias string `db:"alias" json:"alias"`
}
type TagAliases []TagAlias
func (p TagAliases) Each(fn func(interface{})) {
for _, v := range p {
fn(v)
}
}
func (p *TagAliases) Add(o interface{}) {
*p = append(*p, o.(TagAlias))
}
func (p TagAliases) ToAliases() []string {
var ret []string
for _, v := range p {
ret = append(ret, v.Alias)
}
return ret
}
func CreateTagAliases(tagId int64, aliases []string) []TagAlias {
var ret []TagAlias
for _, alias := range aliases {
ret = append(ret, TagAlias{TagID: tagId, Alias: alias})
}
return ret
}
func (p *Tag) IsEditTarget() {
}
func (p *Tag) CopyFromCreateInput(input TagCreateInput) {
CopyFull(p, input)
}
func (p *Tag) CopyFromUpdateInput(input TagUpdateInput) {
CopyFull(p, input)
}

View File

@@ -1,130 +1,41 @@
package models
import "github.com/jmoiron/sqlx"
import (
"github.com/jmoiron/sqlx"
type JoinsQueryBuilder struct{}
"github.com/stashapp/stashdb/pkg/database"
)
func NewJoinsQueryBuilder() JoinsQueryBuilder {
return JoinsQueryBuilder{}
type JoinsQueryBuilder struct {
dbi database.DBI
}
func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins []PerformersScenes, tx *sqlx.Tx) error {
ensureTx(tx)
for _, join := range newJoins {
_, err := tx.NamedExec(
`INSERT INTO performers_scenes (performer_id, scene_id) VALUES (:performer_id, :scene_id)`,
join,
)
if err != nil {
return err
}
func NewJoinsQueryBuilder(tx *sqlx.Tx) JoinsQueryBuilder {
return JoinsQueryBuilder{
dbi: database.DBIWithTxn(tx),
}
return nil
}
func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
_, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID)
if err != nil {
return err
}
return qb.CreatePerformersScenes(updatedJoins, tx)
func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins PerformersScenes) error {
return qb.dbi.InsertJoins(scenePerformerTable, &newJoins)
}
func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins
_, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID)
return err
func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int64, updatedJoins PerformersScenes) error {
return qb.dbi.ReplaceJoins(scenePerformerTable, sceneID, &updatedJoins)
}
func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins []ScenesTags, tx *sqlx.Tx) error {
ensureTx(tx)
for _, join := range newJoins {
_, err := tx.NamedExec(
`INSERT INTO scenes_tags (scene_id, tag_id) VALUES (:scene_id, :tag_id)`,
join,
)
if err != nil {
return err
}
}
return nil
func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int64) error {
return qb.dbi.DeleteJoins(scenePerformerTable, sceneID)
}
func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int, updatedJoins []ScenesTags, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
_, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID)
if err != nil {
return err
}
return qb.CreateScenesTags(updatedJoins, tx)
func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins ScenesTags) error {
return qb.dbi.InsertJoins(sceneTagTable, &newJoins)
}
func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins
_, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID)
return err
func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int64, updatedJoins ScenesTags) error {
return qb.dbi.ReplaceJoins(sceneTagTable, sceneID, &updatedJoins)
}
func (qb *JoinsQueryBuilder) CreateSceneMarkersTags(newJoins []SceneMarkersTags, tx *sqlx.Tx) error {
ensureTx(tx)
for _, join := range newJoins {
_, err := tx.NamedExec(
`INSERT INTO scene_markers_tags (scene_marker_id, tag_id) VALUES (:scene_marker_id, :tag_id)`,
join,
)
if err != nil {
return err
}
}
return nil
}
func (qb *JoinsQueryBuilder) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
_, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID)
if err != nil {
return err
}
return qb.CreateSceneMarkersTags(updatedJoins, tx)
}
func (qb *JoinsQueryBuilder) DestroySceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins
_, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID)
return err
}
func (qb *JoinsQueryBuilder) DestroyScenesGalleries(sceneID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Unset the existing scene id from galleries
_, err := tx.Exec("UPDATE galleries SET scene_id = null WHERE scene_id = ?", sceneID)
return err
}
func (qb *JoinsQueryBuilder) DestroyScenesMarkers(sceneID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the scene marker tags
_, err := tx.Exec("DELETE t FROM scene_markers_tags t join scene_markers m on t.scene_marker_id = m.id WHERE m.scene_id = ?", sceneID)
// Delete the existing joins
_, err = tx.Exec("DELETE FROM scene_markers WHERE scene_id = ?", sceneID)
return err
func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int64) error {
return qb.dbi.DeleteJoins(sceneTagTable, sceneID)
}

View File

@@ -1,129 +1,84 @@
package models
import (
"database/sql"
"strconv"
"time"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/stashapp/stashdb/pkg/database"
"github.com/jmoiron/sqlx"
)
type PerformerQueryBuilder struct{}
const performerTable = "performers"
const performerAliasesJoinTable = "performer_aliases"
const performerUrlsJoinTable = "performer_urls"
const performerTattoosJoinTable = "performer_tattoos"
const performerPiercingsJoinTable = "performer_piercings"
const performerJoinKey = "performer_id"
func NewPerformerQueryBuilder() PerformerQueryBuilder {
return PerformerQueryBuilder{}
type PerformerQueryBuilder struct {
dbi database.DBI
}
func (qb *PerformerQueryBuilder) Create(newPerformer Performer, tx *sqlx.Tx) (*Performer, error) {
performerID, err := insertObject(tx, performerTable, newPerformer)
func NewPerformerQueryBuilder(tx *sqlx.Tx) PerformerQueryBuilder {
return PerformerQueryBuilder{
dbi: database.DBIWithTxn(tx),
}
}
if err != nil {
return nil, errors.Wrap(err, "Error creating performer")
func (qb *PerformerQueryBuilder) toModel(ro interface{}) *Performer {
if ro != nil {
return ro.(*Performer)
}
if err := getByID(tx, performerTable, performerID, &newPerformer); err != nil {
return nil, errors.Wrap(err, "Error getting performer after create")
}
return &newPerformer, nil
return nil
}
func (qb *PerformerQueryBuilder) Update(updatedPerformer Performer, tx *sqlx.Tx) (*Performer, error) {
err := updateObjectByID(tx, performerTable, updatedPerformer)
if err != nil {
return nil, errors.Wrap(err, "Error updating performer")
}
if err := getByID(tx, performerTable, updatedPerformer.ID, &updatedPerformer); err != nil {
return nil, errors.Wrap(err, "Error getting performer after update")
}
return &updatedPerformer, nil
func (qb *PerformerQueryBuilder) Create(newPerformer Performer) (*Performer, error) {
ret, err := qb.dbi.Insert(newPerformer)
return qb.toModel(ret), err
}
func (qb *PerformerQueryBuilder) Destroy(id int64, tx *sqlx.Tx) error {
return executeDeleteQuery(performerTable, id, tx)
func (qb *PerformerQueryBuilder) Update(updatedPerformer Performer) (*Performer, error) {
ret, err := qb.dbi.Update(updatedPerformer)
return qb.toModel(ret), err
}
func (qb *PerformerQueryBuilder) CreateAliases(newJoins []PerformerAliases, tx *sqlx.Tx) error {
return insertJoins(tx, performerAliasesJoinTable, newJoins)
func (qb *PerformerQueryBuilder) Destroy(id int64) error {
return qb.dbi.Delete(id, performerDBTable)
}
func (qb *PerformerQueryBuilder) UpdateAliases(performerID int64, updatedJoins []PerformerAliases, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
err := deleteObjectsByColumn(tx, performerAliasesJoinTable, performerJoinKey, performerID)
if err != nil {
return err
}
return qb.CreateAliases(updatedJoins, tx)
func (qb *PerformerQueryBuilder) CreateAliases(newJoins PerformerAliases) error {
return qb.dbi.InsertJoins(performerAliasTable, &newJoins)
}
func (qb *PerformerQueryBuilder) CreateUrls(newJoins []PerformerUrls, tx *sqlx.Tx) error {
return insertJoins(tx, performerUrlsJoinTable, newJoins)
func (qb *PerformerQueryBuilder) UpdateAliases(performerID int64, updatedJoins PerformerAliases) error {
return qb.dbi.ReplaceJoins(performerAliasTable, performerID, &updatedJoins)
}
func (qb *PerformerQueryBuilder) UpdateUrls(performerID int64, updatedJoins []PerformerUrls, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
err := deleteObjectsByColumn(tx, performerUrlsJoinTable, performerJoinKey, performerID)
if err != nil {
return err
}
return qb.CreateUrls(updatedJoins, tx)
func (qb *PerformerQueryBuilder) CreateUrls(newJoins PerformerUrls) error {
return qb.dbi.InsertJoins(performerUrlTable, &newJoins)
}
func (qb *PerformerQueryBuilder) CreateTattoos(newJoins []PerformerBodyMods, tx *sqlx.Tx) error {
return insertJoins(tx, performerTattoosJoinTable, newJoins)
func (qb *PerformerQueryBuilder) UpdateUrls(performerID int64, updatedJoins PerformerUrls) error {
return qb.dbi.ReplaceJoins(performerUrlTable, performerID, &updatedJoins)
}
func (qb *PerformerQueryBuilder) UpdateTattoos(performerID int64, updatedJoins []PerformerBodyMods, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
err := deleteObjectsByColumn(tx, performerTattoosJoinTable, performerJoinKey, performerID)
if err != nil {
return err
}
return qb.CreateTattoos(updatedJoins, tx)
func (qb *PerformerQueryBuilder) CreateTattoos(newJoins PerformerBodyMods) error {
return qb.dbi.InsertJoins(performerTattooTable, &newJoins)
}
func (qb *PerformerQueryBuilder) CreatePiercings(newJoins []PerformerBodyMods, tx *sqlx.Tx) error {
return insertJoins(tx, performerPiercingsJoinTable, newJoins)
func (qb *PerformerQueryBuilder) UpdateTattoos(performerID int64, updatedJoins PerformerBodyMods) error {
return qb.dbi.ReplaceJoins(performerTattooTable, performerID, &updatedJoins)
}
func (qb *PerformerQueryBuilder) UpdatePiercings(performerID int64, updatedJoins []PerformerBodyMods, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins and then create new ones
err := deleteObjectsByColumn(tx, performerPiercingsJoinTable, performerJoinKey, performerID)
if err != nil {
return err
}
return qb.CreateTattoos(updatedJoins, tx)
func (qb *PerformerQueryBuilder) CreatePiercings(newJoins PerformerBodyMods) error {
return qb.dbi.InsertJoins(performerPiercingTable, &newJoins)
}
func (qb *PerformerQueryBuilder) Find(id int) (*Performer, error) {
query := "SELECT * FROM performers WHERE id = ? LIMIT 1"
args := []interface{}{id}
results, err := qb.queryPerformers(query, args, nil)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
func (qb *PerformerQueryBuilder) UpdatePiercings(performerID int64, updatedJoins PerformerBodyMods) error {
return qb.dbi.ReplaceJoins(performerPiercingTable, performerID, &updatedJoins)
}
func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Performer, error) {
func (qb *PerformerQueryBuilder) Find(id int64) (*Performer, error) {
ret, err := qb.dbi.Find(id, performerDBTable)
return qb.toModel(ret), err
}
func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int) (Performers, error) {
query := `
SELECT performers.* FROM performers
LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id
@@ -132,19 +87,19 @@ func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Per
GROUP BY performers.id
`
args := []interface{}{sceneID}
return qb.queryPerformers(query, args, tx)
return qb.queryPerformers(query, args)
}
func (qb *PerformerQueryBuilder) FindByNames(names []string, tx *sqlx.Tx) ([]*Performer, error) {
func (qb *PerformerQueryBuilder) FindByNames(names []string) (Performers, error) {
query := "SELECT * FROM performers WHERE name IN " + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
return qb.queryPerformers(query, args, tx)
return qb.queryPerformers(query, args)
}
func (qb *PerformerQueryBuilder) FindByAliases(names []string, tx *sqlx.Tx) ([]*Performer, error) {
func (qb *PerformerQueryBuilder) FindByAliases(names []string) (Performers, error) {
query := `SELECT performers.* FROM performers
left join performer_aliases on performers.id = performer_aliases.performer_id
WHERE performer_aliases.alias IN ` + getInBinding(len(names))
@@ -153,24 +108,24 @@ func (qb *PerformerQueryBuilder) FindByAliases(names []string, tx *sqlx.Tx) ([]*
for _, name := range names {
args = append(args, name)
}
return qb.queryPerformers(query, args, tx)
return qb.queryPerformers(query, args)
}
func (qb *PerformerQueryBuilder) FindByName(name string, tx *sqlx.Tx) ([]*Performer, error) {
func (qb *PerformerQueryBuilder) FindByName(name string) (Performers, error) {
query := "SELECT * FROM performers WHERE upper(name) = upper(?)"
var args []interface{}
args = append(args, name)
return qb.queryPerformers(query, args, tx)
return qb.queryPerformers(query, args)
}
func (qb *PerformerQueryBuilder) FindByAlias(name string, tx *sqlx.Tx) ([]*Performer, error) {
func (qb *PerformerQueryBuilder) FindByAlias(name string) (Performers, error) {
query := `SELECT performers.* FROM performers
left join performer_aliases on performers.id = performer_aliases.performer_id
WHERE upper(performer_aliases.alias) = UPPER(?)`
var args []interface{}
args = append(args, name)
return qb.queryPerformers(query, args, tx)
return qb.queryPerformers(query, args)
}
func (qb *PerformerQueryBuilder) Count() (int, error) {
@@ -233,27 +188,6 @@ func (qb *PerformerQueryBuilder) Query(performerFilter *PerformerFilterType, fin
return performers, countResult
}
func handleStringCriterion(column string, value *StringCriterionInput, query *queryBuilder) {
if value != nil {
if modifier := value.Modifier.String(); value.Modifier.IsValid() {
switch modifier {
case "EQUALS":
clause, thisArgs := getSearchBinding([]string{column}, value.Value, false)
query.addWhere(clause)
query.addArg(thisArgs...)
case "NOT_EQUALS":
clause, thisArgs := getSearchBinding([]string{column}, value.Value, true)
query.addWhere(clause)
query.addArg(thisArgs...)
case "IS_NULL":
query.addWhere(column + " IS NULL")
case "NOT_NULL":
query.addWhere(column + " IS NOT NULL")
}
}
}
}
func getBirthYearFilterClause(criterionModifier CriterionModifier, value int) ([]string, []interface{}) {
var clauses []string
var args []interface{}
@@ -338,142 +272,36 @@ func (qb *PerformerQueryBuilder) getPerformerSort(findFilter *QuerySpec) string
return getSort(sort, direction, "performers")
}
func (qb *PerformerQueryBuilder) queryPerformers(query string, args []interface{}, tx *sqlx.Tx) ([]*Performer, error) {
var rows *sqlx.Rows
var err error
if tx != nil {
rows, err = tx.Queryx(query, args...)
} else {
rows, err = database.DB.Queryx(query, args...)
}
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
performers := make([]*Performer, 0)
for rows.Next() {
performer := Performer{}
if err := rows.StructScan(&performer); err != nil {
return nil, err
}
performers = append(performers, &performer)
}
if err := rows.Err(); err != nil {
return nil, err
}
return performers, nil
func (qb *PerformerQueryBuilder) queryPerformers(query string, args []interface{}) (Performers, error) {
output := Performers{}
err := qb.dbi.RawQuery(performerDBTable, query, args, &output)
return output, err
}
func (qb *PerformerQueryBuilder) GetAliases(id int64) ([]string, error) {
query := "SELECT alias FROM performer_aliases WHERE performer_id = ?"
args := []interface{}{id}
joins := PerformerAliases{}
err := qb.dbi.FindJoins(performerAliasTable, id, &joins)
var rows *sqlx.Rows
var err error
rows, err = database.DB.Queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
aliases := make([]string, 0)
for rows.Next() {
var alias string
if err := rows.Scan(&alias); err != nil {
return nil, err
}
aliases = append(aliases, alias)
}
if err := rows.Err(); err != nil {
return nil, err
}
return aliases, nil
return joins.ToAliases(), err
}
func (qb *PerformerQueryBuilder) GetUrls(id int64) ([]PerformerUrls, error) {
query := "SELECT url, type FROM performer_urls WHERE performer_id = ?"
args := []interface{}{id}
func (qb *PerformerQueryBuilder) GetUrls(id int64) (PerformerUrls, error) {
joins := PerformerUrls{}
err := qb.dbi.FindJoins(performerUrlTable, id, &joins)
var rows *sqlx.Rows
var err error
rows, err = database.DB.Queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
urls := make([]PerformerUrls, 0)
for rows.Next() {
var performerUrl PerformerUrls
if err := rows.Scan(&performerUrl); err != nil {
return nil, err
}
urls = append(urls, performerUrl)
}
if err := rows.Err(); err != nil {
return nil, err
}
return urls, nil
return joins, err
}
func translateBodyMods(rows *sqlx.Rows) ([]PerformerBodyMods, error) {
ret := make([]PerformerBodyMods, 0)
for rows.Next() {
var performerBodyMod PerformerBodyMods
func (qb *PerformerQueryBuilder) GetTattoos(id int64) (PerformerBodyMods, error) {
joins := PerformerBodyMods{}
err := qb.dbi.FindJoins(performerTattooTable, id, &joins)
if err := rows.Scan(&performerBodyMod); err != nil {
return nil, err
}
ret = append(ret, performerBodyMod)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ret, nil
return joins, err
}
func (qb *PerformerQueryBuilder) GetTattoos(id int64) ([]PerformerBodyMods, error) {
query := "SELECT location, description FROM performer_tattoos WHERE performer_id = ?"
args := []interface{}{id}
func (qb *PerformerQueryBuilder) GetPiercings(id int64) (PerformerBodyMods, error) {
joins := PerformerBodyMods{}
err := qb.dbi.FindJoins(performerPiercingTable, id, &joins)
var rows *sqlx.Rows
var err error
rows, err = database.DB.Queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
return translateBodyMods(rows)
}
func (qb *PerformerQueryBuilder) GetPiercings(id int64) ([]PerformerBodyMods, error) {
query := "SELECT location, description FROM performer_piercings WHERE performer_id = ?"
args := []interface{}{id}
var rows *sqlx.Rows
var err error
rows, err = database.DB.Queryx(query, args...)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
return translateBodyMods(rows)
return joins, err
}

View File

@@ -0,0 +1,194 @@
package models
import (
"github.com/jmoiron/sqlx"
"github.com/stashapp/stashdb/pkg/database"
)
type SceneQueryBuilder struct {
dbi database.DBI
}
func NewSceneQueryBuilder(tx *sqlx.Tx) SceneQueryBuilder {
return SceneQueryBuilder{
dbi: database.DBIWithTxn(tx),
}
}
func (qb *SceneQueryBuilder) toModel(ro interface{}) *Scene {
if ro != nil {
return ro.(*Scene)
}
return nil
}
func (qb *SceneQueryBuilder) Create(newScene Scene) (*Scene, error) {
ret, err := qb.dbi.Insert(newScene)
return qb.toModel(ret), err
}
func (qb *SceneQueryBuilder) Update(updatedScene Scene) (*Scene, error) {
ret, err := qb.dbi.Update(updatedScene)
return qb.toModel(ret), err
}
func (qb *SceneQueryBuilder) Destroy(id int64) error {
return qb.dbi.Delete(id, sceneDBTable)
}
func (qb *SceneQueryBuilder) CreateFingerprints(newJoins SceneFingerprints) error {
return qb.dbi.InsertJoins(sceneFingerprintTable, &newJoins)
}
func (qb *SceneQueryBuilder) UpdateFingerprints(sceneID int64, updatedJoins SceneFingerprints) error {
return qb.dbi.ReplaceJoins(sceneFingerprintTable, sceneID, &updatedJoins)
}
func (qb *SceneQueryBuilder) Find(id int64) (*Scene, error) {
ret, err := qb.dbi.Find(id, sceneDBTable)
return qb.toModel(ret), err
}
func (qb *SceneQueryBuilder) FindByFingerprint(algorithm FingerprintAlgorithm, hash string) ([]*Scene, error) {
query := `
SELECT scenes.* FROM scenes
LEFT JOIN scene_fingerprints as scenes_join on scenes_join.scene_id = scenes.id
WHERE scenes_join.algorithm = ? AND scenes_join.hash = ?`
var args []interface{}
args = append(args, algorithm.String())
args = append(args, hash)
return qb.queryScenes(query, args)
}
// func (qb *SceneQueryBuilder) FindByStudioID(sceneID int) ([]*Scene, error) {
// query := `
// SELECT scenes.* FROM scenes
// LEFT JOIN scenes_scenes as scenes_join on scenes_join.scene_id = scenes.id
// LEFT JOIN scenes on scenes_join.scene_id = scenes.id
// WHERE scenes.id = ?
// GROUP BY scenes.id
// `
// args := []interface{}{sceneID}
// return qb.queryScenes(query, args)
// }
// func (qb *SceneQueryBuilder) FindByChecksum(checksum string) (*Scene, error) {
// query := `SELECT scenes.* FROM scenes
// left join scene_checksums on scenes.id = scene_checksums.scene_id
// WHERE scene_checksums.checksum = ?`
// var args []interface{}
// args = append(args, checksum)
// results, err := qb.queryScenes(query, args)
// if err != nil || len(results) < 1 {
// return nil, err
// }
// return results[0], nil
// }
// func (qb *SceneQueryBuilder) FindByChecksums(checksums []string) ([]*Scene, error) {
// query := `SELECT scenes.* FROM scenes
// left join scene_checksums on scenes.id = scene_checksums.scene_id
// WHERE scene_checksums.checksum IN ` + getInBinding(len(checksums))
// var args []interface{}
// for _, name := range checksums {
// args = append(args, name)
// }
// return qb.queryScenes(query, args)
// }
func (qb *SceneQueryBuilder) FindByTitle(name string) ([]*Scene, error) {
query := "SELECT * FROM scenes WHERE upper(title) = upper(?)"
var args []interface{}
args = append(args, name)
return qb.queryScenes(query, args)
}
func (qb *SceneQueryBuilder) Count() (int, error) {
return runCountQuery(buildCountQuery("SELECT scenes.id FROM scenes"), nil)
}
func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *QuerySpec) ([]*Scene, int) {
if sceneFilter == nil {
sceneFilter = &SceneFilterType{}
}
if findFilter == nil {
findFilter = &QuerySpec{}
}
query := queryBuilder{
tableName: sceneTable,
}
query.body = selectDistinctIDs(sceneTable)
if q := sceneFilter.Text; q != nil && *q != "" {
searchColumns := []string{"scenes.title", "scenes.details"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
query.addWhere(clause)
query.addArg(thisArgs...)
}
if q := sceneFilter.Title; q != nil && *q != "" {
searchColumns := []string{"scenes.title"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
query.addWhere(clause)
query.addArg(thisArgs...)
}
if q := sceneFilter.URL; q != nil && *q != "" {
searchColumns := []string{"scenes.url"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
query.addWhere(clause)
query.addArg(thisArgs...)
}
// TODO - other filters
query.sortAndPagination = qb.getSceneSort(findFilter) + getPagination(findFilter)
idsResult, countResult := query.executeFind()
var scenes []*Scene
for _, id := range idsResult {
scene, _ := qb.Find(id)
scenes = append(scenes, scene)
}
return scenes, countResult
}
func (qb *SceneQueryBuilder) getSceneSort(findFilter *QuerySpec) string {
var sort string
var direction string
if findFilter == nil {
sort = "title"
direction = "ASC"
} else {
sort = findFilter.GetSort("title")
direction = findFilter.GetDirection()
}
return getSort(sort, direction, "scenes")
}
func (qb *SceneQueryBuilder) queryScenes(query string, args []interface{}) (Scenes, error) {
output := Scenes{}
err := qb.dbi.RawQuery(sceneDBTable, query, args, &output)
return output, err
}
func (qb *SceneQueryBuilder) GetFingerprints(id int64) ([]*Fingerprint, error) {
joins := SceneFingerprints{}
err := qb.dbi.FindJoins(sceneFingerprintTable, id, &joins)
return joins.ToFingerprints(), err
}
func (qb *SceneQueryBuilder) GetPerformers(id int64) (PerformersScenes, error) {
joins := PerformersScenes{}
err := qb.dbi.FindJoins(scenePerformerTable, id, &joins)
return joins, err
}

View File

@@ -27,7 +27,7 @@ type queryBuilder struct {
sortAndPagination string
}
func (qb queryBuilder) executeFind() ([]int, int) {
func (qb queryBuilder) executeFind() ([]int64, int) {
return executeFindQuery(qb.tableName, qb.body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses)
}
@@ -43,6 +43,27 @@ func (qb *queryBuilder) addArg(args ...interface{}) {
qb.args = append(qb.args, args...)
}
func handleStringCriterion(column string, value *StringCriterionInput, query *queryBuilder) {
if value != nil {
if modifier := value.Modifier.String(); value.Modifier.IsValid() {
switch modifier {
case "EQUALS":
clause, thisArgs := getSearchBinding([]string{column}, value.Value, false)
query.addWhere(clause)
query.addArg(thisArgs...)
case "NOT_EQUALS":
clause, thisArgs := getSearchBinding([]string{column}, value.Value, true)
query.addWhere(clause)
query.addArg(thisArgs...)
case "IS_NULL":
query.addWhere(column + " IS NULL")
case "NOT_NULL":
query.addWhere(column + " IS NOT NULL")
}
}
}
}
func insertObject(tx *sqlx.Tx, table string, object interface{}) (int64, error) {
ensureTx(tx)
fields, values := SQLGenKeysCreate(object)
@@ -96,7 +117,7 @@ func deleteObjectsByColumn(tx *sqlx.Tx, table string, column string, value inter
}
func getByID(tx *sqlx.Tx, table string, id int64, object interface{}) error {
return tx.Get(object, `SELECT * FROM performers WHERE id = ? LIMIT 1`, id)
return tx.Get(object, `SELECT * FROM `+table+` WHERE id = ? LIMIT 1`, id)
}
func selectAll(tableName string) string {
@@ -166,9 +187,7 @@ func getSort(sort string, direction string, tableName string) string {
} else {
colName := getColumn(tableName, sort)
var additional string
if tableName == "scenes" {
additional = ", bitrate DESC, framerate DESC, rating DESC, duration DESC"
} else if tableName == "scene_markers" {
if tableName == "scene_markers" {
additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC"
}
return " ORDER BY " + colName + " " + direction + additional
@@ -236,15 +255,15 @@ func getInBinding(length int) string {
return "(" + bindings + ")"
}
func runIdsQuery(query string, args []interface{}) ([]int, error) {
func runIdsQuery(query string, args []interface{}) ([]int64, error) {
var result []struct {
Int int `db:"id"`
Int int64 `db:"id"`
}
if err := database.DB.Select(&result, query, args...); err != nil && err != sql.ErrNoRows {
return []int{}, err
return []int64{}, err
}
vsm := make([]int, len(result))
vsm := make([]int64, len(result))
for i, v := range result {
vsm[i] = v.Int
}
@@ -263,7 +282,7 @@ func runCountQuery(query string, args []interface{}) (int, error) {
return result.Int, nil
}
func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int, int) {
func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int64, int) {
if len(whereClauses) > 0 {
body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR
}

View File

@@ -0,0 +1,152 @@
package models
import (
"github.com/jmoiron/sqlx"
"github.com/stashapp/stashdb/pkg/database"
)
type StudioQueryBuilder struct {
dbi database.DBI
}
func NewStudioQueryBuilder(tx *sqlx.Tx) StudioQueryBuilder {
return StudioQueryBuilder{
dbi: database.DBIWithTxn(tx),
}
}
func (qb *StudioQueryBuilder) toModel(ro interface{}) *Studio {
if ro != nil {
return ro.(*Studio)
}
return nil
}
func (qb *StudioQueryBuilder) Create(newStudio Studio) (*Studio, error) {
ret, err := qb.dbi.Insert(newStudio)
return qb.toModel(ret), err
}
func (qb *StudioQueryBuilder) Update(updatedStudio Studio) (*Studio, error) {
ret, err := qb.dbi.Update(updatedStudio)
return qb.toModel(ret), err
}
func (qb *StudioQueryBuilder) Destroy(id int64) error {
return qb.dbi.Delete(id, studioDBTable)
}
func (qb *StudioQueryBuilder) CreateUrls(newJoins StudioUrls) error {
return qb.dbi.InsertJoins(studioUrlTable, &newJoins)
}
func (qb *StudioQueryBuilder) UpdateUrls(studio int64, updatedJoins StudioUrls) error {
return qb.dbi.ReplaceJoins(studioUrlTable, studio, &updatedJoins)
}
func (qb *StudioQueryBuilder) Find(id int64) (*Studio, error) {
ret, err := qb.dbi.Find(id, studioDBTable)
return qb.toModel(ret), err
}
func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (Studios, error) {
query := `
SELECT studios.* FROM studios
LEFT JOIN scenes on scenes.studio_id = studios.id
WHERE scenes.id = ?
GROUP BY studios.id
`
args := []interface{}{sceneID}
return qb.queryStudios(query, args)
}
func (qb *StudioQueryBuilder) FindByNames(names []string) (Studios, error) {
query := "SELECT * FROM studios WHERE name IN " + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
return qb.queryStudios(query, args)
}
func (qb *StudioQueryBuilder) FindByName(name string) (*Studio, error) {
query := "SELECT * FROM studios WHERE upper(name) = upper(?)"
var args []interface{}
args = append(args, name)
results, err := qb.queryStudios(query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
func (qb *StudioQueryBuilder) FindByParentID(id int64) (Studios, error) {
query := "SELECT * FROM studios WHERE parent_studio_id = ?"
var args []interface{}
args = append(args, id)
return qb.queryStudios(query, args)
}
func (qb *StudioQueryBuilder) Count() (int, error) {
return runCountQuery(buildCountQuery("SELECT studios.id FROM studios"), nil)
}
func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *QuerySpec) (Studios, int) {
if studioFilter == nil {
studioFilter = &StudioFilterType{}
}
if findFilter == nil {
findFilter = &QuerySpec{}
}
query := queryBuilder{
tableName: studioTable,
}
query.body = selectDistinctIDs(studioTable)
if q := studioFilter.Name; q != nil && *q != "" {
searchColumns := []string{"studios.name"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
query.addWhere(clause)
query.addArg(thisArgs...)
}
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
idsResult, countResult := query.executeFind()
var studios []*Studio
for _, id := range idsResult {
studio, _ := qb.Find(id)
studios = append(studios, studio)
}
return studios, countResult
}
func (qb *StudioQueryBuilder) getStudioSort(findFilter *QuerySpec) string {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
return getSort(sort, direction, "studios")
}
func (qb *StudioQueryBuilder) queryStudios(query string, args []interface{}) (Studios, error) {
var output Studios
err := qb.dbi.RawQuery(studioDBTable, query, args, &output)
return output, err
}
func (qb *StudioQueryBuilder) GetUrls(id int64) (StudioUrls, error) {
joins := StudioUrls{}
err := qb.dbi.FindJoins(studioUrlTable, id, &joins)
return joins, err
}

View File

@@ -0,0 +1,178 @@
package models
import (
"github.com/jmoiron/sqlx"
"github.com/stashapp/stashdb/pkg/database"
)
type TagQueryBuilder struct {
dbi database.DBI
}
func NewTagQueryBuilder(tx *sqlx.Tx) TagQueryBuilder {
return TagQueryBuilder{
dbi: database.DBIWithTxn(tx),
}
}
func (qb *TagQueryBuilder) toModel(ro interface{}) *Tag {
if ro != nil {
return ro.(*Tag)
}
return nil
}
func (qb *TagQueryBuilder) Create(newTag Tag) (*Tag, error) {
ret, err := qb.dbi.Insert(newTag)
return qb.toModel(ret), err
}
func (qb *TagQueryBuilder) Update(updatedTag Tag) (*Tag, error) {
ret, err := qb.dbi.Update(updatedTag)
return qb.toModel(ret), err
}
func (qb *TagQueryBuilder) Destroy(id int64) error {
return qb.dbi.Delete(id, tagDBTable)
}
func (qb *TagQueryBuilder) CreateAliases(newJoins TagAliases) error {
return qb.dbi.InsertJoins(tagAliasTable, &newJoins)
}
func (qb *TagQueryBuilder) UpdateAliases(tagID int64, updatedJoins TagAliases) error {
return qb.dbi.ReplaceJoins(tagAliasTable, tagID, &updatedJoins)
}
func (qb *TagQueryBuilder) Find(id int64) (*Tag, error) {
ret, err := qb.dbi.Find(id, tagDBTable)
return qb.toModel(ret), err
}
func (qb *TagQueryBuilder) FindByNameOrAlias(name string) (*Tag, error) {
query := `SELECT tags.* FROM tags
left join tag_aliases on tags.id = tag_aliases.tag_id
WHERE tag_aliases.alias = ? OR tags.name = ?`
args := []interface{}{name, name}
results, err := qb.queryTags(query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
func (qb *TagQueryBuilder) FindBySceneID(sceneID int64) ([]*Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN scene_tags as scenes_join on scenes_join.tag_id = tags.id
LEFT JOIN scenes on scenes_join.scene_id = scenes.id
WHERE scenes.id = ?
GROUP BY tags.id
`
args := []interface{}{sceneID}
return qb.queryTags(query, args)
}
func (qb *TagQueryBuilder) FindByNames(names []string) ([]*Tag, error) {
query := "SELECT * FROM tags WHERE name IN " + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
return qb.queryTags(query, args)
}
func (qb *TagQueryBuilder) FindByAliases(names []string) ([]*Tag, error) {
query := `SELECT tags.* FROM tags
left join tag_aliases on tags.id = tag_aliases.tag_id
WHERE tag_aliases.alias IN ` + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
return qb.queryTags(query, args)
}
func (qb *TagQueryBuilder) FindByName(name string) ([]*Tag, error) {
query := "SELECT * FROM tags WHERE upper(name) = upper(?)"
var args []interface{}
args = append(args, name)
return qb.queryTags(query, args)
}
func (qb *TagQueryBuilder) FindByAlias(name string) ([]*Tag, error) {
query := `SELECT tags.* FROM tags
left join tag_aliases on tag.id = tag_aliases.tag_id
WHERE upper(tag_aliases.alias) = UPPER(?)`
var args []interface{}
args = append(args, name)
return qb.queryTags(query, args)
}
func (qb *TagQueryBuilder) Count() (int, error) {
return runCountQuery(buildCountQuery("SELECT tags.id FROM tags"), nil)
}
func (qb *TagQueryBuilder) Query(tagFilter *TagFilterType, findFilter *QuerySpec) ([]*Tag, int) {
if tagFilter == nil {
tagFilter = &TagFilterType{}
}
if findFilter == nil {
findFilter = &QuerySpec{}
}
query := queryBuilder{
tableName: tagTable,
}
query.body = selectDistinctIDs(tagTable)
if q := tagFilter.Name; q != nil && *q != "" {
searchColumns := []string{"tags.name"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
query.addWhere(clause)
query.addArg(thisArgs...)
}
query.sortAndPagination = qb.getTagSort(findFilter) + getPagination(findFilter)
idsResult, countResult := query.executeFind()
var tags []*Tag
for _, id := range idsResult {
tag, _ := qb.Find(id)
tags = append(tags, tag)
}
return tags, countResult
}
func (qb *TagQueryBuilder) getTagSort(findFilter *QuerySpec) string {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
return getSort(sort, direction, tagTable)
}
func (qb *TagQueryBuilder) queryTags(query string, args []interface{}) (Tags, error) {
var output Tags
err := qb.dbi.RawQuery(tagDBTable, query, args, &output)
return output, err
}
func (qb *TagQueryBuilder) GetAliases(id int64) ([]string, error) {
joins := TagAliases{}
err := qb.dbi.FindJoins(tagAliasTable, id, &joins)
return joins.ToAliases(), err
}

View File

@@ -2,9 +2,10 @@ package models
import (
"database/sql/driver"
"time"
"github.com/stashapp/stashdb/pkg/logger"
"github.com/stashapp/stashdb/pkg/utils"
"time"
)
type SQLiteDate struct {
@@ -38,3 +39,7 @@ func (t SQLiteDate) Value() (driver.Value, error) {
}
return result, nil
}
func (t SQLiteDate) IsValid() bool {
return t.Valid
}

View File

@@ -19,3 +19,7 @@ func (t *SQLiteTimestamp) Scan(value interface{}) error {
func (t SQLiteTimestamp) Value() (driver.Value, error) {
return t.Timestamp.Format(time.RFC3339), nil
}
func (t SQLiteTimestamp) IsValid() bool {
return !t.Timestamp.IsZero()
}