diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..4fea60b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +go.mod text eol=lf +go.sum text eol=lf \ No newline at end of file diff --git a/Makefile b/Makefile index 532a12e..cced337 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,14 @@ generate: go generate packr2 +.PHONY: test +test: + go test ./... + +.PHONY: it +it: + go test -tags=integration ./... + # Runs gofmt -w on the project's source code, modifying any files that do not match its style. .PHONY: fmt fmt: diff --git a/README.md b/README.md index f85a146..235483d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,98 @@ # stash-box -Stash App's own OpenSource video indexing and Perceptual Hashing MetaData API + +[![Discord](https://img.shields.io/discord/559159668438728723.svg?logo=discord)](https://discord.gg/2TsNFKt) + +**stash-box is Stash App's own OpenSource video indexing and Perceptual Hashing MetaData API for porn.** + +The intent of stash-box is to provide a collaborative, crowd-sourced database of porn metadata, in the same way as [MusicBrainz](https://musicbrainz.org/) does for music. The submission and editing of metadata is expected to follow the same principle as that of the MusicBrainz database. [See here](https://musicbrainz.org/doc/Editing_FAQ) for how MusicBrainz does it. + +Currently, stash-box provides a graphql backend API only. There is no built in UI. The graphql playground can be accessed at `host:port/playground`. The graphql interface is at `host:port/graphql`. + +# Docker install + +TODO + +# Bare-metal Install + +Stash-box supports macOS, Windows, and Linux. + +Releases TODO + +## CLI + +Stash-box provides some command line options. See what is currently available by running `stashdb --help`. + +For example, to run stash locally on port 80 run it like this (OSX / Linux) `stashdb --host 127.0.0.1 --port 80` + +## Configuration + +Stash-box generates a configuration file in the current working directory when it is first started up. This configuration file is generated with the following defaults: +- running on `0.0.0.0` port `9998` +- sqlite3 database generated in the current working directory named `stashdb-go.sqlite` +- generated read (`read_api_key`) and write (`modify_api_key`) API keys. These can be deleted to disable read/write authentication (all requests will be allowed without API key) + +### API keys + +These are a very basic authorization method. When set, the `ApiKey` header must be set to the correct value to read/write the data. The write API key allows reading and writing. The read API key allows only reading. + +## SSL (HTTPS) + +Stash-box supports HTTPS with some additional work. First you must generate a SSL certificate and key combo. Here is an example using openssl: + +`openssl req -x509 -newkey rsa:4096 -sha256 -days 7300 -nodes -keyout stashdb.key -out stashdb.crt -extensions san -config <(echo "[req]"; echo distinguished_name=req; echo "[san]"; echo subjectAltName=DNS:stashdb.server,IP:127.0.0.1) -subj /CN=stashdb.server` + +This command would need customizing for your environment. [This link](https://stackoverflow.com/questions/10175812/how-to-create-a-self-signed-certificate-with-openssl) might be useful. + +Once you have a certificate and key file name them `stashdb.crt` and `stashdb.key` and place them in the directory where stash-box is run from. Stash-box detects these and starts up using HTTPS rather than HTTP. + +# FAQ + +> I have a question not answered here. + +Join the [Discord server](https://discord.gg/2TsNFKt). + +# Development + +## Install + +* [Revive](https://github.com/mgechev/revive) - Configurable linter + * Go Install: `go get github.com/mgechev/revive` +* [Packr2](https://github.com/gobuffalo/packr/tree/v2.0.2/v2) - Static asset bundler + * Go Install: `go get github.com/gobuffalo/packr/v2/packr2@v2.0.2` + * [Binary Download](https://github.com/gobuffalo/packr/releases) +* [Yarn](https://yarnpkg.com/en/docs/install) - Yarn package manager + +NOTE: You may need to run the `go get` commands outside the project directory to avoid modifying the projects module file. + +## Environment + +### macOS + +TODO + +### Windows + +1. Download and install [Go for Windows](https://golang.org/dl/) +2. Download and install [MingW](https://sourceforge.net/projects/mingw-w64/) +3. Search for "advanced system settings" and open the system properties dialog. + 1. Click the `Environment Variables` button + 2. Add `GO111MODULE=on` + 3. Under system variables find the `Path`. Edit and add `C:\Program Files\mingw-w64\*\mingw64\bin` (replace * with the correct path). + +## Commands + +* `make generate` - Generate Go GraphQL and packr2 files. This should be run if the graphql schema or schema migration files have changed. +* `make build` - Builds the binary +* `make vet` - Run `go vet` +* `make lint` - Run the linter +* `make test` - Runs the unit tests +* `make it` - Runs the unit and integration tests + +## Building a release + +1. Run `make generate` to create generated files +2. Run `make build` to build the executable for your current platform + +## Cross compiling + +TODO diff --git a/go.mod b/go.mod index 6170531..da7805a 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/golang-migrate/migrate/v4 v4.3.1 github.com/gorilla/websocket v1.4.0 github.com/h2non/filetype v1.0.8 - github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jmoiron/sqlx v1.2.0 github.com/mattn/go-sqlite3 v1.10.0 github.com/pkg/errors v0.8.1 diff --git a/go.sum b/go.sum index bbca084..147db74 100644 --- a/go.sum +++ b/go.sum @@ -243,6 +243,7 @@ github.com/gobuffalo/packr v1.15.0/go.mod h1:t5gXzEhIviQwVlNx/+3SfS07GS+cZ2hn76W github.com/gobuffalo/packr v1.15.1/go.mod h1:IeqicJ7jm8182yrVmNbM6PR4g79SjN9tZLH8KduZZwE= github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU= github.com/gobuffalo/packr v1.20.0/go.mod h1:JDytk1t2gP+my1ig7iI4NcVaXr886+N0ecUga6884zw= +github.com/gobuffalo/packr v1.21.0 h1:p2ujcDJQp2QTiYWcI0ByHbr/gMoCouok6M0vXs/yTYQ= github.com/gobuffalo/packr v1.21.0/go.mod h1:H00jGfj1qFKxscFJSw8wcL4hpQtPe1PfU2wa6sg/SR0= github.com/gobuffalo/packr/v2 v2.0.0-rc.8/go.mod h1:y60QCdzwuMwO2R49fdQhsjCPv7tLQFR0ayzxxla9zes= github.com/gobuffalo/packr/v2 v2.0.0-rc.9/go.mod h1:fQqADRfZpEsgkc7c/K7aMew3n4aF1Kji7+lIZeR98Fc= diff --git a/gqlgen.yml b/gqlgen.yml index b637540..516161b 100644 --- a/gqlgen.yml +++ b/gqlgen.yml @@ -16,4 +16,10 @@ struct_tag: gqlgen models: Performer: model: github.com/stashapp/stashdb/pkg/models.Performer + Tag: + model: github.com/stashapp/stashdb/pkg/models.Tag + Studio: + model: github.com/stashapp/stashdb/pkg/models.Studio + Scene: + model: github.com/stashapp/stashdb/pkg/models.Scene diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 9a67a7b..6df5f85 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -29,9 +29,12 @@ type Query { #### Scenes #### - # ids and checksums should be unique - """Find a scene by ID or checksum""" - findScene(id: ID, checksum: String): Scene + # ids should be unique + """Find a scene by ID""" + findScene(id: ID!): Scene + + """Finds a scene by an algorithm-specific checksum""" + findSceneByFingerprint(fingerprint: FingerprintInput!): [Scene!]! queryScenes(scene_filter: SceneFilterType, filter: QuerySpec): QueryScenesResultType! diff --git a/graphql/schema/types/performer.graphql b/graphql/schema/types/performer.graphql index 322079f..b0a9c86 100644 --- a/graphql/schema/types/performer.graphql +++ b/graphql/schema/types/performer.graphql @@ -109,6 +109,8 @@ input PerformerCreateInput { career_end_year: Int tattoos: [BodyModificationInput!] piercings: [BodyModificationInput!] + """Should be base64 encoded""" + image: String } input PerformerUpdateInput { @@ -130,6 +132,8 @@ input PerformerUpdateInput { career_end_year: Int tattoos: [BodyModificationInput!] piercings: [BodyModificationInput!] + """Should be base64 encoded""" + image: String } input PerformerDestroyInput { @@ -154,6 +158,8 @@ input PerformerEditDetailsInput { career_end_year: Int tattoos: [BodyModificationInput!] piercings: [BodyModificationInput!] + """Should be base64 encoded""" + image: String } input PerformerEditInput { diff --git a/graphql/schema/types/scene.graphql b/graphql/schema/types/scene.graphql index 82511b7..4c960b0 100644 --- a/graphql/schema/types/scene.graphql +++ b/graphql/schema/types/scene.graphql @@ -10,6 +10,20 @@ input PerformerAppearanceInput { as: String } +enum FingerprintAlgorithm { + MD5 +} + +type Fingerprint { + hash: String! + algorithm: FingerprintAlgorithm! +} + +input FingerprintInput { + hash: String! + algorithm: FingerprintAlgorithm! +} + type Scene { id: ID! title: String @@ -20,7 +34,7 @@ type Scene { studio: Studio tags: [Tag!]! performers: [PerformerAppearance!]! - checksums: [String!]! + fingerprints: [Fingerprint!]! } input SceneCreateInput { @@ -31,7 +45,7 @@ input SceneCreateInput { studio_id: ID performers: [PerformerAppearanceInput!] tag_ids: [ID!] - checksums: [String!]! + fingerprints: [FingerprintInput!]! } input SceneUpdateInput { @@ -43,6 +57,7 @@ input SceneUpdateInput { studio_id: ID performers: [PerformerAppearanceInput!] tag_ids: [ID!] + fingerprints: [FingerprintInput!] } input SceneDestroyInput { @@ -57,7 +72,7 @@ input SceneEditDetailsInput { studio_id: ID performers: [PerformerAppearanceInput!] tag_ids: [ID!] - checksums: [String!] + fingerprints: [FingerprintInput!] } input SceneEditInput { @@ -82,6 +97,8 @@ type SceneEdit { removed_performers: [PerformerAppearance!] added_tags: [Tag!] removed_tags: [Tag!] + added_fingerprints: [Fingerprint!] + removed_fingerprints: [Fingerprint!] } type QueryScenesResultType { diff --git a/graphql/schema/types/studio.graphql b/graphql/schema/types/studio.graphql index c30cdd1..de12b4a 100644 --- a/graphql/schema/types/studio.graphql +++ b/graphql/schema/types/studio.graphql @@ -15,7 +15,7 @@ input StudioCreateInput { input StudioUpdateInput { id: ID! - name: String! + name: String urls: [URLInput!] parent_id: ID child_studio_ids: [ID!] diff --git a/pkg/api/context_keys.go b/pkg/api/context_keys.go new file mode 100644 index 0000000..0196d9d --- /dev/null +++ b/pkg/api/context_keys.go @@ -0,0 +1,11 @@ +package api + +// https://stackoverflow.com/questions/40891345/fix-should-not-use-basic-type-string-as-key-in-context-withvalue-golint + +type key int + +const ( + performerKey key = 1 + sceneKey key = 2 + studioKey key = 3 +) diff --git a/pkg/api/integration_test.go b/pkg/api/integration_test.go new file mode 100644 index 0000000..5aeeac9 --- /dev/null +++ b/pkg/api/integration_test.go @@ -0,0 +1,176 @@ +// +build integration + +package api_test + +import ( + "context" + "strconv" + "testing" + + "github.com/stashapp/stashdb/pkg/api" + dbtest "github.com/stashapp/stashdb/pkg/database/databasetest" + "github.com/stashapp/stashdb/pkg/models" + + "github.com/99designs/gqlgen/graphql" + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" +) + +func TestMain(m *testing.M) { + dbtest.TestWithDatabase(m, nil) +} + +type testRunner struct { + t *testing.T + resolver api.Resolver + ctx context.Context + err error +} + +var performerSuffix int +var studioSuffix int +var tagSuffix int +var sceneChecksumSuffix int + +func createTestRunner(t *testing.T) *testRunner { + resolver := api.Resolver{} + ctx := context.TODO() + ctx = context.WithValue(ctx, api.ContextRole, api.ModifyRole) + + return &testRunner{ + t: t, + resolver: resolver, + ctx: ctx, + } +} + +func (t *testRunner) doTest(test func()) { + if t.t.Failed() { + return + } + + test() +} + +func (t *testRunner) fieldMismatch(expected interface{}, actual interface{}, field string) { + t.t.Helper() + t.t.Errorf("%s mismatch: %+v != %+v", field, actual, expected) +} + +func (t *testRunner) updateContext(fields []string) context.Context { + variables := make(map[string]interface{}) + for _, v := range fields { + variables[v] = true + } + + rctx := &graphql.RequestContext{ + Variables: variables, + } + return graphql.WithRequestContext(t.ctx, rctx) +} + +func (s *testRunner) generatePerformerName() string { + performerSuffix += 1 + return "performer-" + strconv.Itoa(performerSuffix) +} + +func (s *testRunner) createTestPerformer(input *models.PerformerCreateInput) (*models.Performer, error) { + s.t.Helper() + if input == nil { + input = &models.PerformerCreateInput{ + Name: s.generatePerformerName(), + } + } + + createdPerformer, err := s.resolver.Mutation().PerformerCreate(s.ctx, *input) + + if err != nil { + s.t.Errorf("Error creating performer: %s", err.Error()) + return nil, err + } + + return createdPerformer, nil +} + +func (s *testRunner) generateStudioName() string { + studioSuffix += 1 + return "studio-" + strconv.Itoa(studioSuffix) +} + +func (s *testRunner) createTestStudio(input *models.StudioCreateInput) (*models.Studio, error) { + s.t.Helper() + if input == nil { + input = &models.StudioCreateInput{ + Name: s.generateStudioName(), + } + } + + createdStudio, err := s.resolver.Mutation().StudioCreate(s.ctx, *input) + + if err != nil { + s.t.Errorf("Error creating studio: %s", err.Error()) + return nil, err + } + + return createdStudio, nil +} + +func (s *testRunner) generateTagName() string { + tagSuffix += 1 + return "tag-" + strconv.Itoa(tagSuffix) +} + +func (s *testRunner) createTestTag(input *models.TagCreateInput) (*models.Tag, error) { + s.t.Helper() + if input == nil { + input = &models.TagCreateInput{ + Name: s.generateTagName(), + } + } + + createdTag, err := s.resolver.Mutation().TagCreate(s.ctx, *input) + + if err != nil { + s.t.Errorf("Error creating tag: %s", err.Error()) + return nil, err + } + + return createdTag, nil +} + +func (s *testRunner) createTestScene(input *models.SceneCreateInput) (*models.Scene, error) { + s.t.Helper() + if input == nil { + title := "title" + input = &models.SceneCreateInput{ + Title: &title, + Fingerprints: []*models.FingerprintInput{ + s.generateSceneFingerprint(), + }, + } + } + + createdScene, err := s.resolver.Mutation().SceneCreate(s.ctx, *input) + + if err != nil { + s.t.Errorf("Error creating scene: %s", err.Error()) + return nil, err + } + + return createdScene, nil +} + +func (s *testRunner) generateSceneFingerprint() *models.FingerprintInput { + sceneChecksumSuffix += 1 + return &models.FingerprintInput{ + Algorithm: "MD5", + Hash: "scene-" + strconv.Itoa(sceneChecksumSuffix), + } +} + +func oneNil(l interface{}, r interface{}) bool { + return l != r && (l == nil || r == nil) +} + +func bothNil(l interface{}, r interface{}) bool { + return l == nil && r == nil +} diff --git a/pkg/api/performer_integration_test.go b/pkg/api/performer_integration_test.go new file mode 100644 index 0000000..8880c12 --- /dev/null +++ b/pkg/api/performer_integration_test.go @@ -0,0 +1,504 @@ +// +build integration + +package api_test + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stashapp/stashdb/pkg/models" + + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" +) + +type performerTestRunner struct { + testRunner +} + +func createPerformerTestRunner(t *testing.T) *performerTestRunner { + return &performerTestRunner{ + testRunner: *createTestRunner(t), + } +} + +func (s *performerTestRunner) testCreatePerformer() { + disambiguation := "Disambiguation" + country := "USA" + height := 182 + cupSize := "C" + bandSize := 32 + careerStartYear := 2000 + tattooDesc := "Foobar" + gender := models.GenderEnumFemale + ethnicity := models.EthnicityEnumCaucasian + eyeColor := models.EyeColorEnumBlue + hairColor := models.HairColorEnumBlonde + breastType := models.BreastTypeEnumNatural + + input := models.PerformerCreateInput{ + Name: s.generatePerformerName(), + Disambiguation: &disambiguation, + Aliases: []string{"Alias1", "Alias2"}, + Gender: &gender, + Urls: []*models.URLInput{ + &models.URLInput{ + URL: "URL", + Type: "Type", + }, + }, + Birthdate: &models.FuzzyDateInput{ + Date: "2001-02-03", + Accuracy: models.DateAccuracyEnumDay, + }, + Ethnicity: ðnicity, + Country: &country, + EyeColor: &eyeColor, + HairColor: &hairColor, + Height: &height, + Measurements: &models.MeasurementsInput{ + CupSize: &cupSize, + BandSize: &bandSize, + Waist: &bandSize, + Hip: &bandSize, + }, + BreastType: &breastType, + CareerStartYear: &careerStartYear, + CareerEndYear: nil, + Tattoos: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Inner thigh", + Description: &tattooDesc, + }, + }, + Piercings: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Nose", + Description: nil, + }, + }, + } + + performer, err := s.resolver.Mutation().PerformerCreate(s.ctx, input) + + if err != nil { + s.t.Errorf("Error creating performer: %s", err.Error()) + return + } + + s.verifyCreatedPerformer(input, performer) +} + +func compareBodyMods(input []*models.BodyModificationInput, bodyMods []*models.BodyModification) bool { + if len(bodyMods) != len(input) { + return false + } + + for i, v := range bodyMods { + if v.Location != input[i].Location { + return false + } + + if v.Description != input[i].Description { + if v.Description == nil || input[i].Description == nil { + return false + } + + if *v.Description != *input[i].Description { + return false + } + } + } + + return true +} + +func compareUrls(input []*models.URLInput, urls []*models.URL) bool { + if len(urls) != len(input) { + return false + } + + for i, v := range urls { + if v.URL != input[i].URL || v.Type != input[i].Type { + return false + } + } + + return true +} + +func (s *performerTestRunner) verifyCreatedPerformer(input models.PerformerCreateInput, performer *models.Performer) { + // ensure basic attributes are set correctly + if input.Name != performer.Name { + s.fieldMismatch(input.Name, performer.Name, "Name") + } + + r := s.resolver.Performer() + + id, _ := r.ID(s.ctx, performer) + if id == "" { + s.t.Errorf("Expected created performer id to be non-zero") + } + + if v, _ := r.Disambiguation(s.ctx, performer); !reflect.DeepEqual(v, input.Disambiguation) { + s.fieldMismatch(*input.Disambiguation, v, "Disambiguation") + } + + if v, _ := r.Aliases(s.ctx, performer); !reflect.DeepEqual(v, input.Aliases) { + s.fieldMismatch(input.Aliases, v, "Aliases") + } + + if v, _ := r.Gender(s.ctx, performer); !reflect.DeepEqual(v, input.Gender) { + s.fieldMismatch(*input.Gender, v, "Gender") + } + + // ensure urls were set correctly + urls, _ := s.resolver.Performer().Urls(s.ctx, performer) + if !compareUrls(input.Urls, urls) { + s.fieldMismatch(input.Urls, urls, "Urls") + } + + birthdate, _ := r.Birthdate(s.ctx, performer) + if !bothNil(birthdate, input.Birthdate) && (oneNil(birthdate, input.Birthdate) || birthdate.Date != input.Birthdate.Date || birthdate.Accuracy != input.Birthdate.Accuracy) { + s.fieldMismatch(input.Birthdate, birthdate, "Birthdate") + } + + if v, _ := r.Ethnicity(s.ctx, performer); !reflect.DeepEqual(v, input.Ethnicity) { + s.fieldMismatch(*input.Ethnicity, v, "Ethnicity") + } + + if v, _ := r.Country(s.ctx, performer); !reflect.DeepEqual(v, input.Country) { + s.fieldMismatch(*input.Country, v, "Country") + } + + if v, _ := r.EyeColor(s.ctx, performer); !reflect.DeepEqual(v, input.EyeColor) { + s.fieldMismatch(*input.HairColor, v, "EyeColor") + } + + if v, _ := r.HairColor(s.ctx, performer); !reflect.DeepEqual(v, input.HairColor) { + s.fieldMismatch(*input.HairColor, v, "HairColor") + } + + if v, _ := r.Height(s.ctx, performer); !reflect.DeepEqual(v, input.Height) { + s.fieldMismatch(*input.Height, v, "Height") + } + + measurements, _ := r.Measurements(s.ctx, performer) + if !bothNil(measurements, input.Measurements) && (oneNil(measurements, input.Measurements) || + *measurements.CupSize != *input.Measurements.CupSize || + *measurements.BandSize != *input.Measurements.BandSize || + *measurements.Waist != *input.Measurements.Waist || + *measurements.Hip != *input.Measurements.Hip) { + s.fieldMismatch(input.Measurements, measurements, "Measurements") + } + + if v, _ := r.BreastType(s.ctx, performer); !reflect.DeepEqual(v, input.BreastType) { + s.fieldMismatch(*input.BreastType, v, "BreastType") + } + + if v, _ := r.CareerStartYear(s.ctx, performer); !reflect.DeepEqual(v, input.CareerStartYear) { + s.fieldMismatch(*input.CareerStartYear, v, "CareerStartYear") + } + + if v, _ := r.CareerEndYear(s.ctx, performer); !reflect.DeepEqual(v, input.CareerEndYear) { + s.fieldMismatch(nil, v, "CareerEndYear") + } + + tattoos, _ := s.resolver.Performer().Tattoos(s.ctx, performer) + if !compareBodyMods(input.Tattoos, tattoos) { + s.fieldMismatch(input.Tattoos, tattoos, "Tattoos") + } + + piercings, _ := s.resolver.Performer().Piercings(s.ctx, performer) + if !compareBodyMods(input.Piercings, piercings) { + s.fieldMismatch(input.Piercings, piercings, "Piercings") + } +} + +func (s *performerTestRunner) testFindPerformer() { + createdPerformer, err := s.createTestPerformer(nil) + if err != nil { + return + } + + performer, err := s.resolver.Query().FindPerformer(s.ctx, strconv.FormatInt(createdPerformer.ID, 10)) + if err != nil { + s.t.Errorf("Error finding performer: %s", err.Error()) + return + } + + // ensure returned performer is not nil + if performer == nil { + s.t.Error("Did not find performer by id") + return + } + + // ensure values were set + if createdPerformer.Name != performer.Name { + s.fieldMismatch(createdPerformer.Name, performer.Name, "Name") + } +} + +func (s *performerTestRunner) testUpdatePerformer() { + cupSize := "C" + bandSize := 32 + tattooDesc := "Foobar" + + input := &models.PerformerCreateInput{ + Name: s.generatePerformerName(), + Aliases: []string{"Alias1", "Alias2"}, + Urls: []*models.URLInput{ + &models.URLInput{ + URL: "URL", + Type: "Type", + }, + }, + Birthdate: &models.FuzzyDateInput{ + Date: "2001-02-03", + Accuracy: models.DateAccuracyEnumDay, + }, + Measurements: &models.MeasurementsInput{ + CupSize: &cupSize, + BandSize: &bandSize, + Waist: &bandSize, + Hip: &bandSize, + }, + Tattoos: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Inner thigh", + Description: &tattooDesc, + }, + }, + Piercings: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Nose", + Description: nil, + }, + }, + } + + createdPerformer, err := s.createTestPerformer(input) + if err != nil { + return + } + + performerID := strconv.FormatInt(createdPerformer.ID, 10) + + updateInput := models.PerformerUpdateInput{ + ID: performerID, + Aliases: []string{"Alias3", "Alias4"}, + Urls: []*models.URLInput{ + &models.URLInput{ + URL: "URL", + Type: "Type", + }, + }, + Birthdate: &models.FuzzyDateInput{ + Date: "2001-02-03", + Accuracy: models.DateAccuracyEnumDay, + }, + Measurements: &models.MeasurementsInput{ + CupSize: &cupSize, + BandSize: &bandSize, + Waist: &bandSize, + Hip: &bandSize, + }, + Tattoos: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Tramp stamp", + Description: &tattooDesc, + }, + }, + Piercings: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Navel", + Description: nil, + }, + }, + } + + // need some mocking of the context to make the field ignore behaviour work + ctx := s.updateContext([]string{ + "aliases", + "urls", + "birthdate", + "measurements", + "tattoos", + "piercings", + }) + + updatedPerformer, err := s.resolver.Mutation().PerformerUpdate(ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating performer: %s", err.Error()) + return + } + + s.verifyUpdatedPerformer(updateInput, updatedPerformer) +} + +func (s *performerTestRunner) testUpdatePerformerName() { + cupSize := "C" + bandSize := 32 + tattooDesc := "Foobar" + + input := &models.PerformerCreateInput{ + Name: s.generatePerformerName(), + Aliases: []string{"Alias1", "Alias2"}, + Urls: []*models.URLInput{ + &models.URLInput{ + URL: "URL", + Type: "Type", + }, + }, + Birthdate: &models.FuzzyDateInput{ + Date: "2001-02-03", + Accuracy: models.DateAccuracyEnumDay, + }, + Measurements: &models.MeasurementsInput{ + CupSize: &cupSize, + BandSize: &bandSize, + Waist: &bandSize, + Hip: &bandSize, + }, + Tattoos: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Inner thigh", + Description: &tattooDesc, + }, + }, + Piercings: []*models.BodyModificationInput{ + &models.BodyModificationInput{ + Location: "Nose", + Description: nil, + }, + }, + } + + createdPerformer, err := s.createTestPerformer(input) + if err != nil { + return + } + + performerID := strconv.FormatInt(createdPerformer.ID, 10) + + updatedName := s.generatePerformerName() + updateInput := models.PerformerUpdateInput{ + ID: performerID, + Name: &updatedName, + } + + // need some mocking of the context to make the field ignore behaviour work + ctx := s.updateContext([]string{ + "name", + }) + updatedPerformer, err := s.resolver.Mutation().PerformerUpdate(ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating performer: %s", err.Error()) + return + } + + input.Name = updatedName + s.verifyCreatedPerformer(*input, updatedPerformer) +} + +func (s *performerTestRunner) verifyUpdatedPerformer(input models.PerformerUpdateInput, performer *models.Performer) { + // ensure basic attributes are set correctly + if input.Name != nil && *input.Name != performer.Name { + s.fieldMismatch(input.Name, performer.Name, "Name") + } + + r := s.resolver.Performer() + + if v, _ := r.Aliases(s.ctx, performer); !reflect.DeepEqual(v, input.Aliases) { + s.fieldMismatch(input.Aliases, v, "Aliases") + } + + // ensure urls were set correctly + urls, _ := s.resolver.Performer().Urls(s.ctx, performer) + if !compareUrls(input.Urls, urls) { + s.fieldMismatch(input.Urls, urls, "Urls") + } + + birthdate, _ := r.Birthdate(s.ctx, performer) + if birthdate != nil && (birthdate.Date != input.Birthdate.Date || birthdate.Accuracy != input.Birthdate.Accuracy) { + s.fieldMismatch(input.Birthdate, birthdate, "Birthdate") + } + + measurements, _ := r.Measurements(s.ctx, performer) + if input.Measurements != nil && (*measurements.CupSize != *input.Measurements.CupSize || + *measurements.BandSize != *input.Measurements.BandSize || + *measurements.Waist != *input.Measurements.Waist || + *measurements.Hip != *input.Measurements.Hip) { + s.fieldMismatch(input.Measurements, measurements, "Measurements") + } + + tattoos, _ := s.resolver.Performer().Tattoos(s.ctx, performer) + if !compareBodyMods(input.Tattoos, tattoos) { + s.fieldMismatch(input.Tattoos, tattoos, "Tattoos") + } + + piercings, _ := s.resolver.Performer().Piercings(s.ctx, performer) + if !compareBodyMods(input.Piercings, piercings) { + s.fieldMismatch(input.Piercings, piercings, "Piercings") + } +} + +func (s *performerTestRunner) testDestroyPerformer() { + createdPerformer, err := s.createTestPerformer(nil) + if err != nil { + return + } + + performerID := strconv.FormatInt(createdPerformer.ID, 10) + + destroyed, err := s.resolver.Mutation().PerformerDestroy(s.ctx, models.PerformerDestroyInput{ + ID: performerID, + }) + if err != nil { + s.t.Errorf("Error destroying performer: %s", err.Error()) + return + } + + if !destroyed { + s.t.Error("Performer was not destroyed") + return + } + + // ensure cannot find performer + foundPerformer, err := s.resolver.Query().FindPerformer(s.ctx, performerID) + if err != nil { + s.t.Errorf("Error finding performer after destroying: %s", err.Error()) + return + } + + if foundPerformer != nil { + s.t.Error("Found performer after destruction") + } + + // TODO - ensure scene was not removed +} + +func TestCreatePerformer(t *testing.T) { + pt := createPerformerTestRunner(t) + pt.testCreatePerformer() +} + +func TestFindPerformer(t *testing.T) { + pt := createPerformerTestRunner(t) + pt.testFindPerformer() +} + +func TestUpdatePerformer(t *testing.T) { + pt := createPerformerTestRunner(t) + pt.testUpdatePerformer() +} + +func TestUpdatePerformerName(t *testing.T) { + pt := createPerformerTestRunner(t) + pt.testUpdatePerformerName() +} + +func TestDestroyPerformer(t *testing.T) { + pt := createPerformerTestRunner(t) + pt.testDestroyPerformer() +} diff --git a/pkg/api/resolver.go b/pkg/api/resolver.go index be2a712..4b6be4c 100644 --- a/pkg/api/resolver.go +++ b/pkg/api/resolver.go @@ -16,6 +16,15 @@ func (r *Resolver) Mutation() models.MutationResolver { func (r *Resolver) Performer() models.PerformerResolver { return &performerResolver{r} } +func (r *Resolver) Tag() models.TagResolver { + return &tagResolver{r} +} +func (r *Resolver) Studio() models.StudioResolver { + return &studioResolver{r} +} +func (r *Resolver) Scene() models.SceneResolver { + return &sceneResolver{r} +} func (r *Resolver) Query() models.QueryResolver { return &queryResolver{r} } @@ -34,6 +43,10 @@ func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) { func wasFieldIncluded(ctx context.Context, field string) bool { rctx := graphql.GetRequestContext(ctx) - _, ret := rctx.Variables[field] - return ret + if rctx != nil { + _, ret := rctx.Variables[field] + return ret + } + + return false } diff --git a/pkg/api/resolver_model_performer.go b/pkg/api/resolver_model_performer.go index 78a95bd..8e33c3e 100644 --- a/pkg/api/resolver_model_performer.go +++ b/pkg/api/resolver_model_performer.go @@ -20,7 +20,7 @@ func (r *performerResolver) Disambiguation(ctx context.Context, obj *models.Perf } func (r *performerResolver) Aliases(ctx context.Context, obj *models.Performer) ([]string, error) { - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) aliases, err := qb.GetAliases(obj.ID) if err != nil { @@ -40,7 +40,7 @@ func (r *performerResolver) Gender(ctx context.Context, obj *models.Performer) ( } func (r *performerResolver) Urls(ctx context.Context, obj *models.Performer) ([]*models.URL, error) { - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) urls, err := qb.GetUrls(obj.ID) if err != nil { @@ -140,7 +140,7 @@ func (r *performerResolver) CareerEndYear(ctx context.Context, obj *models.Perfo } func (r *performerResolver) Tattoos(ctx context.Context, obj *models.Performer) ([]*models.BodyModification, error) { - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) tattoos, err := qb.GetTattoos(obj.ID) if err != nil { @@ -157,7 +157,7 @@ func (r *performerResolver) Tattoos(ctx context.Context, obj *models.Performer) } func (r *performerResolver) Piercings(ctx context.Context, obj *models.Performer) ([]*models.BodyModification, error) { - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) piercings, err := qb.GetPiercings(obj.ID) if err != nil { diff --git a/pkg/api/resolver_model_scene.go b/pkg/api/resolver_model_scene.go index 89ca42c..85bae07 100644 --- a/pkg/api/resolver_model_scene.go +++ b/pkg/api/resolver_model_scene.go @@ -2,40 +2,76 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) type sceneResolver struct{ *Resolver } +func (r *sceneResolver) ID(ctx context.Context, obj *models.Scene) (string, error) { + return strconv.FormatInt(obj.ID, 10), nil +} func (r *sceneResolver) Title(ctx context.Context, obj *models.Scene) (*string, error) { - panic("not implemented") + return resolveNullString(obj.Title) } - func (r *sceneResolver) Details(ctx context.Context, obj *models.Scene) (*string, error) { - panic("not implemented") + return resolveNullString(obj.Details) } - func (r *sceneResolver) URL(ctx context.Context, obj *models.Scene) (*string, error) { - panic("not implemented") + return resolveNullString(obj.URL) } - func (r *sceneResolver) Date(ctx context.Context, obj *models.Scene) (*string, error) { - panic("not implemented") + return resolveSQLiteDate(obj.Date) } - func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (*models.Studio, error) { - panic("not implemented") -} + if !obj.StudioID.Valid { + return nil, nil + } + qb := models.NewStudioQueryBuilder(nil) + parent, err := qb.Find(obj.StudioID.Int64) + + if err != nil { + return nil, err + } + + return parent, nil +} func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) ([]*models.Tag, error) { - panic("not implemented") + qb := models.NewTagQueryBuilder(nil) + return qb.FindBySceneID(obj.ID) } +func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.PerformerAppearance, error) { + pqb := models.NewPerformerQueryBuilder(nil) + sqb := models.NewSceneQueryBuilder(nil) + performersScenes, err := sqb.GetPerformers(obj.ID) -func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.Performer, error) { - panic("not implemented") -} + if err != nil { + return nil, err + } -func (r *sceneResolver) Checksums(ctx context.Context, obj *models.Scene) ([]string, error) { - panic("not implemented") + // TODO - probably a better way to do this + var ret []*models.PerformerAppearance + for _, appearance := range performersScenes { + performer, err := pqb.Find(appearance.PerformerID) + + if err != nil { + return nil, err + } + + as, _ := resolveNullString(appearance.As) + + retApp := models.PerformerAppearance{ + Performer: performer, + As: as, + } + ret = append(ret, &retApp) + } + + return ret, nil +} +func (r *sceneResolver) Fingerprints(ctx context.Context, obj *models.Scene) ([]*models.Fingerprint, error) { + qb := models.NewSceneQueryBuilder(nil) + return qb.GetFingerprints(obj.ID) } diff --git a/pkg/api/resolver_model_studio.go b/pkg/api/resolver_model_studio.go index ad1225c..d8efbae 100644 --- a/pkg/api/resolver_model_studio.go +++ b/pkg/api/resolver_model_studio.go @@ -2,18 +2,56 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) type studioResolver struct{ *Resolver } +func (r *studioResolver) ID(ctx context.Context, obj *models.Studio) (string, error) { + return strconv.FormatInt(obj.ID, 10), nil +} + func (r *studioResolver) Urls(ctx context.Context, obj *models.Studio) ([]*models.URL, error) { - panic("not implemented") + qb := models.NewStudioQueryBuilder(nil) + urls, err := qb.GetUrls(obj.ID) + + if err != nil { + return nil, err + } + + var ret []*models.URL + for _, url := range urls { + retURL := url.ToURL() + ret = append(ret, &retURL) + } + + return ret, nil } + func (r *studioResolver) Parent(ctx context.Context, obj *models.Studio) (*models.Studio, error) { - panic("not implemented") + if !obj.ParentStudioID.Valid { + return nil, nil + } + + qb := models.NewStudioQueryBuilder(nil) + parent, err := qb.Find(obj.ParentStudioID.Int64) + + if err != nil { + return nil, err + } + + return parent, nil } -func (r *studioResolver) AddedChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) { - panic("not implemented") + +func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) { + qb := models.NewStudioQueryBuilder(nil) + children, err := qb.FindByParentID(obj.ID) + + if err != nil { + return nil, err + } + + return children, nil } diff --git a/pkg/api/resolver_model_tag.go b/pkg/api/resolver_model_tag.go index 40b320f..acca8be 100644 --- a/pkg/api/resolver_model_tag.go +++ b/pkg/api/resolver_model_tag.go @@ -2,15 +2,26 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) type tagResolver struct{ *Resolver } +func (r *tagResolver) ID(ctx context.Context, obj *models.Tag) (string, error) { + return strconv.FormatInt(obj.ID, 10), nil +} func (r *tagResolver) Description(ctx context.Context, obj *models.Tag) (*string, error) { - panic("not implemented") + return resolveNullString(obj.Description) } func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) ([]string, error) { - panic("not implemented") + qb := models.NewTagQueryBuilder(nil) + aliases, err := qb.GetAliases(obj.ID) + + if err != nil { + return nil, err + } + + return aliases, nil } diff --git a/pkg/api/resolver_mutation_performer.go b/pkg/api/resolver_mutation_performer.go index ccff0e2..b50190b 100644 --- a/pkg/api/resolver_mutation_performer.go +++ b/pkg/api/resolver_mutation_performer.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "strconv" "time" @@ -27,12 +28,15 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, } - newPerformer.CopyFromCreateInput(input) + err = newPerformer.CopyFromCreateInput(input) + if err != nil { + return nil, err + } // Start the transaction and save the performer tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewPerformerQueryBuilder() - performer, err := qb.Create(newPerformer, tx) + qb := models.NewPerformerQueryBuilder(tx) + performer, err := qb.Create(newPerformer) if err != nil { _ = tx.Rollback() return nil, err @@ -40,28 +44,28 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per // Save the aliases performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases) - if err := qb.CreateAliases(performerAliases, tx); err != nil { + if err := qb.CreateAliases(performerAliases); err != nil { _ = tx.Rollback() return nil, err } // Save the URLs performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls) - if err := qb.CreateUrls(performerUrls, tx); err != nil { + if err := qb.CreateUrls(performerUrls); err != nil { _ = tx.Rollback() return nil, err } // Save the Tattoos performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos) - if err := qb.CreateTattoos(performerTattoos, tx); err != nil { + if err := qb.CreateTattoos(performerTattoos); err != nil { _ = tx.Rollback() return nil, err } // Save the Piercings performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings) - if err := qb.CreatePiercings(performerPiercings, tx); err != nil { + if err := qb.CreatePiercings(performerPiercings); err != nil { _ = tx.Rollback() return nil, err } @@ -79,56 +83,74 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per return nil, err } - qb := models.NewPerformerQueryBuilder() + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewPerformerQueryBuilder(tx) // get the existing performer and modify it - performerID, _ := strconv.Atoi(input.ID) + performerID, _ := strconv.ParseInt(input.ID, 10, 64) updatedPerformer, err := qb.Find(performerID) if err != nil { return nil, err } + if updatedPerformer == nil { + return nil, fmt.Errorf("Performer with id %d cannot be found", performerID) + } + updatedPerformer.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()} - // Start the transaction and save the performer - tx := database.DB.MustBeginTx(ctx, nil) - // Populate performer from the input - updatedPerformer.CopyFromUpdateInput(input) + err = updatedPerformer.CopyFromUpdateInput(input) + if err != nil { + _ = tx.Rollback() + return nil, err + } - performer, err := qb.Update(*updatedPerformer, tx) + performer, err := qb.Update(*updatedPerformer) if err != nil { _ = tx.Rollback() return nil, err } // Save the aliases - performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases) - if err := qb.UpdateAliases(performer.ID, performerAliases, tx); err != nil { - _ = tx.Rollback() - return nil, err + // only do this if provided + if wasFieldIncluded(ctx, "aliases") { + performerAliases := models.CreatePerformerAliases(performer.ID, input.Aliases) + if err := qb.UpdateAliases(performer.ID, performerAliases); err != nil { + _ = tx.Rollback() + return nil, err + } } // Save the URLs - performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls) - if err := qb.UpdateUrls(performer.ID, performerUrls, tx); err != nil { - _ = tx.Rollback() - return nil, err + // only do this if provided + if wasFieldIncluded(ctx, "urls") { + performerUrls := models.CreatePerformerUrls(performer.ID, input.Urls) + if err := qb.UpdateUrls(performer.ID, performerUrls); err != nil { + _ = tx.Rollback() + return nil, err + } } // Save the Tattoos - performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos) - if err := qb.UpdateTattoos(performer.ID, performerTattoos, tx); err != nil { - _ = tx.Rollback() - return nil, err + // only do this if provided + if wasFieldIncluded(ctx, "tattoos") { + performerTattoos := models.CreatePerformerBodyMods(performer.ID, input.Tattoos) + if err := qb.UpdateTattoos(performer.ID, performerTattoos); err != nil { + _ = tx.Rollback() + return nil, err + } } // Save the Piercings - performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings) - if err := qb.UpdatePiercings(performer.ID, performerPiercings, tx); err != nil { - _ = tx.Rollback() - return nil, err + // only do this if provided + if wasFieldIncluded(ctx, "piercings") { + performerPiercings := models.CreatePerformerBodyMods(performer.ID, input.Piercings) + if err := qb.UpdatePiercings(performer.ID, performerPiercings); err != nil { + _ = tx.Rollback() + return nil, err + } } // Commit @@ -144,8 +166,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe return false, err } - qb := models.NewPerformerQueryBuilder() tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewPerformerQueryBuilder(tx) // references have on delete cascade, so shouldn't be necessary // to remove them explicitly @@ -154,7 +176,7 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe if err != nil { return false, err } - if err = qb.Destroy(performerID, tx); err != nil { + if err = qb.Destroy(performerID); err != nil { _ = tx.Rollback() return false, err } diff --git a/pkg/api/resolver_mutation_scene.go b/pkg/api/resolver_mutation_scene.go index 723d327..6b84b8f 100644 --- a/pkg/api/resolver_mutation_scene.go +++ b/pkg/api/resolver_mutation_scene.go @@ -2,7 +2,10 @@ package api import ( "context" + "strconv" + "time" + "github.com/stashapp/stashdb/pkg/database" "github.com/stashapp/stashdb/pkg/models" ) @@ -11,11 +14,124 @@ func (r *mutationResolver) SceneCreate(ctx context.Context, input models.SceneCr return nil, err } - return nil, nil + var err error + + if err != nil { + return nil, err + } + + // Populate a new scene from the input + currentTime := time.Now() + newScene := models.Scene{ + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + } + + newScene.CopyFromCreateInput(input) + + // Start the transaction and save the scene + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewSceneQueryBuilder(tx) + scene, err := qb.Create(newScene) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the checksums + sceneFingerprints := models.CreateSceneFingerprints(scene.ID, input.Fingerprints) + if err := qb.CreateFingerprints(sceneFingerprints); err != nil { + _ = tx.Rollback() + return nil, err + } + + // save the performers + scenePerformers := models.CreateScenePerformers(scene.ID, input.Performers) + jqb := models.NewJoinsQueryBuilder(tx) + if err := jqb.CreatePerformersScenes(scenePerformers); err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the tags + tagJoins := models.CreateSceneTags(scene.ID, input.TagIds) + + if err := jqb.CreateScenesTags(tagJoins); err != nil { + return nil, err + } + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return scene, nil } func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (*models.Scene, error) { - return nil, nil + if err := validateModify(ctx); err != nil { + return nil, err + } + + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewSceneQueryBuilder(tx) + + // get the existing scene and modify it + sceneID, _ := strconv.ParseInt(input.ID, 10, 64) + updatedScene, err := qb.Find(sceneID) + + if err != nil { + return nil, err + } + + updatedScene.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()} + + // Populate scene from the input + updatedScene.CopyFromUpdateInput(input) + + scene, err := qb.Update(*updatedScene) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the checksums + // only do this if provided + if wasFieldIncluded(ctx, "fingerprints") { + sceneFingerprints := models.CreateSceneFingerprints(scene.ID, input.Fingerprints) + if err := qb.UpdateFingerprints(scene.ID, sceneFingerprints); err != nil { + _ = tx.Rollback() + return nil, err + } + } + + jqb := models.NewJoinsQueryBuilder(tx) + + // only do this if provided + if wasFieldIncluded(ctx, "performers") { + scenePerformers := models.CreateScenePerformers(scene.ID, input.Performers) + if err := jqb.UpdatePerformersScenes(scene.ID, scenePerformers); err != nil { + _ = tx.Rollback() + return nil, err + } + } + + // Save the tags + // only do this if provided + if wasFieldIncluded(ctx, "tagIds") { + tagJoins := models.CreateSceneTags(scene.ID, input.TagIds) + + if err := jqb.UpdateScenesTags(scene.ID, tagJoins); err != nil { + return nil, err + } + } + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return scene, nil } func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneDestroyInput) (bool, error) { @@ -23,5 +139,23 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD return false, err } + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewSceneQueryBuilder(tx) + + // references have on delete cascade, so shouldn't be necessary + // to remove them explicitly + + sceneID, err := strconv.ParseInt(input.ID, 10, 64) + if err != nil { + return false, err + } + if err = qb.Destroy(sceneID); err != nil { + _ = tx.Rollback() + return false, err + } + + if err := tx.Commit(); err != nil { + return false, err + } return true, nil } diff --git a/pkg/api/resolver_mutation_studio.go b/pkg/api/resolver_mutation_studio.go index 3e11225..d3a9827 100644 --- a/pkg/api/resolver_mutation_studio.go +++ b/pkg/api/resolver_mutation_studio.go @@ -2,7 +2,10 @@ package api import ( "context" + "strconv" + "time" + "github.com/stashapp/stashdb/pkg/database" "github.com/stashapp/stashdb/pkg/models" ) @@ -11,7 +14,45 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio return nil, err } - return nil, nil + var err error + + if err != nil { + return nil, err + } + + // Populate a new studio from the input + currentTime := time.Now() + newStudio := models.Studio{ + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + } + + newStudio.CopyFromCreateInput(input) + + // Start the transaction and save the studio + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewStudioQueryBuilder(tx) + studio, err := qb.Create(newStudio) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // TODO - save child studios + + // Save the URLs + studioUrls := models.CreateStudioUrls(studio.ID, input.Urls) + if err := qb.CreateUrls(studioUrls); err != nil { + _ = tx.Rollback() + return nil, err + } + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return studio, nil } func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { @@ -19,7 +60,44 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio return nil, err } - return nil, nil + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewStudioQueryBuilder(tx) + + // get the existing studio and modify it + studioID, _ := strconv.ParseInt(input.ID, 10, 64) + updatedStudio, err := qb.Find(studioID) + + if err != nil { + return nil, err + } + + updatedStudio.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()} + + // Populate studio from the input + updatedStudio.CopyFromUpdateInput(input) + + studio, err := qb.Update(*updatedStudio) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the URLs + // TODO - only do this if provided + studioUrls := models.CreateStudioUrls(studio.ID, input.Urls) + if err := qb.UpdateUrls(studio.ID, studioUrls); err != nil { + _ = tx.Rollback() + return nil, err + } + + // TODO - handle child studios + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return studio, nil } func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) { @@ -27,5 +105,23 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi return false, err } + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewStudioQueryBuilder(tx) + + // references have on delete cascade, so shouldn't be necessary + // to remove them explicitly + + studioID, err := strconv.ParseInt(input.ID, 10, 64) + if err != nil { + return false, err + } + if err = qb.Destroy(studioID); err != nil { + _ = tx.Rollback() + return false, err + } + + if err := tx.Commit(); err != nil { + return false, err + } return true, nil } diff --git a/pkg/api/resolver_mutation_tag.go b/pkg/api/resolver_mutation_tag.go index edff591..63d9cda 100644 --- a/pkg/api/resolver_mutation_tag.go +++ b/pkg/api/resolver_mutation_tag.go @@ -2,7 +2,10 @@ package api import ( "context" + "strconv" + "time" + "github.com/stashapp/stashdb/pkg/database" "github.com/stashapp/stashdb/pkg/models" ) @@ -11,7 +14,43 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate return nil, err } - return nil, nil + var err error + + if err != nil { + return nil, err + } + + // Populate a new performer from the input + currentTime := time.Now() + newTag := models.Tag{ + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + } + + newTag.CopyFromCreateInput(input) + + // Start the transaction and save the performer + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewTagQueryBuilder(tx) + tag, err := qb.Create(newTag) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the aliases + tagAliases := models.CreateTagAliases(tag.ID, input.Aliases) + if err := qb.CreateAliases(tagAliases); err != nil { + _ = tx.Rollback() + return nil, err + } + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return tag, nil } func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) { @@ -19,7 +58,42 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate return nil, err } - return nil, nil + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewTagQueryBuilder(tx) + + // get the existing tag and modify it + tagID, _ := strconv.ParseInt(input.ID, 10, 64) + updatedTag, err := qb.Find(tagID) + + if err != nil { + return nil, err + } + + updatedTag.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()} + + // Populate performer from the input + updatedTag.CopyFromUpdateInput(input) + + tag, err := qb.Update(*updatedTag) + if err != nil { + _ = tx.Rollback() + return nil, err + } + + // Save the aliases + // TODO - only do this if provided + tagAliases := models.CreateTagAliases(tag.ID, input.Aliases) + if err := qb.UpdateAliases(tag.ID, tagAliases); err != nil { + _ = tx.Rollback() + return nil, err + } + + // Commit + if err := tx.Commit(); err != nil { + return nil, err + } + + return tag, nil } func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) { @@ -27,5 +101,23 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr return false, err } + tx := database.DB.MustBeginTx(ctx, nil) + qb := models.NewTagQueryBuilder(tx) + + // references have on delete cascade, so shouldn't be necessary + // to remove them explicitly + + tagID, err := strconv.ParseInt(input.ID, 10, 64) + if err != nil { + return false, err + } + if err = qb.Destroy(tagID); err != nil { + _ = tx.Rollback() + return false, err + } + + if err := tx.Commit(); err != nil { + return false, err + } return true, nil } diff --git a/pkg/api/resolver_query_find_performer.go b/pkg/api/resolver_query_find_performer.go index 2319c5d..92d987f 100644 --- a/pkg/api/resolver_query_find_performer.go +++ b/pkg/api/resolver_query_find_performer.go @@ -12,9 +12,9 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (*models.P return nil, err } - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) - idInt, _ := strconv.Atoi(id) + idInt, _ := strconv.ParseInt(id, 10, 64) return qb.Find(idInt) } func (r *queryResolver) QueryPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.QuerySpec) (*models.QueryPerformersResultType, error) { @@ -22,7 +22,7 @@ func (r *queryResolver) QueryPerformers(ctx context.Context, performerFilter *mo return nil, err } - qb := models.NewPerformerQueryBuilder() + qb := models.NewPerformerQueryBuilder(nil) performers, count := qb.Query(performerFilter, filter) return &models.QueryPerformersResultType{ diff --git a/pkg/api/resolver_query_find_scene.go b/pkg/api/resolver_query_find_scene.go index 88a1452..8a66dad 100644 --- a/pkg/api/resolver_query_find_scene.go +++ b/pkg/api/resolver_query_find_scene.go @@ -2,38 +2,42 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) -func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) { - panic("not implemented") +func (r *queryResolver) FindScene(ctx context.Context, id string) (*models.Scene, error) { + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewSceneQueryBuilder(nil) + + idInt, _ := strconv.ParseInt(id, 10, 64) + return qb.Find(idInt) } + +func (r *queryResolver) FindSceneByFingerprint(ctx context.Context, fingerprint models.FingerprintInput) ([]*models.Scene, error) { + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewSceneQueryBuilder(nil) + + return qb.FindByFingerprint(fingerprint.Algorithm, fingerprint.Hash) +} + func (r *queryResolver) QueryScenes(ctx context.Context, sceneFilter *models.SceneFilterType, filter *models.QuerySpec) (*models.QueryScenesResultType, error) { - panic("not implemented") + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewSceneQueryBuilder(nil) + + scenes, count := qb.Query(sceneFilter, filter) + return &models.QueryScenesResultType{ + Scenes: scenes, + Count: count, + }, nil } - -// func (r *queryResolver) FindScene(ctx context.Context, id string) (*models.Scene, error) { -// if err := validateRead(ctx); err != nil { -// return nil, err -// } - -// qb := models.NewSceneQueryBuilder() -// idInt, _ := strconv.Atoi(id) -// var scene *models.Scene -// var err error -// scene, err = qb.Find(idInt) -// return scene, err -// } - -// func (r *queryResolver) FindSceneByChecksum(ctx context.Context, checksum string) (*models.Scene, error) { -// if err := validateRead(ctx); err != nil { -// return nil, err -// } - -// qb := models.NewSceneQueryBuilder() -// var scene *models.Scene -// var err error -// scene, err = qb.FindByChecksum(checksum) -// return scene, err -// } diff --git a/pkg/api/resolver_query_find_studio.go b/pkg/api/resolver_query_find_studio.go index e22748f..9469e58 100644 --- a/pkg/api/resolver_query_find_studio.go +++ b/pkg/api/resolver_query_find_studio.go @@ -2,14 +2,38 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) func (r *queryResolver) FindStudio(ctx context.Context, id *string, name *string) (*models.Studio, error) { - panic("not implemented") + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewStudioQueryBuilder(nil) + + if id != nil { + idInt, _ := strconv.ParseInt(*id, 10, 64) + return qb.Find(idInt) + } else if name != nil { + return qb.FindByName(*name) + } + + return nil, nil } func (r *queryResolver) QueryStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.QuerySpec) (*models.QueryStudiosResultType, error) { - panic("not implemented") + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewStudioQueryBuilder(nil) + + studios, count := qb.Query(studioFilter, filter) + return &models.QueryStudiosResultType{ + Studios: studios, + Count: count, + }, nil } diff --git a/pkg/api/resolver_query_find_tag.go b/pkg/api/resolver_query_find_tag.go index bdf1552..209eb03 100644 --- a/pkg/api/resolver_query_find_tag.go +++ b/pkg/api/resolver_query_find_tag.go @@ -2,14 +2,38 @@ package api import ( "context" + "strconv" "github.com/stashapp/stashdb/pkg/models" ) func (r *queryResolver) FindTag(ctx context.Context, id *string, name *string) (*models.Tag, error) { - panic("not implemented") + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewTagQueryBuilder(nil) + + if id != nil { + idInt, _ := strconv.ParseInt(*id, 10, 64) + return qb.Find(idInt) + } else if name != nil { + return qb.FindByNameOrAlias(*name) + } + + return nil, nil } func (r *queryResolver) QueryTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.QuerySpec) (*models.QueryTagsResultType, error) { - panic("not implemented") + if err := validateRead(ctx); err != nil { + return nil, err + } + + qb := models.NewTagQueryBuilder(nil) + + tags, count := qb.Query(tagFilter, filter) + return &models.QueryTagsResultType{ + Tags: tags, + Count: count, + }, nil } diff --git a/pkg/api/routes_performer.go b/pkg/api/routes_performer.go new file mode 100644 index 0000000..c770b8d --- /dev/null +++ b/pkg/api/routes_performer.go @@ -0,0 +1,53 @@ +package api + +import ( + "context" + "net/http" + "strconv" + + "github.com/go-chi/chi" + "github.com/stashapp/stashdb/pkg/models" +) + +type performerRoutes struct{} + +func (rs performerRoutes) Routes() chi.Router { + r := chi.NewRouter() + + r.Route("/{performerId}", func(r chi.Router) { + r.Use(PerformerCtx) + r.Get("/image", rs.Image) + }) + + return r +} + +func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { + if err := validateRead(r.Context()); err != nil { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + + performer := r.Context().Value(performerKey).(*models.Performer) + _, _ = w.Write(performer.Image) +} + +func PerformerCtx(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + performerID, err := strconv.ParseInt(chi.URLParam(r, "performerId"), 10, 64) + if err != nil { + http.Error(w, http.StatusText(404), 404) + return + } + + qb := models.NewPerformerQueryBuilder(nil) + performer, err := qb.Find(performerID) + if err != nil { + http.Error(w, http.StatusText(404), 404) + return + } + + ctx := context.WithValue(r.Context(), performerKey, performer) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/pkg/api/scene_integration_test.go b/pkg/api/scene_integration_test.go new file mode 100644 index 0000000..384b714 --- /dev/null +++ b/pkg/api/scene_integration_test.go @@ -0,0 +1,486 @@ +// +build integration + +package api_test + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stashapp/stashdb/pkg/models" + + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" +) + +type sceneTestRunner struct { + testRunner +} + +func createSceneTestRunner(t *testing.T) *sceneTestRunner { + return &sceneTestRunner{ + testRunner: *createTestRunner(t), + } +} + +func (s *sceneTestRunner) testCreateScene() { + title := "Title" + details := "Details" + url := "URL" + date := "2003-02-01" + + performer, _ := s.createTestPerformer(nil) + studio, _ := s.createTestScene(nil) + tag, _ := s.createTestTag(nil) + + performerID := strconv.FormatInt(performer.ID, 10) + studioID := strconv.FormatInt(studio.ID, 10) + tagID := strconv.FormatInt(tag.ID, 10) + + performerAlias := "alias" + + input := models.SceneCreateInput{ + Title: &title, + Details: &details, + URL: &url, + Date: &date, + Fingerprints: []*models.FingerprintInput{ + s.generateSceneFingerprint(), + s.generateSceneFingerprint(), + }, + StudioID: &studioID, + Performers: []*models.PerformerAppearanceInput{ + &models.PerformerAppearanceInput{ + PerformerID: performerID, + As: &performerAlias, + }, + }, + TagIds: []string{ + tagID, + }, + } + + scene, err := s.resolver.Mutation().SceneCreate(s.ctx, input) + + if err != nil { + s.t.Errorf("Error creating scene: %s", err.Error()) + return + } + + s.verifyCreatedScene(input, scene) +} + +func comparePerformers(input []*models.PerformerAppearanceInput, performers []*models.PerformerAppearance) bool { + if len(performers) != len(input) { + return false + } + + for i, v := range performers { + performerID := strconv.FormatInt(v.Performer.ID, 10) + if performerID != input[i].PerformerID { + return false + } + + if v.As != input[i].As { + if v.As == nil || input[i].As == nil { + return false + } + + if *v.As != *input[i].As { + return false + } + } + } + + return true +} + +func compareTags(tagIDs []string, tags []*models.Tag) bool { + if len(tags) != len(tagIDs) { + return false + } + + for i, v := range tags { + tagID := strconv.FormatInt(v.ID, 10) + if tagID != tagIDs[i] { + return false + } + } + + return true +} + +func compareFingerprints(input []*models.FingerprintInput, fingerprints []*models.Fingerprint) bool { + if len(input) != len(fingerprints) { + return false + } + + for i, v := range fingerprints { + if input[i].Algorithm != v.Algorithm || input[i].Hash != v.Hash { + return false + } + } + + return true +} + +func (s *sceneTestRunner) verifyCreatedScene(input models.SceneCreateInput, scene *models.Scene) { + // ensure basic attributes are set correctly + r := s.resolver.Scene() + + id, _ := r.ID(s.ctx, scene) + if id == "" { + s.t.Errorf("Expected created scene id to be non-zero") + } + + if v, _ := r.Title(s.ctx, scene); !reflect.DeepEqual(v, input.Title) { + s.fieldMismatch(*input.Title, v, "Title") + } + + if v, _ := r.Details(s.ctx, scene); !reflect.DeepEqual(v, input.Details) { + s.fieldMismatch(input.Details, v, "Details") + } + + if v, _ := r.URL(s.ctx, scene); !reflect.DeepEqual(v, input.URL) { + s.fieldMismatch(*input.URL, v, "URL") + } + + if v, _ := r.Date(s.ctx, scene); !reflect.DeepEqual(v, input.Date) { + s.fieldMismatch(*input.Date, v, "Date") + } + + if v, _ := r.Fingerprints(s.ctx, scene); !compareFingerprints(input.Fingerprints, v) { + s.fieldMismatch(input.Fingerprints, v, "Fingerprints") + } + + performers, err := s.resolver.Scene().Performers(s.ctx, scene) + if err != nil { + s.t.Errorf("Error getting scene performers: %s", err.Error()) + } + + if !comparePerformers(input.Performers, performers) { + s.fieldMismatch(input.Performers, performers, "Performers") + } + + tags, err := s.resolver.Scene().Tags(s.ctx, scene) + if err != nil { + s.t.Errorf("Error getting scene tags: %s", err.Error()) + } + + if !compareTags(input.TagIds, tags) { + s.fieldMismatch(input.TagIds, tags, "Tags") + } +} + +func (s *sceneTestRunner) testFindSceneById() { + createdScene, err := s.createTestScene(nil) + if err != nil { + return + } + + sceneID := strconv.FormatInt(createdScene.ID, 10) + scene, err := s.resolver.Query().FindScene(s.ctx, sceneID) + if err != nil { + s.t.Errorf("Error finding scene: %s", err.Error()) + return + } + + // ensure returned scene is not nil + if scene == nil { + s.t.Error("Did not find scene by id") + return + } + + // ensure values were set + if createdScene.Title != scene.Title { + s.fieldMismatch(createdScene.Title, scene.Title, "Title") + } +} + +func (s *sceneTestRunner) testFindSceneByFingerprint() { + createdScene, err := s.createTestScene(nil) + if err != nil { + return + } + + fingerprints, err := s.resolver.Scene().Fingerprints(s.ctx, createdScene) + fingerprint := models.FingerprintInput{ + Algorithm: fingerprints[0].Algorithm, + Hash: fingerprints[0].Hash, + } + scenes, err := s.resolver.Query().FindSceneByFingerprint(s.ctx, fingerprint) + if err != nil { + s.t.Errorf("Error finding scene: %s", err.Error()) + return + } + + // ensure returned scene is not nil + if len(scenes) == 0 { + s.t.Error("Did not find scene by fingerprint") + return + } + + // ensure values were set + if createdScene.Title != scenes[0].Title { + s.fieldMismatch(createdScene.Title, scenes[0].Title, "Title") + } +} + +func (s *sceneTestRunner) testUpdateScene() { + title := "Title" + details := "Details" + url := "URL" + date := "2003-02-01" + + performer, _ := s.createTestPerformer(nil) + studio, _ := s.createTestScene(nil) + tag, _ := s.createTestTag(nil) + + performerID := strconv.FormatInt(performer.ID, 10) + studioID := strconv.FormatInt(studio.ID, 10) + tagID := strconv.FormatInt(tag.ID, 10) + + performerAlias := "alias" + + input := models.SceneCreateInput{ + Title: &title, + Details: &details, + URL: &url, + Date: &date, + Fingerprints: []*models.FingerprintInput{ + s.generateSceneFingerprint(), + s.generateSceneFingerprint(), + }, + StudioID: &studioID, + Performers: []*models.PerformerAppearanceInput{ + &models.PerformerAppearanceInput{ + PerformerID: performerID, + As: &performerAlias, + }, + }, + TagIds: []string{ + tagID, + }, + } + + createdScene, err := s.createTestScene(&input) + if err != nil { + return + } + + sceneID := strconv.FormatInt(createdScene.ID, 10) + + newTitle := "NewTitle" + newDetails := "NewDetails" + newURL := "NewURL" + newDate := "2001-02-03" + + performer, _ = s.createTestPerformer(nil) + studio, _ = s.createTestScene(nil) + tag, _ = s.createTestTag(nil) + + performerID = strconv.FormatInt(performer.ID, 10) + studioID = strconv.FormatInt(studio.ID, 10) + tagID = strconv.FormatInt(tag.ID, 10) + + performerAlias = "updatedAlias" + + updateInput := models.SceneUpdateInput{ + ID: sceneID, + Title: &newTitle, + Details: &newDetails, + URL: &newURL, + Date: &newDate, + Fingerprints: []*models.FingerprintInput{ + s.generateSceneFingerprint(), + }, + Performers: []*models.PerformerAppearanceInput{ + &models.PerformerAppearanceInput{ + PerformerID: performerID, + As: &performerAlias, + }, + }, + StudioID: &studioID, + TagIds: []string{ + tagID, + }, + } + + // need some mocking of the context to make the field ignore behaviour work + ctx := s.updateContext([]string{ + "fingerprints", + "performers", + "tagIds", + }) + + updatedScene, err := s.resolver.Mutation().SceneUpdate(ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating scene: %s", err.Error()) + return + } + + s.verifyUpdatedScene(updateInput, updatedScene) +} + +func (s *sceneTestRunner) testUpdateSceneTitle() { + title := "Title" + details := "Details" + url := "URL" + date := "2003-02-01" + + performer, _ := s.createTestPerformer(nil) + studio, _ := s.createTestScene(nil) + tag, _ := s.createTestTag(nil) + + performerID := strconv.FormatInt(performer.ID, 10) + studioID := strconv.FormatInt(studio.ID, 10) + tagID := strconv.FormatInt(tag.ID, 10) + + performerAlias := "alias" + + input := models.SceneCreateInput{ + Title: &title, + Details: &details, + URL: &url, + Date: &date, + Fingerprints: []*models.FingerprintInput{ + s.generateSceneFingerprint(), + s.generateSceneFingerprint(), + }, + Performers: []*models.PerformerAppearanceInput{ + &models.PerformerAppearanceInput{ + PerformerID: performerID, + As: &performerAlias, + }, + }, + StudioID: &studioID, + TagIds: []string{ + tagID, + }, + } + + createdScene, err := s.createTestScene(&input) + if err != nil { + return + } + + sceneID := strconv.FormatInt(createdScene.ID, 10) + newTitle := "NewTitle" + + updateInput := models.SceneUpdateInput{ + ID: sceneID, + Title: &newTitle, + } + + // need some mocking of the context to make the field ignore behaviour work + ctx := s.updateContext([]string{ + "title", + }) + updatedScene, err := s.resolver.Mutation().SceneUpdate(ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating scene: %s", err.Error()) + return + } + + input.Title = &newTitle + s.verifyCreatedScene(input, updatedScene) +} + +func (s *sceneTestRunner) verifyUpdatedScene(input models.SceneUpdateInput, scene *models.Scene) { + // ensure basic attributes are set correctly + r := s.resolver.Scene() + + if v, _ := r.Title(s.ctx, scene); !reflect.DeepEqual(v, input.Title) { + s.fieldMismatch(input.Title, v, "Title") + } + + if v, _ := r.Details(s.ctx, scene); !reflect.DeepEqual(v, input.Details) { + s.fieldMismatch(input.Details, v, "Details") + } + + if v, _ := r.URL(s.ctx, scene); !reflect.DeepEqual(v, input.URL) { + s.fieldMismatch(input.URL, v, "URL") + } + + if v, _ := r.Date(s.ctx, scene); !reflect.DeepEqual(v, input.Date) { + s.fieldMismatch(input.Date, v, "Date") + } + + if v, _ := r.Fingerprints(s.ctx, scene); !compareFingerprints(input.Fingerprints, v) { + s.fieldMismatch(input.Fingerprints, v, "Fingerprints") + } + + performers, _ := s.resolver.Scene().Performers(s.ctx, scene) + if !comparePerformers(input.Performers, performers) { + s.fieldMismatch(input.Performers, performers, "Performers") + } + + tags, _ := s.resolver.Scene().Tags(s.ctx, scene) + if !compareTags(input.TagIds, tags) { + s.fieldMismatch(input.TagIds, tags, "Tags") + } +} + +func (s *sceneTestRunner) testDestroyScene() { + createdScene, err := s.createTestScene(nil) + if err != nil { + return + } + + sceneID := strconv.FormatInt(createdScene.ID, 10) + + destroyed, err := s.resolver.Mutation().SceneDestroy(s.ctx, models.SceneDestroyInput{ + ID: sceneID, + }) + if err != nil { + s.t.Errorf("Error destroying scene: %s", err.Error()) + return + } + + if !destroyed { + s.t.Error("Scene was not destroyed") + return + } + + // ensure cannot find scene + foundScene, err := s.resolver.Query().FindScene(s.ctx, sceneID) + if err != nil { + s.t.Errorf("Error finding scene after destroying: %s", err.Error()) + return + } + + if foundScene != nil { + s.t.Error("Found scene after destruction") + } + + // TODO - ensure scene was not removed +} + +func TestCreateScene(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testCreateScene() +} + +func TestFindSceneById(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testFindSceneById() +} + +func TestFindSceneByFingerprint(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testFindSceneByFingerprint() +} + +func TestUpdateScene(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testUpdateScene() +} + +func TestUpdateSceneTitle(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testUpdateSceneTitle() +} + +func TestDestroyScene(t *testing.T) { + pt := createSceneTestRunner(t) + pt.testDestroyScene() +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 7a8ba9c..5277854 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -110,6 +110,8 @@ func Start() { // TODO - this should be disabled in production r.Handle("/playground", handler.Playground("GraphQL playground", "/graphql")) + r.Mount("/performer", performerRoutes{}.Routes()) + address := config.GetHost() + ":" + strconv.Itoa(config.GetPort()) if tlsConfig := makeTLSConfig(); tlsConfig != nil { httpsServer := &http.Server{ diff --git a/pkg/api/studio_integration_test.go b/pkg/api/studio_integration_test.go new file mode 100644 index 0000000..804c330 --- /dev/null +++ b/pkg/api/studio_integration_test.go @@ -0,0 +1,208 @@ +// +build integration + +package api_test + +import ( + "strconv" + "testing" + + "github.com/stashapp/stashdb/pkg/models" + + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" +) + +type studioTestRunner struct { + testRunner + studioSuffix int +} + +func createStudioTestRunner(t *testing.T) *studioTestRunner { + return &studioTestRunner{ + testRunner: *createTestRunner(t), + } +} + +func (s *studioTestRunner) generateStudioName() string { + s.studioSuffix += 1 + return "studioTestRunner-" + strconv.Itoa(s.studioSuffix) +} + +func (s *studioTestRunner) testCreateStudio() { + input := models.StudioCreateInput{ + Name: s.generateStudioName(), + } + + studio, err := s.resolver.Mutation().StudioCreate(s.ctx, input) + + if err != nil { + s.t.Errorf("Error creating studio: %s", err.Error()) + return + } + + s.verifyCreatedStudio(input, studio) +} + +func (s *studioTestRunner) verifyCreatedStudio(input models.StudioCreateInput, studio *models.Studio) { + // ensure basic attributes are set correctly + if input.Name != studio.Name { + s.fieldMismatch(input.Name, studio.Name, "Name") + } + + r := s.resolver.Studio() + + id, _ := r.ID(s.ctx, studio) + if id == "" { + s.t.Errorf("Expected created studio id to be non-zero") + } +} + +func (s *studioTestRunner) testFindStudioById() { + createdStudio, err := s.createTestStudio(nil) + if err != nil { + return + } + + studioID := strconv.FormatInt(createdStudio.ID, 10) + studio, err := s.resolver.Query().FindStudio(s.ctx, &studioID, nil) + if err != nil { + s.t.Errorf("Error finding studio: %s", err.Error()) + return + } + + // ensure returned studio is not nil + if studio == nil { + s.t.Error("Did not find studio by id") + return + } + + // ensure values were set + if createdStudio.Name != studio.Name { + s.fieldMismatch(createdStudio.Name, studio.Name, "Name") + } +} + +func (s *studioTestRunner) testFindStudioByName() { + createdStudio, err := s.createTestStudio(nil) + if err != nil { + return + } + + studioName := createdStudio.Name + studio, err := s.resolver.Query().FindStudio(s.ctx, nil, &studioName) + if err != nil { + s.t.Errorf("Error finding studio: %s", err.Error()) + return + } + + // ensure returned studio is not nil + if studio == nil { + s.t.Error("Did not find studio by name") + return + } + + // ensure values were set + if createdStudio.Name != studio.Name { + s.fieldMismatch(createdStudio.Name, studio.Name, "Name") + } +} + +func (s *studioTestRunner) testUpdateStudioName() { + input := &models.StudioCreateInput{ + Name: s.generateStudioName(), + } + + createdStudio, err := s.createTestStudio(input) + if err != nil { + return + } + + studioID := strconv.FormatInt(createdStudio.ID, 10) + + updatedName := s.generateStudioName() + updateInput := models.StudioUpdateInput{ + ID: studioID, + Name: &updatedName, + } + + // need some mocking of the context to make the field ignore behaviour work + ctx := s.updateContext([]string{ + "name", + }) + updatedStudio, err := s.resolver.Mutation().StudioUpdate(ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating studio: %s", err.Error()) + return + } + + input.Name = updatedName + s.verifyCreatedStudio(*input, updatedStudio) +} + +func (s *studioTestRunner) verifyUpdatedStudio(input models.StudioUpdateInput, studio *models.Studio) { + // ensure basic attributes are set correctly + if input.Name != nil && *input.Name != studio.Name { + s.fieldMismatch(input.Name, studio.Name, "Name") + } +} + +func (s *studioTestRunner) testDestroyStudio() { + createdStudio, err := s.createTestStudio(nil) + if err != nil { + return + } + + studioID := strconv.FormatInt(createdStudio.ID, 10) + + destroyed, err := s.resolver.Mutation().StudioDestroy(s.ctx, models.StudioDestroyInput{ + ID: studioID, + }) + if err != nil { + s.t.Errorf("Error destroying studio: %s", err.Error()) + return + } + + if !destroyed { + s.t.Error("Studio was not destroyed") + return + } + + // ensure cannot find studio + foundStudio, err := s.resolver.Query().FindStudio(s.ctx, &studioID, nil) + if err != nil { + s.t.Errorf("Error finding studio after destroying: %s", err.Error()) + return + } + + if foundStudio != nil { + s.t.Error("Found studio after destruction") + } + + // TODO - ensure scene was not removed +} + +func TestCreateStudio(t *testing.T) { + pt := createStudioTestRunner(t) + pt.testCreateStudio() +} + +func TestFindStudioById(t *testing.T) { + pt := createStudioTestRunner(t) + pt.testFindStudioById() +} + +func TestFindStudioByName(t *testing.T) { + pt := createStudioTestRunner(t) + pt.testFindStudioByName() +} + +func TestUpdateStudioName(t *testing.T) { + pt := createStudioTestRunner(t) + pt.testUpdateStudioName() +} + +func TestDestroyStudio(t *testing.T) { + pt := createStudioTestRunner(t) + pt.testDestroyStudio() +} + +// TODO - test parent/children studios diff --git a/pkg/api/tag_integration_test.go b/pkg/api/tag_integration_test.go new file mode 100644 index 0000000..50e9713 --- /dev/null +++ b/pkg/api/tag_integration_test.go @@ -0,0 +1,209 @@ +// +build integration + +package api_test + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stashapp/stashdb/pkg/models" + + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" +) + +type tagTestRunner struct { + testRunner +} + +func createTagTestRunner(t *testing.T) *tagTestRunner { + return &tagTestRunner{ + testRunner: *createTestRunner(t), + } +} + +func (s *tagTestRunner) testCreateTag() { + description := "Description" + + input := models.TagCreateInput{ + Name: s.generateTagName(), + Description: &description, + } + + tag, err := s.resolver.Mutation().TagCreate(s.ctx, input) + + if err != nil { + s.t.Errorf("Error creating tag: %s", err.Error()) + return + } + + s.verifyCreatedTag(input, tag) +} + +func (s *tagTestRunner) verifyCreatedTag(input models.TagCreateInput, tag *models.Tag) { + // ensure basic attributes are set correctly + if input.Name != tag.Name { + s.fieldMismatch(input.Name, tag.Name, "Name") + } + + r := s.resolver.Tag() + + id, _ := r.ID(s.ctx, tag) + if id == "" { + s.t.Errorf("Expected created tag id to be non-zero") + } + + if v, _ := r.Description(s.ctx, tag); !reflect.DeepEqual(v, input.Description) { + s.fieldMismatch(*input.Description, v, "Description") + } + +} + +func (s *tagTestRunner) testFindTagById() { + createdTag, err := s.createTestTag(nil) + if err != nil { + return + } + + tagID := strconv.FormatInt(createdTag.ID, 10) + tag, err := s.resolver.Query().FindTag(s.ctx, &tagID, nil) + if err != nil { + s.t.Errorf("Error finding tag: %s", err.Error()) + return + } + + // ensure returned tag is not nil + if tag == nil { + s.t.Error("Did not find tag by id") + return + } + + // ensure values were set + if createdTag.Name != tag.Name { + s.fieldMismatch(createdTag.Name, tag.Name, "Name") + } +} + +func (s *tagTestRunner) testFindTagByName() { + createdTag, err := s.createTestTag(nil) + if err != nil { + return + } + + tagName := createdTag.Name + + tag, err := s.resolver.Query().FindTag(s.ctx, nil, &tagName) + if err != nil { + s.t.Errorf("Error finding tag: %s", err.Error()) + return + } + + // ensure returned tag is not nil + if tag == nil { + s.t.Error("Did not find tag by name") + return + } + + // ensure values were set + if createdTag.Name != tag.Name { + s.fieldMismatch(createdTag.Name, tag.Name, "Name") + } +} + +func (s *tagTestRunner) testUpdateTag() { + createdTag, err := s.createTestTag(nil) + if err != nil { + return + } + + tagID := strconv.FormatInt(createdTag.ID, 10) + + newDescription := "newDescription" + + updateInput := models.TagUpdateInput{ + ID: tagID, + Description: &newDescription, + } + + updatedTag, err := s.resolver.Mutation().TagUpdate(s.ctx, updateInput) + if err != nil { + s.t.Errorf("Error updating tag: %s", err.Error()) + return + } + + updateInput.Name = &createdTag.Name + s.verifyUpdatedTag(updateInput, updatedTag) +} + +func (s *tagTestRunner) verifyUpdatedTag(input models.TagUpdateInput, tag *models.Tag) { + // ensure basic attributes are set correctly + if input.Name != nil && *input.Name != tag.Name { + s.fieldMismatch(input.Name, tag.Name, "Name") + } + + r := s.resolver.Tag() + + if v, _ := r.Description(s.ctx, tag); !reflect.DeepEqual(v, input.Description) { + s.fieldMismatch(input.Description, v, "Description") + } +} + +func (s *tagTestRunner) testDestroyTag() { + createdTag, err := s.createTestTag(nil) + if err != nil { + return + } + + tagID := strconv.FormatInt(createdTag.ID, 10) + + destroyed, err := s.resolver.Mutation().TagDestroy(s.ctx, models.TagDestroyInput{ + ID: tagID, + }) + if err != nil { + s.t.Errorf("Error destroying tag: %s", err.Error()) + return + } + + if !destroyed { + s.t.Error("Tag was not destroyed") + return + } + + // ensure cannot find tag + foundTag, err := s.resolver.Query().FindTag(s.ctx, &tagID, nil) + if err != nil { + s.t.Errorf("Error finding tag after destroying: %s", err.Error()) + return + } + + if foundTag != nil { + s.t.Error("Found tag after destruction") + } + + // TODO - ensure scene was not removed +} + +func TestCreateTag(t *testing.T) { + pt := createTagTestRunner(t) + pt.testCreateTag() +} + +func TestFindTagById(t *testing.T) { + pt := createTagTestRunner(t) + pt.testFindTagById() +} + +func TestFindTagByName(t *testing.T) { + pt := createTagTestRunner(t) + pt.testFindTagByName() +} + +func TestUpdateTag(t *testing.T) { + pt := createTagTestRunner(t) + pt.testUpdateTag() +} + +func TestDestroyTag(t *testing.T) { + pt := createTagTestRunner(t) + pt.testDestroyTag() +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 6c834c0..0dcb189 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -3,7 +3,9 @@ package database import ( "database/sql" "fmt" + "os" "regexp" + "github.com/gobuffalo/packr/v2" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/source" @@ -11,7 +13,6 @@ import ( sqlite3 "github.com/mattn/go-sqlite3" "github.com/stashapp/stashdb/pkg/logger" "github.com/stashapp/stashdb/pkg/utils" - "os" ) var DB *sqlx.DB @@ -68,6 +69,8 @@ func runMigrations(databasePath string) { panic(err.Error()) } } + + m.Close() } func registerRegexpFunc() { diff --git a/pkg/database/databasetest/database_test_utils.go b/pkg/database/databasetest/database_test_utils.go new file mode 100644 index 0000000..56e859d --- /dev/null +++ b/pkg/database/databasetest/database_test_utils.go @@ -0,0 +1,76 @@ +package databasetest + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/stashapp/stashdb/pkg/database" +) + +type DatabasePopulater interface { + PopulateDB() error +} + +func testTeardown(databaseFile string) { + err := database.DB.Close() + + if err != nil { + panic(err) + } + + err = os.Remove(databaseFile) + if err != nil { + panic(err) + } +} + +func runTests(m *testing.M, populater DatabasePopulater) int { + // create the database file + f, err := ioutil.TempFile("", "*.sqlite") + if err != nil { + panic(fmt.Sprintf("Could not create temporary file: %s", err.Error())) + } + + f.Close() + databaseFile := f.Name() + database.Initialize(databaseFile) + + // defer close and delete the database + defer testTeardown(databaseFile) + + if populater != nil { + err = populater.PopulateDB() + if err != nil { + panic(fmt.Sprintf("Could not populate database: %s", err.Error())) + } + } + + // run the tests + return m.Run() +} + +func TestWithDatabase(m *testing.M, populater DatabasePopulater) { + ret := runTests(m, populater) + os.Exit(ret) +} + +func WithTransientTransaction(ctx context.Context, fn database.TxFunc) { + txn := database.NewTransaction(ctx) + txn.Begin(ctx) + + defer func() { + if p := recover(); p != nil { + // a panic occurred, rollback and repanic + txn.Rollback() + panic(p) + } else { + // something went wrong, rollback + txn.Rollback() + } + }() + + fn(txn) +} diff --git a/pkg/database/dbi.go b/pkg/database/dbi.go new file mode 100644 index 0000000..1679341 --- /dev/null +++ b/pkg/database/dbi.go @@ -0,0 +1,249 @@ +package database + +import ( + "database/sql" + "fmt" + "reflect" + + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" +) + +// The DBI interface is used to interface with the database. +type DBI interface { + // Insert inserts the provided object as a row into the database. + // It returns the new object. + Insert(model Model) (interface{}, error) + + // InsertJoin inserts a join object into the provided join table. + InsertJoin(tableJoin TableJoin, object interface{}) error + + // InsertJoins inserts multiple join objects into the provided join table. + InsertJoins(tableJoin TableJoin, joins Joins) error + + // Update updates a database row based on the id and values of the provided + // object. It returns the updated object. Update will return an error if + // the object with id does not exist in the database table. + Update(model Model) (interface{}, error) + + // ReplaceJoins replaces table join objects with the provided primary table + // id value with the provided join objects. + ReplaceJoins(tableJoin TableJoin, id int64, objects Joins) error + + // Delete deletes the table row with the provided id. Delete returns an + // error if the id does not exist in the database table. + Delete(id int64, table Table) error + + // DeleteJoins deletes all join objects with the provided primary table + // id value. + DeleteJoins(tableJoin TableJoin, id int64) error + + // Find returns the row object with the provided id, or returns nil if not + // found. + Find(id int64, table Table) (interface{}, error) + + // FindJoins returns join objects where the foreign key id is equal to the + // provided id. The join objects are output to the provided output slice. + FindJoins(tableJoin TableJoin, id int64, output Joins) error + + // RawQuery performs a query on the provided table using the query string + // and argument slice. It outputs the results to the output slice. + RawQuery(table Table, query string, args []interface{}, output Models) error +} + +type dbi struct { + tx *sqlx.Tx +} + +// DBIWithTxn returns a DBI interface that is to operate within a transaction. +func DBIWithTxn(tx *sqlx.Tx) DBI { + return &dbi{ + tx: tx, + } +} + +// DBINoTxn returns a DBI interface that is to operate outside of a transaction. +// This DBI will not be able to mutate the database. +func DBINoTxn() DBI { + return &dbi{} +} + +// Insert inserts the provided object as a row into the database. +// It returns the new object. +func (q dbi) Insert(model Model) (interface{}, error) { + tableName := model.GetTable().Name() + id, err := insertObject(q.tx, tableName, model) + + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error creating %s", reflect.TypeOf(model).Name())) + } + + // don't want to modify the existing object + newModel := model.GetTable().NewObject() + if err := getByID(q.tx, tableName, id, newModel); err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error getting %s after create", reflect.TypeOf(model).Name())) + } + + return newModel, nil +} + +// Update updates a database row based on the id and values of the provided +// object. It returns the updated object. Update will return an error if +// the object with id does not exist in the database table. +func (q dbi) Update(model Model) (interface{}, error) { + tableName := model.GetTable().Name() + err := updateObjectByID(q.tx, tableName, model) + + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error updating %s", reflect.TypeOf(model).Name())) + } + + // don't want to modify the existing object + updatedModel := model.GetTable().NewObject() + if err := getByID(q.tx, tableName, model.GetID(), updatedModel); err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("Error getting %s after update", reflect.TypeOf(model).Name())) + } + + return updatedModel, nil +} + +// Delete deletes the table row with the provided id. Delete returns an +// error if the id does not exist in the database table. +func (q dbi) Delete(id int64, table Table) error { + o, err := q.Find(id, table) + + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Error deleting from %s", table.Name())) + } + + if o == nil { + return fmt.Errorf("Row with id %d not found in %s", id, table.Name()) + } + + return executeDeleteQuery(table.Name(), id, q.tx) +} + +func selectStatement(table Table) string { + tableName := table.Name() + return fmt.Sprintf("SELECT %s.* FROM %s", tableName, tableName) +} + +func (q dbi) queryx(query string, args ...interface{}) (*sqlx.Rows, error) { + if q.tx != nil { + return q.tx.Queryx(query, args...) + } else { + return DB.Queryx(query, args...) + } +} + +// Find returns the row object with the provided id, or returns nil if not +// found. +func (q dbi) Find(id int64, table Table) (interface{}, error) { + query := selectStatement(table) + " WHERE id = ? LIMIT 1" + args := []interface{}{id} + + var rows *sqlx.Rows + var err error + rows, err = q.queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + output := table.NewObject() + if rows.Next() { + if err := rows.StructScan(output); err != nil { + return nil, err + } + } else { + // not found + return nil, nil + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return output, nil +} + +// InsertJoin inserts a join object into the provided join table. +func (q dbi) InsertJoin(tableJoin TableJoin, object interface{}) error { + _, err := insertObject(q.tx, tableJoin.Name(), object) + + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Error creating %s", reflect.TypeOf(object).Name())) + } + + return nil +} + +// InsertJoins inserts multiple join objects into the provided join table. +func (q dbi) InsertJoins(tableJoin TableJoin, joins Joins) error { + var err error + joins.Each(func(ro interface{}) { + if err != nil { + return + } + + err = q.InsertJoin(tableJoin, ro) + }) + + return err +} + +// ReplaceJoins replaces table join objects with the provided primary table +// id value with the provided join objects. +func (q dbi) ReplaceJoins(tableJoin TableJoin, id int64, joins Joins) error { + err := q.DeleteJoins(tableJoin, id) + + if err != nil { + return err + } + + return q.InsertJoins(tableJoin, joins) +} + +// DeleteJoins deletes all join objects with the provided primary table +// id value. +func (q dbi) DeleteJoins(tableJoin TableJoin, id int64) error { + return deleteObjectsByColumn(q.tx, tableJoin.Name(), tableJoin.joinColumn, id) +} + +// FindJoins returns join objects where the foreign key id is equal to the +// provided id. The join objects are output to the provided output slice. +func (q dbi) FindJoins(tableJoin TableJoin, id int64, output Joins) error { + query := selectStatement(tableJoin.Table) + " WHERE " + tableJoin.joinColumn + " = ?" + args := []interface{}{id} + + return q.RawQuery(tableJoin.Table, query, args, output) +} + +// RawQuery performs a query on the provided table using the query string +// and argument slice. It outputs the results to the output slice. +func (q dbi) RawQuery(table Table, query string, args []interface{}, output Models) error { + var rows *sqlx.Rows + var err error + rows, err = q.queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return err + } + defer rows.Close() + + for rows.Next() { + o := table.NewObject() + if err := rows.StructScan(o); err != nil { + return err + } + + output.Add(o) + } + + if err := rows.Err(); err != nil { + return err + } + + return nil +} diff --git a/pkg/database/migrations/1_initial.up.sql b/pkg/database/migrations/1_initial.up.sql index c82773f..c820f6a 100644 --- a/pkg/database/migrations/1_initial.up.sql +++ b/pkg/database/migrations/1_initial.up.sql @@ -19,7 +19,8 @@ CREATE TABLE `performers` ( `career_start_year` integer, `career_end_year` integer, `created_at` datetime not null, - `updated_at` datetime not null + `updated_at` datetime not null, + unique (`name`, `disambiguation`) ); CREATE TABLE `performer_aliases` ( @@ -59,3 +60,77 @@ CREATE INDEX `index_performers_on_alias` on `performer_aliases` (`alias`); CREATE INDEX `index_performers_on_piercing_location` on `performer_piercings` (`location`); CREATE INDEX `index_performers_on_tattoo_location` on `performer_tattoos` (`location`); CREATE INDEX `index_performers_on_tattoo_description` on `performer_tattoos` (`description`); + +CREATE TABLE `tags` ( + `id` integer not null primary key autoincrement, + `name` varchar(255) not null, + `description` varchar(255), + `created_at` datetime not null, + `updated_at` datetime not null, + unique (`name`) +); + +CREATE TABLE `tag_aliases` ( + `tag_id` integer not null, + `alias` varchar(255) not null, + foreign key(`tag_id`) references `tags`(`id`) ON DELETE CASCADE, + unique (`alias`) +); + +CREATE TABLE `studios` ( + `id` integer not null primary key autoincrement, + `image` blob, + `name` varchar(255) not null, + `parent_studio_id` integer , + `created_at` datetime not null, + `updated_at` datetime not null, + foreign key(`parent_studio_id`) references `studios`(`id`) ON DELETE CASCADE +); + +CREATE TABLE `studio_urls` ( + `studio_id` integer not null, + `url` varchar(255) not null, + `type` varchar(255) not null, + foreign key(`studio_id`) references `studios`(`id`) ON DELETE CASCADE, + unique (`studio_id`, `url`), + unique (`studio_id`, `type`) +); + +CREATE TABLE `scenes` ( + `id` integer not null primary key autoincrement, + `title` varchar(255), + `details` varchar(255), + `url` varchar(255), + `date` date, + `studio_id` integer, + `created_at` datetime not null, + `updated_at` datetime not null, + foreign key(`studio_id`) references `studios`(`id`) ON DELETE SET NULL +); + +CREATE TABLE `scene_fingerprints` ( + `scene_id` integer not null, + `hash` varchar(255) not null, + `algorithm` varchar(20) not null, + foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE, + unique (`scene_id`, `algorithm`, `hash`) +); + +CREATE INDEX `index_scene_fingerprints_on_hash` on `scene_fingerprints` (`algorithm`, `hash`); + +CREATE TABLE `scene_performers` ( + `scene_id` integer not null, + `as` varchar(255), + `performer_id` integer not null, + foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE, + foreign key(`performer_id`) references `performers`(`id`) ON DELETE CASCADE, + unique(`scene_id`, `performer_id`) +); + +CREATE TABLE `scene_tags` ( + `scene_id` integer not null, + `tag_id` integer not null, + foreign key(`scene_id`) references `scenes`(`id`) ON DELETE CASCADE, + foreign key(`tag_id`) references `tags`(`id`) ON DELETE CASCADE, + unique(`scene_id`, `tag_id`) +); diff --git a/pkg/database/sql.go b/pkg/database/sql.go new file mode 100644 index 0000000..ea5a904 --- /dev/null +++ b/pkg/database/sql.go @@ -0,0 +1,187 @@ +package database + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/jmoiron/sqlx" +) + +type optionalValue interface { + IsValid() bool +} + +func ensureTx(tx *sqlx.Tx) { + if tx == nil { + panic("must use a transaction") + } +} + +func getByID(tx *sqlx.Tx, table string, id int64, object interface{}) error { + return tx.Get(object, `SELECT * FROM `+table+` WHERE id = ? LIMIT 1`, id) +} + +func insertObject(tx *sqlx.Tx, table string, object interface{}) (int64, error) { + ensureTx(tx) + fields, values := sqlGenKeysCreate(object) + + result, err := tx.NamedExec( + `INSERT INTO `+table+` (`+fields+`) + VALUES (`+values+`) + `, + object, + ) + + if err != nil { + return 0, err + } + + return result.LastInsertId() +} + +func updateObjectByID(tx *sqlx.Tx, table string, object interface{}) error { + ensureTx(tx) + _, err := tx.NamedExec( + `UPDATE `+table+` SET `+sqlGenKeys(object, false)+` WHERE `+table+`.id = :id`, + object, + ) + + return err +} + +func executeDeleteQuery(tableName string, id int64, tx *sqlx.Tx) error { + if tx == nil { + panic("must use a transaction") + } + idColumnName := getColumn(tableName, "id") + _, err := tx.Exec( + `DELETE FROM `+tableName+` WHERE `+idColumnName+` = ?`, + id, + ) + return err +} + +func deleteObjectsByColumn(tx *sqlx.Tx, table string, column string, value interface{}) error { + ensureTx(tx) + _, err := tx.Exec(`DELETE FROM `+table+` WHERE `+column+` = ?`, value) + return err +} + +func getColumn(tableName string, columnName string) string { + return tableName + "." + columnName +} + +func sqlGenKeysCreate(i interface{}) (string, string) { + var fields []string + var values []string + + addPlaceholder := func(key string) { + fields = append(fields, "`"+key+"`") + values = append(values, ":"+key) + } + + v := reflect.ValueOf(i) + for i := 0; i < v.NumField(); i++ { + //get key for struct tag + rawKey := v.Type().Field(i).Tag.Get("db") + key := strings.Split(rawKey, ",")[0] + if key == "id" { + continue + } + switch t := v.Field(i).Interface().(type) { + case string: + if t != "" { + addPlaceholder(key) + } + case int, int64, float64: + if t != 0 { + addPlaceholder(key) + } + case optionalValue: + if t.IsValid() { + addPlaceholder(key) + } + case sql.NullString: + if t.Valid { + addPlaceholder(key) + } + case sql.NullBool: + if t.Valid { + addPlaceholder(key) + } + case sql.NullInt64: + if t.Valid { + addPlaceholder(key) + } + case sql.NullFloat64: + if t.Valid { + addPlaceholder(key) + } + default: + reflectValue := reflect.ValueOf(t) + isNil := reflectValue.IsNil() + if !isNil { + fields = append(fields, key) + values = append(values, ":"+key) + } + } + } + return strings.Join(fields, ", "), strings.Join(values, ", ") +} + +func sqlGenKeys(i interface{}, partial bool) string { + var query []string + + addKey := func(key string) { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } + + v := reflect.ValueOf(i) + for i := 0; i < v.NumField(); i++ { + //get key for struct tag + rawKey := v.Type().Field(i).Tag.Get("db") + key := strings.Split(rawKey, ",")[0] + if key == "id" { + continue + } + switch t := v.Field(i).Interface().(type) { + case string: + if partial || t != "" { + addKey(key) + } + case int, int64, float64: + if partial || t != 0 { + addKey(key) + } + case optionalValue: + if partial || t.IsValid() { + addKey(key) + } + case sql.NullString: + if partial || t.Valid { + addKey(key) + } + case sql.NullBool: + if partial || t.Valid { + addKey(key) + } + case sql.NullInt64: + if partial || t.Valid { + addKey(key) + } + case sql.NullFloat64: + if partial || t.Valid { + addKey(key) + } + default: + reflectValue := reflect.ValueOf(t) + isNil := reflectValue.IsNil() + if !isNil { + addKey(key) + } + } + } + return strings.Join(query, ", ") +} diff --git a/pkg/database/table.go b/pkg/database/table.go new file mode 100644 index 0000000..a3de566 --- /dev/null +++ b/pkg/database/table.go @@ -0,0 +1,96 @@ +package database + +// NewObjectFunc is a function that returns an instance of an object stored in +// a database table. +type NewObjectFunc func() interface{} + +// Table represents a database table. +type Table struct { + name string + newObjectFn NewObjectFunc +} + +// Name returns the name of the database table. +func (t Table) Name() string { + return t.name +} + +// NewObject returns a new object model of the type that this table stores. +func (t Table) NewObject() interface{} { + return t.newObjectFn() +} + +// NewTable creates a new Table object with the provided table name and new +// object function. +func NewTable(name string, newObjectFn NewObjectFunc) Table { + return Table{ + name: name, + newObjectFn: newObjectFn, + } +} + +// TableJoin represents a database Table that joins two other tables. +type TableJoin struct { + Table + + // the primary table that will be joined to this table + primaryTable string + + // the column in this table that stores the foreign key to the primary table. + joinColumn string +} + +// Creates a new TableJoin instance. The primaryTable is the table that will join +// to the join table. The joinColumn is the name in the join table that stores +// the foreign key in the primary table. +func NewTableJoin(primaryTable string, name string, joinColumn string, newObjectFn func() interface{}) TableJoin { + return TableJoin{ + Table: Table{ + name: name, + newObjectFn: newObjectFn, + }, + primaryTable: primaryTable, + joinColumn: joinColumn, + } +} + +// Inverse creates a TableJoin object that is the inverse of this table join. +// The returns TableJoin object will have this table as the primary table. +func (t TableJoin) Inverse(joinColumn string) TableJoin { + return TableJoin{ + Table: Table{ + name: t.primaryTable, + newObjectFn: t.newObjectFn, + }, + primaryTable: t.Name(), + joinColumn: joinColumn, + } +} + +// Model is the interface implemented by objects that exist in the database +// that have an `id` column. +type Model interface { + // GetTable returns the table that stores objects of this type. + GetTable() Table + + // GetID returns the ID of the object. + GetID() int64 +} + +// Models is the interface implemented by slices of Model objects. +type Models interface { + // Add adds a new object to the slice. It is assumed that the passed + // object can be type asserted to the correct type. + Add(interface{}) +} + +// Joins is the interface implemented by slices of join objects. +type Joins interface { + // Each calls the provided function on each of the concrete (not pointer) + // objects in the slice. + Each(func(interface{})) + + // Add adds a new object to the slice. It is assumed that the passed + // object can be type asserted to the correct type. + Add(interface{}) +} diff --git a/pkg/database/transaction.go b/pkg/database/transaction.go new file mode 100644 index 0000000..88595e8 --- /dev/null +++ b/pkg/database/transaction.go @@ -0,0 +1,93 @@ +package database + +import ( + "context" + + "github.com/jmoiron/sqlx" +) + +type Transaction interface { + Begin(ctx context.Context) *sqlx.Tx + Commit() error + Rollback() error + GetTx() *sqlx.Tx +} + +type transaction struct { + tx *sqlx.Tx + closed bool +} + +func NewTransaction(ctx context.Context) Transaction { + return &transaction{} +} + +func (t *transaction) close() { + t.closed = true +} + +func (t *transaction) Begin(ctx context.Context) *sqlx.Tx { + if t.tx != nil { + panic("Begin called twice on the same Transaction") + } + + if t.closed { + panic("Begin called on closed Transaction") + } + + t.tx = DB.MustBeginTx(ctx, nil) + return t.tx +} + +func (t *transaction) Commit() error { + if t.closed { + panic("Commit called on closed transaction") + } + + if t.tx == nil { + panic("Commit called before Begin") + } + + defer t.close() + + return t.tx.Commit() +} + +func (t *transaction) Rollback() error { + if t.tx == nil { + panic("Rollback called before begin") + } + + defer t.close() + + return t.tx.Rollback() +} + +func (t *transaction) GetTx() *sqlx.Tx { + return t.tx +} + +type TxFunc func(Transaction) error + +func WithTransaction(ctx context.Context, fn TxFunc) error { + txn := NewTransaction(ctx) + txn.Begin(ctx) + + var err error + defer func() { + if p := recover(); p != nil { + // a panic occurred, rollback and repanic + txn.Rollback() + panic(p) + } else if err != nil { + // something went wrong, rollback + txn.Rollback() + } else { + // all good, commit + err = txn.Commit() + } + }() + + err = fn(txn) + return err +} diff --git a/pkg/models/model_joins.go b/pkg/models/model_joins.go index 2a1129b..0790260 100644 --- a/pkg/models/model_joins.go +++ b/pkg/models/model_joins.go @@ -1,16 +1,56 @@ package models -type PerformersScenes struct { - PerformerID int `db:"performer_id" json:"performer_id"` - SceneID int `db:"scene_id" json:"scene_id"` +import ( + "database/sql" + + "github.com/stashapp/stashdb/pkg/database" +) + +var ( + scenePerformerTable = database.NewTableJoin(sceneTable, "scene_performers", sceneJoinKey, func() interface{} { + return &PerformerScene{} + }) + + performerSceneTable = scenePerformerTable.Inverse(performerJoinKey) + + sceneTagTable = database.NewTableJoin(sceneTable, "scene_tags", sceneJoinKey, func() interface{} { + return &SceneTag{} + }) + + tagSceneTable = sceneTagTable.Inverse(tagJoinKey) +) + +type PerformerScene struct { + PerformerID int64 `db:"performer_id" json:"performer_id"` + As sql.NullString `db:"as" json:"as"` + SceneID int64 `db:"scene_id" json:"scene_id"` } -type ScenesTags struct { - SceneID int `db:"scene_id" json:"scene_id"` - TagID int `db:"tag_id" json:"tag_id"` +type PerformersScenes []*PerformerScene + +func (p PerformersScenes) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } } -type SceneMarkersTags struct { - SceneMarkerID int `db:"scene_marker_id" json:"scene_marker_id"` - TagID int `db:"tag_id" json:"tag_id"` +func (p *PerformersScenes) Add(o interface{}) { + *p = append(*p, o.(*PerformerScene)) +} + +type SceneTag struct { + SceneID int64 `db:"scene_id" json:"scene_id"` + TagID int64 `db:"tag_id" json:"tag_id"` +} + +type ScenesTags []*SceneTag + +func (p ScenesTags) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *ScenesTags) Add(o interface{}) { + *p = append(*p, o.(*SceneTag)) } diff --git a/pkg/models/model_performer.go b/pkg/models/model_performer.go index d96c086..e6aa1d4 100644 --- a/pkg/models/model_performer.go +++ b/pkg/models/model_performer.go @@ -2,6 +2,36 @@ package models import ( "database/sql" + + "github.com/stashapp/stashdb/pkg/database" + "github.com/stashapp/stashdb/pkg/utils" +) + +const ( + performerTable = "performers" + performerJoinKey = "performer_id" +) + +var ( + performerDBTable = database.NewTable(performerTable, func() interface{} { + return &Performer{} + }) + + performerAliasTable = database.NewTableJoin(performerTable, "performer_aliases", performerJoinKey, func() interface{} { + return &PerformerAlias{} + }) + + performerUrlTable = database.NewTableJoin(performerTable, "performer_urls", performerJoinKey, func() interface{} { + return &PerformerUrl{} + }) + + performerTattooTable = database.NewTableJoin(performerTable, "performer_tattoos", performerJoinKey, func() interface{} { + return &PerformerBodyMod{} + }) + + performerPiercingTable = database.NewTableJoin(performerTable, "performer_piercings", performerJoinKey, func() interface{} { + return &PerformerBodyMod{} + }) ) type Performer struct { @@ -28,39 +58,92 @@ type Performer struct { UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` } -type PerformerAliases struct { +func (Performer) GetTable() database.Table { + return performerDBTable +} + +func (p Performer) GetID() int64 { + return p.ID +} + +type Performers []*Performer + +func (p Performers) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *Performers) Add(o interface{}) { + *p = append(*p, o.(*Performer)) +} + +type PerformerAlias struct { PerformerID int64 `db:"performer_id" json:"performer_id"` Alias string `db:"alias" json:"alias"` } -func CreatePerformerAliases(performerId int64, aliases []string) []PerformerAliases { - var ret []PerformerAliases +type PerformerAliases []*PerformerAlias - for _, alias := range aliases { - ret = append(ret, PerformerAliases{PerformerID: performerId, Alias: alias}) +func (p PerformerAliases) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *PerformerAliases) Add(o interface{}) { + *p = append(*p, o.(*PerformerAlias)) +} + +func (p PerformerAliases) ToAliases() []string { + var ret []string + for _, v := range p { + ret = append(ret, v.Alias) } return ret } -type PerformerUrls struct { +func CreatePerformerAliases(performerId int64, aliases []string) PerformerAliases { + var ret PerformerAliases + + for _, alias := range aliases { + ret = append(ret, &PerformerAlias{PerformerID: performerId, Alias: alias}) + } + + return ret +} + +type PerformerUrl struct { PerformerID int64 `db:"performer_id" json:"performer_id"` URL string `db:"url" json:"url"` Type string `db:"type" json:"type"` } -func (p *PerformerUrls) ToURL() URL { +func (p *PerformerUrl) ToURL() URL { return URL{ URL: p.URL, Type: p.Type, } } -func CreatePerformerUrls(performerId int64, urls []*URLInput) []PerformerUrls { - var ret []PerformerUrls +type PerformerUrls []*PerformerUrl + +func (p PerformerUrls) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *PerformerUrls) Add(o interface{}) { + *p = append(*p, o.(*PerformerUrl)) +} + +func CreatePerformerUrls(performerId int64, urls []*URLInput) PerformerUrls { + var ret PerformerUrls for _, urlInput := range urls { - ret = append(ret, PerformerUrls{ + ret = append(ret, &PerformerUrl{ PerformerID: performerId, URL: urlInput.URL, Type: urlInput.Type, @@ -70,13 +153,13 @@ func CreatePerformerUrls(performerId int64, urls []*URLInput) []PerformerUrls { return ret } -type PerformerBodyMods struct { +type PerformerBodyMod struct { PerformerID int64 `db:"performer_id" json:"performer_id"` Location string `db:"location" json:"location"` Description sql.NullString `db:"description" json:"description"` } -func (m PerformerBodyMods) ToBodyModification() BodyModification { +func (m PerformerBodyMod) ToBodyModification() BodyModification { ret := BodyModification{ Location: m.Location, } @@ -87,8 +170,20 @@ func (m PerformerBodyMods) ToBodyModification() BodyModification { return ret } -func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) []PerformerBodyMods { - var ret []PerformerBodyMods +type PerformerBodyMods []*PerformerBodyMod + +func (p PerformerBodyMods) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *PerformerBodyMods) Add(o interface{}) { + *p = append(*p, o.(*PerformerBodyMod)) +} + +func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) PerformerBodyMods { + var ret PerformerBodyMods for _, bmInput := range urls { description := sql.NullString{} @@ -97,7 +192,7 @@ func CreatePerformerBodyMods(performerId int64, urls []*BodyModificationInput) [ description.String = *bmInput.Description description.Valid = true } - ret = append(ret, PerformerBodyMods{ + ret = append(ret, &PerformerBodyMod{ PerformerID: performerId, Location: bmInput.Location, Description: description, @@ -168,9 +263,27 @@ func (p Performer) ResolveMeasurements() Measurements { return ret } -func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) { +func (p *Performer) TranslateImageData(inputData *string) ([]byte, error) { + var imageData []byte + var err error + + _, imageData, err = utils.ProcessBase64Image(*inputData) + + return imageData, err +} + +func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) error { CopyFull(p, input) + if input.Image != nil { + var err error + p.Image, err = p.TranslateImageData(input.Image) + + if err != nil { + return err + } + } + if input.Birthdate != nil { p.setBirthdate(*input.Birthdate) } @@ -178,11 +291,22 @@ func (p *Performer) CopyFromCreateInput(input PerformerCreateInput) { if input.Measurements != nil { p.setMeasurements(*input.Measurements) } + + return nil } -func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) { +func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) error { CopyFull(p, input) + if input.Image != nil { + var err error + p.Image, err = p.TranslateImageData(input.Image) + + if err != nil { + return err + } + } + if input.Birthdate != nil { p.setBirthdate(*input.Birthdate) } @@ -190,4 +314,6 @@ func (p *Performer) CopyFromUpdateInput(input PerformerUpdateInput) { if input.Measurements != nil { p.setMeasurements(*input.Measurements) } + + return nil } diff --git a/pkg/models/model_scene.go b/pkg/models/model_scene.go new file mode 100644 index 0000000..ea9a7f6 --- /dev/null +++ b/pkg/models/model_scene.go @@ -0,0 +1,158 @@ +package models + +import ( + "database/sql" + "strconv" + + "github.com/stashapp/stashdb/pkg/database" +) + +const ( + sceneTable = "scenes" + sceneJoinKey = "scene_id" +) + +var ( + sceneDBTable = database.NewTable(sceneTable, func() interface{} { + return &Scene{} + }) + + sceneFingerprintTable = database.NewTableJoin(sceneTable, "scene_fingerprints", sceneJoinKey, func() interface{} { + return &SceneFingerprint{} + }) +) + +type Scene struct { + ID int64 `db:"id" json:"id"` + Title sql.NullString `db:"title" json:"title"` + Details sql.NullString `db:"details" json:"details"` + URL sql.NullString `db:"url" json:"url"` + Date SQLiteDate `db:"date" json:"date"` + StudioID sql.NullInt64 `db:"studio_id,omitempty" json:"studio_id"` + CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` + UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` +} + +func (Scene) GetTable() database.Table { + return sceneDBTable +} + +func (p Scene) GetID() int64 { + return p.ID +} + +type Scenes []*Scene + +func (p Scenes) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *Scenes) Add(o interface{}) { + *p = append(*p, o.(*Scene)) +} + +type SceneFingerprint struct { + SceneID int64 `db:"scene_id" json:"scene_id"` + Hash string `db:"hash" json:"hash"` + Algorithm string `db:"algorithm" json:"algorithm"` +} + +func (p SceneFingerprint) ToFingerprint() *Fingerprint { + return &Fingerprint{ + Algorithm: FingerprintAlgorithm(p.Algorithm), + Hash: p.Hash, + } +} + +type SceneFingerprints []*SceneFingerprint + +func (p SceneFingerprints) Each(fn func(interface{})) { + for _, v := range p { + fn(*v) + } +} + +func (p *SceneFingerprints) Add(o interface{}) { + *p = append(*p, o.(*SceneFingerprint)) +} + +func (p SceneFingerprints) ToFingerprints() []*Fingerprint { + var ret []*Fingerprint + for _, v := range p { + ret = append(ret, v.ToFingerprint()) + } + + return ret +} + +func CreateSceneFingerprints(sceneID int64, fingerprints []*FingerprintInput) SceneFingerprints { + var ret SceneFingerprints + + for _, fingerprint := range fingerprints { + ret = append(ret, &SceneFingerprint{ + SceneID: sceneID, + Hash: fingerprint.Hash, + Algorithm: fingerprint.Algorithm.String(), + }) + } + + return ret +} + +func CreateSceneTags(sceneID int64, tagIds []string) ScenesTags { + var tagJoins ScenesTags + for _, tid := range tagIds { + tagID, _ := strconv.ParseInt(tid, 10, 64) + tagJoin := &SceneTag{ + SceneID: sceneID, + TagID: tagID, + } + tagJoins = append(tagJoins, tagJoin) + } + + return tagJoins +} + +func CreateScenePerformers(sceneID int64, appearances []*PerformerAppearanceInput) PerformersScenes { + var performerJoins PerformersScenes + for _, a := range appearances { + performerID, _ := strconv.ParseInt(a.PerformerID, 10, 64) + performerJoin := &PerformerScene{ + SceneID: sceneID, + PerformerID: performerID, + } + + if a.As != nil { + performerJoin.As = sql.NullString{Valid: true, String: *a.As} + } + + performerJoins = append(performerJoins, performerJoin) + } + + return performerJoins +} + +func (p *Scene) IsEditTarget() { +} + +func (p *Scene) setDate(date string) { + p.Date = SQLiteDate{String: date, Valid: true} +} + +func (p *Scene) CopyFromCreateInput(input SceneCreateInput) { + CopyFull(p, input) + + if input.Date != nil { + p.setDate(*input.Date) + } +} + +func (p *Scene) CopyFromUpdateInput(input SceneUpdateInput) { + CopyFull(p, input) + + if input.Date != nil { + p.setDate(*input.Date) + } +} diff --git a/pkg/models/model_studio.go b/pkg/models/model_studio.go new file mode 100644 index 0000000..46edb4d --- /dev/null +++ b/pkg/models/model_studio.go @@ -0,0 +1,116 @@ +package models + +import ( + "database/sql" + "strconv" + + "github.com/stashapp/stashdb/pkg/database" +) + +const ( + studioTable = "studios" + studioJoinKey = "studio_id" +) + +var ( + studioDBTable = database.NewTable(studioTable, func() interface{} { + return &Studio{} + }) + + studioUrlTable = database.NewTableJoin(studioTable, "studio_urls", studioJoinKey, func() interface{} { + return &StudioUrl{} + }) +) + +type Studio struct { + ID int64 `db:"id" json:"id"` + Name string `db:"name" json:"name"` + Image []byte `db:"image" json:"image"` + ParentStudioID sql.NullInt64 `db:"parent_studio_id,omitempty" json:"parent_studio_id"` + CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` + UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` +} + +func (Studio) GetTable() database.Table { + return studioDBTable +} + +func (p Studio) GetID() int64 { + return p.ID +} + +type Studios []*Studio + +func (p Studios) Each(fn func(interface{})) { + for _, v := range p { + fn(v) + } +} + +func (p *Studios) Add(o interface{}) { + *p = append(*p, o.(*Studio)) +} + +type StudioUrl struct { + StudioID int64 `db:"studio_id" json:"studio_id"` + URL string `db:"url" json:"url"` + Type string `db:"type" json:"type"` +} + +func (p *StudioUrl) ToURL() URL { + return URL{ + URL: p.URL, + Type: p.Type, + } +} + +type StudioUrls []StudioUrl + +func (p StudioUrls) Each(fn func(interface{})) { + for _, v := range p { + fn(v) + } +} + +func (p *StudioUrls) Add(o interface{}) { + *p = append(*p, o.(StudioUrl)) +} + +func CreateStudioUrls(studioId int64, urls []*URLInput) []StudioUrl { + var ret []StudioUrl + + for _, urlInput := range urls { + ret = append(ret, StudioUrl{ + StudioID: studioId, + URL: urlInput.URL, + Type: urlInput.Type, + }) + } + + return ret +} + +func (p *Studio) IsEditTarget() { +} + +func (p *Studio) CopyFromCreateInput(input StudioCreateInput) { + CopyFull(p, input) + + if input.ParentID != nil { + parentID, err := strconv.ParseInt(*input.ParentID, 10, 64) + if err == nil { + p.ParentStudioID = sql.NullInt64{Int64: parentID, Valid: true} + } + } +} + +func (p *Studio) CopyFromUpdateInput(input StudioUpdateInput) { + CopyFull(p, input) + + if input.ParentID != nil { + parentID, err := strconv.ParseInt(*input.ParentID, 10, 64) + if err == nil { + p.ParentStudioID = sql.NullInt64{Int64: parentID, Valid: true} + } + } +} diff --git a/pkg/models/model_tag.go b/pkg/models/model_tag.go new file mode 100644 index 0000000..e8ebfd8 --- /dev/null +++ b/pkg/models/model_tag.go @@ -0,0 +1,97 @@ +package models + +import ( + "database/sql" + + "github.com/stashapp/stashdb/pkg/database" +) + +const ( + tagTable = "tags" + tagJoinKey = "tag_id" +) + +var ( + tagDBTable = database.NewTable(tagTable, func() interface{} { + return &Tag{} + }) + + tagAliasTable = database.NewTableJoin(tagTable, "tag_aliases", tagJoinKey, func() interface{} { + return &TagAlias{} + }) +) + +type Tag struct { + ID int64 `db:"id" json:"id"` + Name string `db:"name" json:"name"` + Description sql.NullString `db:"description" json:"description"` + CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` + UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` +} + +func (Tag) GetTable() database.Table { + return tagDBTable +} + +func (p Tag) GetID() int64 { + return p.ID +} + +type Tags []*Tag + +func (p Tags) Each(fn func(interface{})) { + for _, v := range p { + fn(v) + } +} + +func (p *Tags) Add(o interface{}) { + *p = append(*p, o.(*Tag)) +} + +type TagAlias struct { + TagID int64 `db:"tag_id" json:"tag_id"` + Alias string `db:"alias" json:"alias"` +} + +type TagAliases []TagAlias + +func (p TagAliases) Each(fn func(interface{})) { + for _, v := range p { + fn(v) + } +} + +func (p *TagAliases) Add(o interface{}) { + *p = append(*p, o.(TagAlias)) +} + +func (p TagAliases) ToAliases() []string { + var ret []string + for _, v := range p { + ret = append(ret, v.Alias) + } + + return ret +} + +func CreateTagAliases(tagId int64, aliases []string) []TagAlias { + var ret []TagAlias + + for _, alias := range aliases { + ret = append(ret, TagAlias{TagID: tagId, Alias: alias}) + } + + return ret +} + +func (p *Tag) IsEditTarget() { +} + +func (p *Tag) CopyFromCreateInput(input TagCreateInput) { + CopyFull(p, input) +} + +func (p *Tag) CopyFromUpdateInput(input TagUpdateInput) { + CopyFull(p, input) +} diff --git a/pkg/models/querybuilder_joins.go b/pkg/models/querybuilder_joins.go index a16961c..5263939 100644 --- a/pkg/models/querybuilder_joins.go +++ b/pkg/models/querybuilder_joins.go @@ -1,130 +1,41 @@ package models -import "github.com/jmoiron/sqlx" +import ( + "github.com/jmoiron/sqlx" -type JoinsQueryBuilder struct{} + "github.com/stashapp/stashdb/pkg/database" +) -func NewJoinsQueryBuilder() JoinsQueryBuilder { - return JoinsQueryBuilder{} +type JoinsQueryBuilder struct { + dbi database.DBI } -func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins []PerformersScenes, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO performers_scenes (performer_id, scene_id) VALUES (:performer_id, :scene_id)`, - join, - ) - if err != nil { - return err - } +func NewJoinsQueryBuilder(tx *sqlx.Tx) JoinsQueryBuilder { + return JoinsQueryBuilder{ + dbi: database.DBIWithTxn(tx), } - return nil } -func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreatePerformersScenes(updatedJoins, tx) +func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins PerformersScenes) error { + return qb.dbi.InsertJoins(scenePerformerTable, &newJoins) } -func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID) - return err +func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int64, updatedJoins PerformersScenes) error { + return qb.dbi.ReplaceJoins(scenePerformerTable, sceneID, &updatedJoins) } -func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins []ScenesTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO scenes_tags (scene_id, tag_id) VALUES (:scene_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil +func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int64) error { + return qb.dbi.DeleteJoins(scenePerformerTable, sceneID) } -func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int, updatedJoins []ScenesTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreateScenesTags(updatedJoins, tx) +func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins ScenesTags) error { + return qb.dbi.InsertJoins(sceneTagTable, &newJoins) } -func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID) - - return err +func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int64, updatedJoins ScenesTags) error { + return qb.dbi.ReplaceJoins(sceneTagTable, sceneID, &updatedJoins) } -func (qb *JoinsQueryBuilder) CreateSceneMarkersTags(newJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO scene_markers_tags (scene_marker_id, tag_id) VALUES (:scene_marker_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID) - if err != nil { - return err - } - return qb.CreateSceneMarkersTags(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroySceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID) - return err -} - -func (qb *JoinsQueryBuilder) DestroyScenesGalleries(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Unset the existing scene id from galleries - _, err := tx.Exec("UPDATE galleries SET scene_id = null WHERE scene_id = ?", sceneID) - - return err -} - -func (qb *JoinsQueryBuilder) DestroyScenesMarkers(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the scene marker tags - _, err := tx.Exec("DELETE t FROM scene_markers_tags t join scene_markers m on t.scene_marker_id = m.id WHERE m.scene_id = ?", sceneID) - - // Delete the existing joins - _, err = tx.Exec("DELETE FROM scene_markers WHERE scene_id = ?", sceneID) - - return err +func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int64) error { + return qb.dbi.DeleteJoins(sceneTagTable, sceneID) } diff --git a/pkg/models/querybuilder_performer.go b/pkg/models/querybuilder_performer.go index 110413d..643ac20 100644 --- a/pkg/models/querybuilder_performer.go +++ b/pkg/models/querybuilder_performer.go @@ -1,129 +1,84 @@ package models import ( - "database/sql" "strconv" "time" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" "github.com/stashapp/stashdb/pkg/database" + + "github.com/jmoiron/sqlx" ) -type PerformerQueryBuilder struct{} - -const performerTable = "performers" -const performerAliasesJoinTable = "performer_aliases" -const performerUrlsJoinTable = "performer_urls" -const performerTattoosJoinTable = "performer_tattoos" -const performerPiercingsJoinTable = "performer_piercings" -const performerJoinKey = "performer_id" - -func NewPerformerQueryBuilder() PerformerQueryBuilder { - return PerformerQueryBuilder{} +type PerformerQueryBuilder struct { + dbi database.DBI } -func (qb *PerformerQueryBuilder) Create(newPerformer Performer, tx *sqlx.Tx) (*Performer, error) { - performerID, err := insertObject(tx, performerTable, newPerformer) +func NewPerformerQueryBuilder(tx *sqlx.Tx) PerformerQueryBuilder { + return PerformerQueryBuilder{ + dbi: database.DBIWithTxn(tx), + } +} - if err != nil { - return nil, errors.Wrap(err, "Error creating performer") +func (qb *PerformerQueryBuilder) toModel(ro interface{}) *Performer { + if ro != nil { + return ro.(*Performer) } - if err := getByID(tx, performerTable, performerID, &newPerformer); err != nil { - return nil, errors.Wrap(err, "Error getting performer after create") - } - return &newPerformer, nil + return nil } -func (qb *PerformerQueryBuilder) Update(updatedPerformer Performer, tx *sqlx.Tx) (*Performer, error) { - err := updateObjectByID(tx, performerTable, updatedPerformer) - - if err != nil { - return nil, errors.Wrap(err, "Error updating performer") - } - - if err := getByID(tx, performerTable, updatedPerformer.ID, &updatedPerformer); err != nil { - return nil, errors.Wrap(err, "Error getting performer after update") - } - return &updatedPerformer, nil +func (qb *PerformerQueryBuilder) Create(newPerformer Performer) (*Performer, error) { + ret, err := qb.dbi.Insert(newPerformer) + return qb.toModel(ret), err } -func (qb *PerformerQueryBuilder) Destroy(id int64, tx *sqlx.Tx) error { - return executeDeleteQuery(performerTable, id, tx) +func (qb *PerformerQueryBuilder) Update(updatedPerformer Performer) (*Performer, error) { + ret, err := qb.dbi.Update(updatedPerformer) + return qb.toModel(ret), err } -func (qb *PerformerQueryBuilder) CreateAliases(newJoins []PerformerAliases, tx *sqlx.Tx) error { - return insertJoins(tx, performerAliasesJoinTable, newJoins) +func (qb *PerformerQueryBuilder) Destroy(id int64) error { + return qb.dbi.Delete(id, performerDBTable) } -func (qb *PerformerQueryBuilder) UpdateAliases(performerID int64, updatedJoins []PerformerAliases, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - err := deleteObjectsByColumn(tx, performerAliasesJoinTable, performerJoinKey, performerID) - if err != nil { - return err - } - return qb.CreateAliases(updatedJoins, tx) +func (qb *PerformerQueryBuilder) CreateAliases(newJoins PerformerAliases) error { + return qb.dbi.InsertJoins(performerAliasTable, &newJoins) } -func (qb *PerformerQueryBuilder) CreateUrls(newJoins []PerformerUrls, tx *sqlx.Tx) error { - return insertJoins(tx, performerUrlsJoinTable, newJoins) +func (qb *PerformerQueryBuilder) UpdateAliases(performerID int64, updatedJoins PerformerAliases) error { + return qb.dbi.ReplaceJoins(performerAliasTable, performerID, &updatedJoins) } -func (qb *PerformerQueryBuilder) UpdateUrls(performerID int64, updatedJoins []PerformerUrls, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - err := deleteObjectsByColumn(tx, performerUrlsJoinTable, performerJoinKey, performerID) - if err != nil { - return err - } - return qb.CreateUrls(updatedJoins, tx) +func (qb *PerformerQueryBuilder) CreateUrls(newJoins PerformerUrls) error { + return qb.dbi.InsertJoins(performerUrlTable, &newJoins) } -func (qb *PerformerQueryBuilder) CreateTattoos(newJoins []PerformerBodyMods, tx *sqlx.Tx) error { - return insertJoins(tx, performerTattoosJoinTable, newJoins) +func (qb *PerformerQueryBuilder) UpdateUrls(performerID int64, updatedJoins PerformerUrls) error { + return qb.dbi.ReplaceJoins(performerUrlTable, performerID, &updatedJoins) } -func (qb *PerformerQueryBuilder) UpdateTattoos(performerID int64, updatedJoins []PerformerBodyMods, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - err := deleteObjectsByColumn(tx, performerTattoosJoinTable, performerJoinKey, performerID) - if err != nil { - return err - } - return qb.CreateTattoos(updatedJoins, tx) +func (qb *PerformerQueryBuilder) CreateTattoos(newJoins PerformerBodyMods) error { + return qb.dbi.InsertJoins(performerTattooTable, &newJoins) } -func (qb *PerformerQueryBuilder) CreatePiercings(newJoins []PerformerBodyMods, tx *sqlx.Tx) error { - return insertJoins(tx, performerPiercingsJoinTable, newJoins) +func (qb *PerformerQueryBuilder) UpdateTattoos(performerID int64, updatedJoins PerformerBodyMods) error { + return qb.dbi.ReplaceJoins(performerTattooTable, performerID, &updatedJoins) } -func (qb *PerformerQueryBuilder) UpdatePiercings(performerID int64, updatedJoins []PerformerBodyMods, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - err := deleteObjectsByColumn(tx, performerPiercingsJoinTable, performerJoinKey, performerID) - if err != nil { - return err - } - return qb.CreateTattoos(updatedJoins, tx) +func (qb *PerformerQueryBuilder) CreatePiercings(newJoins PerformerBodyMods) error { + return qb.dbi.InsertJoins(performerPiercingTable, &newJoins) } -func (qb *PerformerQueryBuilder) Find(id int) (*Performer, error) { - query := "SELECT * FROM performers WHERE id = ? LIMIT 1" - args := []interface{}{id} - results, err := qb.queryPerformers(query, args, nil) - if err != nil || len(results) < 1 { - return nil, err - } - return results[0], nil +func (qb *PerformerQueryBuilder) UpdatePiercings(performerID int64, updatedJoins PerformerBodyMods) error { + return qb.dbi.ReplaceJoins(performerPiercingTable, performerID, &updatedJoins) } -func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *PerformerQueryBuilder) Find(id int64) (*Performer, error) { + ret, err := qb.dbi.Find(id, performerDBTable) + return qb.toModel(ret), err +} + +func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int) (Performers, error) { query := ` SELECT performers.* FROM performers LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id @@ -132,19 +87,19 @@ func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Per GROUP BY performers.id ` args := []interface{}{sceneID} - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByNames(names []string, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *PerformerQueryBuilder) FindByNames(names []string) (Performers, error) { query := "SELECT * FROM performers WHERE name IN " + getInBinding(len(names)) var args []interface{} for _, name := range names { args = append(args, name) } - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByAliases(names []string, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *PerformerQueryBuilder) FindByAliases(names []string) (Performers, error) { query := `SELECT performers.* FROM performers left join performer_aliases on performers.id = performer_aliases.performer_id WHERE performer_aliases.alias IN ` + getInBinding(len(names)) @@ -153,24 +108,24 @@ func (qb *PerformerQueryBuilder) FindByAliases(names []string, tx *sqlx.Tx) ([]* for _, name := range names { args = append(args, name) } - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByName(name string, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *PerformerQueryBuilder) FindByName(name string) (Performers, error) { query := "SELECT * FROM performers WHERE upper(name) = upper(?)" var args []interface{} args = append(args, name) - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByAlias(name string, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *PerformerQueryBuilder) FindByAlias(name string) (Performers, error) { query := `SELECT performers.* FROM performers left join performer_aliases on performers.id = performer_aliases.performer_id WHERE upper(performer_aliases.alias) = UPPER(?)` var args []interface{} args = append(args, name) - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } func (qb *PerformerQueryBuilder) Count() (int, error) { @@ -233,27 +188,6 @@ func (qb *PerformerQueryBuilder) Query(performerFilter *PerformerFilterType, fin return performers, countResult } -func handleStringCriterion(column string, value *StringCriterionInput, query *queryBuilder) { - if value != nil { - if modifier := value.Modifier.String(); value.Modifier.IsValid() { - switch modifier { - case "EQUALS": - clause, thisArgs := getSearchBinding([]string{column}, value.Value, false) - query.addWhere(clause) - query.addArg(thisArgs...) - case "NOT_EQUALS": - clause, thisArgs := getSearchBinding([]string{column}, value.Value, true) - query.addWhere(clause) - query.addArg(thisArgs...) - case "IS_NULL": - query.addWhere(column + " IS NULL") - case "NOT_NULL": - query.addWhere(column + " IS NOT NULL") - } - } - } -} - func getBirthYearFilterClause(criterionModifier CriterionModifier, value int) ([]string, []interface{}) { var clauses []string var args []interface{} @@ -338,142 +272,36 @@ func (qb *PerformerQueryBuilder) getPerformerSort(findFilter *QuerySpec) string return getSort(sort, direction, "performers") } -func (qb *PerformerQueryBuilder) queryPerformers(query string, args []interface{}, tx *sqlx.Tx) ([]*Performer, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - performers := make([]*Performer, 0) - for rows.Next() { - performer := Performer{} - if err := rows.StructScan(&performer); err != nil { - return nil, err - } - performers = append(performers, &performer) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return performers, nil +func (qb *PerformerQueryBuilder) queryPerformers(query string, args []interface{}) (Performers, error) { + output := Performers{} + err := qb.dbi.RawQuery(performerDBTable, query, args, &output) + return output, err } func (qb *PerformerQueryBuilder) GetAliases(id int64) ([]string, error) { - query := "SELECT alias FROM performer_aliases WHERE performer_id = ?" - args := []interface{}{id} + joins := PerformerAliases{} + err := qb.dbi.FindJoins(performerAliasTable, id, &joins) - var rows *sqlx.Rows - var err error - rows, err = database.DB.Queryx(query, args...) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - aliases := make([]string, 0) - for rows.Next() { - var alias string - - if err := rows.Scan(&alias); err != nil { - return nil, err - } - aliases = append(aliases, alias) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return aliases, nil + return joins.ToAliases(), err } -func (qb *PerformerQueryBuilder) GetUrls(id int64) ([]PerformerUrls, error) { - query := "SELECT url, type FROM performer_urls WHERE performer_id = ?" - args := []interface{}{id} +func (qb *PerformerQueryBuilder) GetUrls(id int64) (PerformerUrls, error) { + joins := PerformerUrls{} + err := qb.dbi.FindJoins(performerUrlTable, id, &joins) - var rows *sqlx.Rows - var err error - rows, err = database.DB.Queryx(query, args...) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - urls := make([]PerformerUrls, 0) - for rows.Next() { - var performerUrl PerformerUrls - - if err := rows.Scan(&performerUrl); err != nil { - return nil, err - } - urls = append(urls, performerUrl) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return urls, nil + return joins, err } -func translateBodyMods(rows *sqlx.Rows) ([]PerformerBodyMods, error) { - ret := make([]PerformerBodyMods, 0) - for rows.Next() { - var performerBodyMod PerformerBodyMods +func (qb *PerformerQueryBuilder) GetTattoos(id int64) (PerformerBodyMods, error) { + joins := PerformerBodyMods{} + err := qb.dbi.FindJoins(performerTattooTable, id, &joins) - if err := rows.Scan(&performerBodyMod); err != nil { - return nil, err - } - ret = append(ret, performerBodyMod) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return ret, nil + return joins, err } -func (qb *PerformerQueryBuilder) GetTattoos(id int64) ([]PerformerBodyMods, error) { - query := "SELECT location, description FROM performer_tattoos WHERE performer_id = ?" - args := []interface{}{id} +func (qb *PerformerQueryBuilder) GetPiercings(id int64) (PerformerBodyMods, error) { + joins := PerformerBodyMods{} + err := qb.dbi.FindJoins(performerPiercingTable, id, &joins) - var rows *sqlx.Rows - var err error - rows, err = database.DB.Queryx(query, args...) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - return translateBodyMods(rows) -} - -func (qb *PerformerQueryBuilder) GetPiercings(id int64) ([]PerformerBodyMods, error) { - query := "SELECT location, description FROM performer_piercings WHERE performer_id = ?" - args := []interface{}{id} - - var rows *sqlx.Rows - var err error - rows, err = database.DB.Queryx(query, args...) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - return translateBodyMods(rows) + return joins, err } diff --git a/pkg/models/querybuilder_scene.go b/pkg/models/querybuilder_scene.go new file mode 100644 index 0000000..2a9441c --- /dev/null +++ b/pkg/models/querybuilder_scene.go @@ -0,0 +1,194 @@ +package models + +import ( + "github.com/jmoiron/sqlx" + "github.com/stashapp/stashdb/pkg/database" +) + +type SceneQueryBuilder struct { + dbi database.DBI +} + +func NewSceneQueryBuilder(tx *sqlx.Tx) SceneQueryBuilder { + return SceneQueryBuilder{ + dbi: database.DBIWithTxn(tx), + } +} + +func (qb *SceneQueryBuilder) toModel(ro interface{}) *Scene { + if ro != nil { + return ro.(*Scene) + } + + return nil +} + +func (qb *SceneQueryBuilder) Create(newScene Scene) (*Scene, error) { + ret, err := qb.dbi.Insert(newScene) + return qb.toModel(ret), err +} + +func (qb *SceneQueryBuilder) Update(updatedScene Scene) (*Scene, error) { + ret, err := qb.dbi.Update(updatedScene) + return qb.toModel(ret), err +} + +func (qb *SceneQueryBuilder) Destroy(id int64) error { + return qb.dbi.Delete(id, sceneDBTable) +} + +func (qb *SceneQueryBuilder) CreateFingerprints(newJoins SceneFingerprints) error { + return qb.dbi.InsertJoins(sceneFingerprintTable, &newJoins) +} + +func (qb *SceneQueryBuilder) UpdateFingerprints(sceneID int64, updatedJoins SceneFingerprints) error { + return qb.dbi.ReplaceJoins(sceneFingerprintTable, sceneID, &updatedJoins) +} + +func (qb *SceneQueryBuilder) Find(id int64) (*Scene, error) { + ret, err := qb.dbi.Find(id, sceneDBTable) + return qb.toModel(ret), err +} + +func (qb *SceneQueryBuilder) FindByFingerprint(algorithm FingerprintAlgorithm, hash string) ([]*Scene, error) { + query := ` + SELECT scenes.* FROM scenes + LEFT JOIN scene_fingerprints as scenes_join on scenes_join.scene_id = scenes.id + WHERE scenes_join.algorithm = ? AND scenes_join.hash = ?` + var args []interface{} + args = append(args, algorithm.String()) + args = append(args, hash) + return qb.queryScenes(query, args) +} + +// func (qb *SceneQueryBuilder) FindByStudioID(sceneID int) ([]*Scene, error) { +// query := ` +// SELECT scenes.* FROM scenes +// LEFT JOIN scenes_scenes as scenes_join on scenes_join.scene_id = scenes.id +// LEFT JOIN scenes on scenes_join.scene_id = scenes.id +// WHERE scenes.id = ? +// GROUP BY scenes.id +// ` +// args := []interface{}{sceneID} +// return qb.queryScenes(query, args) +// } + +// func (qb *SceneQueryBuilder) FindByChecksum(checksum string) (*Scene, error) { +// query := `SELECT scenes.* FROM scenes +// left join scene_checksums on scenes.id = scene_checksums.scene_id +// WHERE scene_checksums.checksum = ?` + +// var args []interface{} +// args = append(args, checksum) + +// results, err := qb.queryScenes(query, args) +// if err != nil || len(results) < 1 { +// return nil, err +// } +// return results[0], nil +// } + +// func (qb *SceneQueryBuilder) FindByChecksums(checksums []string) ([]*Scene, error) { +// query := `SELECT scenes.* FROM scenes +// left join scene_checksums on scenes.id = scene_checksums.scene_id +// WHERE scene_checksums.checksum IN ` + getInBinding(len(checksums)) + +// var args []interface{} +// for _, name := range checksums { +// args = append(args, name) +// } +// return qb.queryScenes(query, args) +// } + +func (qb *SceneQueryBuilder) FindByTitle(name string) ([]*Scene, error) { + query := "SELECT * FROM scenes WHERE upper(title) = upper(?)" + var args []interface{} + args = append(args, name) + return qb.queryScenes(query, args) +} + +func (qb *SceneQueryBuilder) Count() (int, error) { + return runCountQuery(buildCountQuery("SELECT scenes.id FROM scenes"), nil) +} + +func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *QuerySpec) ([]*Scene, int) { + if sceneFilter == nil { + sceneFilter = &SceneFilterType{} + } + if findFilter == nil { + findFilter = &QuerySpec{} + } + + query := queryBuilder{ + tableName: sceneTable, + } + + query.body = selectDistinctIDs(sceneTable) + + if q := sceneFilter.Text; q != nil && *q != "" { + searchColumns := []string{"scenes.title", "scenes.details"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + query.addWhere(clause) + query.addArg(thisArgs...) + } + + if q := sceneFilter.Title; q != nil && *q != "" { + searchColumns := []string{"scenes.title"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + query.addWhere(clause) + query.addArg(thisArgs...) + } + + if q := sceneFilter.URL; q != nil && *q != "" { + searchColumns := []string{"scenes.url"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + query.addWhere(clause) + query.addArg(thisArgs...) + } + + // TODO - other filters + + query.sortAndPagination = qb.getSceneSort(findFilter) + getPagination(findFilter) + idsResult, countResult := query.executeFind() + + var scenes []*Scene + for _, id := range idsResult { + scene, _ := qb.Find(id) + scenes = append(scenes, scene) + } + + return scenes, countResult +} + +func (qb *SceneQueryBuilder) getSceneSort(findFilter *QuerySpec) string { + var sort string + var direction string + if findFilter == nil { + sort = "title" + direction = "ASC" + } else { + sort = findFilter.GetSort("title") + direction = findFilter.GetDirection() + } + return getSort(sort, direction, "scenes") +} + +func (qb *SceneQueryBuilder) queryScenes(query string, args []interface{}) (Scenes, error) { + output := Scenes{} + err := qb.dbi.RawQuery(sceneDBTable, query, args, &output) + return output, err +} + +func (qb *SceneQueryBuilder) GetFingerprints(id int64) ([]*Fingerprint, error) { + joins := SceneFingerprints{} + err := qb.dbi.FindJoins(sceneFingerprintTable, id, &joins) + + return joins.ToFingerprints(), err +} + +func (qb *SceneQueryBuilder) GetPerformers(id int64) (PerformersScenes, error) { + joins := PerformersScenes{} + err := qb.dbi.FindJoins(scenePerformerTable, id, &joins) + + return joins, err +} diff --git a/pkg/models/querybuilder_sql.go b/pkg/models/querybuilder_sql.go index b141ed0..a5c2586 100644 --- a/pkg/models/querybuilder_sql.go +++ b/pkg/models/querybuilder_sql.go @@ -27,7 +27,7 @@ type queryBuilder struct { sortAndPagination string } -func (qb queryBuilder) executeFind() ([]int, int) { +func (qb queryBuilder) executeFind() ([]int64, int) { return executeFindQuery(qb.tableName, qb.body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses) } @@ -43,6 +43,27 @@ func (qb *queryBuilder) addArg(args ...interface{}) { qb.args = append(qb.args, args...) } +func handleStringCriterion(column string, value *StringCriterionInput, query *queryBuilder) { + if value != nil { + if modifier := value.Modifier.String(); value.Modifier.IsValid() { + switch modifier { + case "EQUALS": + clause, thisArgs := getSearchBinding([]string{column}, value.Value, false) + query.addWhere(clause) + query.addArg(thisArgs...) + case "NOT_EQUALS": + clause, thisArgs := getSearchBinding([]string{column}, value.Value, true) + query.addWhere(clause) + query.addArg(thisArgs...) + case "IS_NULL": + query.addWhere(column + " IS NULL") + case "NOT_NULL": + query.addWhere(column + " IS NOT NULL") + } + } + } +} + func insertObject(tx *sqlx.Tx, table string, object interface{}) (int64, error) { ensureTx(tx) fields, values := SQLGenKeysCreate(object) @@ -96,7 +117,7 @@ func deleteObjectsByColumn(tx *sqlx.Tx, table string, column string, value inter } func getByID(tx *sqlx.Tx, table string, id int64, object interface{}) error { - return tx.Get(object, `SELECT * FROM performers WHERE id = ? LIMIT 1`, id) + return tx.Get(object, `SELECT * FROM `+table+` WHERE id = ? LIMIT 1`, id) } func selectAll(tableName string) string { @@ -166,9 +187,7 @@ func getSort(sort string, direction string, tableName string) string { } else { colName := getColumn(tableName, sort) var additional string - if tableName == "scenes" { - additional = ", bitrate DESC, framerate DESC, rating DESC, duration DESC" - } else if tableName == "scene_markers" { + if tableName == "scene_markers" { additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC" } return " ORDER BY " + colName + " " + direction + additional @@ -236,15 +255,15 @@ func getInBinding(length int) string { return "(" + bindings + ")" } -func runIdsQuery(query string, args []interface{}) ([]int, error) { +func runIdsQuery(query string, args []interface{}) ([]int64, error) { var result []struct { - Int int `db:"id"` + Int int64 `db:"id"` } if err := database.DB.Select(&result, query, args...); err != nil && err != sql.ErrNoRows { - return []int{}, err + return []int64{}, err } - vsm := make([]int, len(result)) + vsm := make([]int64, len(result)) for i, v := range result { vsm[i] = v.Int } @@ -263,7 +282,7 @@ func runCountQuery(query string, args []interface{}) (int, error) { return result.Int, nil } -func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int, int) { +func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int64, int) { if len(whereClauses) > 0 { body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR } diff --git a/pkg/models/querybuilder_studio.go b/pkg/models/querybuilder_studio.go new file mode 100644 index 0000000..9ff8698 --- /dev/null +++ b/pkg/models/querybuilder_studio.go @@ -0,0 +1,152 @@ +package models + +import ( + "github.com/jmoiron/sqlx" + "github.com/stashapp/stashdb/pkg/database" +) + +type StudioQueryBuilder struct { + dbi database.DBI +} + +func NewStudioQueryBuilder(tx *sqlx.Tx) StudioQueryBuilder { + return StudioQueryBuilder{ + dbi: database.DBIWithTxn(tx), + } +} + +func (qb *StudioQueryBuilder) toModel(ro interface{}) *Studio { + if ro != nil { + return ro.(*Studio) + } + + return nil +} + +func (qb *StudioQueryBuilder) Create(newStudio Studio) (*Studio, error) { + ret, err := qb.dbi.Insert(newStudio) + return qb.toModel(ret), err +} + +func (qb *StudioQueryBuilder) Update(updatedStudio Studio) (*Studio, error) { + ret, err := qb.dbi.Update(updatedStudio) + return qb.toModel(ret), err +} + +func (qb *StudioQueryBuilder) Destroy(id int64) error { + return qb.dbi.Delete(id, studioDBTable) +} + +func (qb *StudioQueryBuilder) CreateUrls(newJoins StudioUrls) error { + return qb.dbi.InsertJoins(studioUrlTable, &newJoins) +} + +func (qb *StudioQueryBuilder) UpdateUrls(studio int64, updatedJoins StudioUrls) error { + return qb.dbi.ReplaceJoins(studioUrlTable, studio, &updatedJoins) +} + +func (qb *StudioQueryBuilder) Find(id int64) (*Studio, error) { + ret, err := qb.dbi.Find(id, studioDBTable) + return qb.toModel(ret), err +} + +func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (Studios, error) { + query := ` + SELECT studios.* FROM studios + LEFT JOIN scenes on scenes.studio_id = studios.id + WHERE scenes.id = ? + GROUP BY studios.id + ` + args := []interface{}{sceneID} + return qb.queryStudios(query, args) +} + +func (qb *StudioQueryBuilder) FindByNames(names []string) (Studios, error) { + query := "SELECT * FROM studios WHERE name IN " + getInBinding(len(names)) + var args []interface{} + for _, name := range names { + args = append(args, name) + } + return qb.queryStudios(query, args) +} + +func (qb *StudioQueryBuilder) FindByName(name string) (*Studio, error) { + query := "SELECT * FROM studios WHERE upper(name) = upper(?)" + var args []interface{} + args = append(args, name) + results, err := qb.queryStudios(query, args) + if err != nil || len(results) < 1 { + return nil, err + } + return results[0], nil +} + +func (qb *StudioQueryBuilder) FindByParentID(id int64) (Studios, error) { + query := "SELECT * FROM studios WHERE parent_studio_id = ?" + var args []interface{} + args = append(args, id) + return qb.queryStudios(query, args) +} + +func (qb *StudioQueryBuilder) Count() (int, error) { + return runCountQuery(buildCountQuery("SELECT studios.id FROM studios"), nil) +} + +func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *QuerySpec) (Studios, int) { + if studioFilter == nil { + studioFilter = &StudioFilterType{} + } + if findFilter == nil { + findFilter = &QuerySpec{} + } + + query := queryBuilder{ + tableName: studioTable, + } + + query.body = selectDistinctIDs(studioTable) + + if q := studioFilter.Name; q != nil && *q != "" { + searchColumns := []string{"studios.name"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + query.addWhere(clause) + query.addArg(thisArgs...) + } + + query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter) + idsResult, countResult := query.executeFind() + + var studios []*Studio + for _, id := range idsResult { + studio, _ := qb.Find(id) + studios = append(studios, studio) + } + + return studios, countResult +} + +func (qb *StudioQueryBuilder) getStudioSort(findFilter *QuerySpec) string { + var sort string + var direction string + if findFilter == nil { + sort = "name" + direction = "ASC" + } else { + sort = findFilter.GetSort("name") + direction = findFilter.GetDirection() + } + return getSort(sort, direction, "studios") +} + +func (qb *StudioQueryBuilder) queryStudios(query string, args []interface{}) (Studios, error) { + var output Studios + err := qb.dbi.RawQuery(studioDBTable, query, args, &output) + return output, err +} + +func (qb *StudioQueryBuilder) GetUrls(id int64) (StudioUrls, error) { + joins := StudioUrls{} + err := qb.dbi.FindJoins(studioUrlTable, id, &joins) + + return joins, err +} diff --git a/pkg/models/querybuilder_tag.go b/pkg/models/querybuilder_tag.go new file mode 100644 index 0000000..e8cf2e6 --- /dev/null +++ b/pkg/models/querybuilder_tag.go @@ -0,0 +1,178 @@ +package models + +import ( + "github.com/jmoiron/sqlx" + + "github.com/stashapp/stashdb/pkg/database" +) + +type TagQueryBuilder struct { + dbi database.DBI +} + +func NewTagQueryBuilder(tx *sqlx.Tx) TagQueryBuilder { + return TagQueryBuilder{ + dbi: database.DBIWithTxn(tx), + } +} + +func (qb *TagQueryBuilder) toModel(ro interface{}) *Tag { + if ro != nil { + return ro.(*Tag) + } + + return nil +} + +func (qb *TagQueryBuilder) Create(newTag Tag) (*Tag, error) { + ret, err := qb.dbi.Insert(newTag) + return qb.toModel(ret), err +} + +func (qb *TagQueryBuilder) Update(updatedTag Tag) (*Tag, error) { + ret, err := qb.dbi.Update(updatedTag) + return qb.toModel(ret), err +} + +func (qb *TagQueryBuilder) Destroy(id int64) error { + return qb.dbi.Delete(id, tagDBTable) +} + +func (qb *TagQueryBuilder) CreateAliases(newJoins TagAliases) error { + return qb.dbi.InsertJoins(tagAliasTable, &newJoins) +} + +func (qb *TagQueryBuilder) UpdateAliases(tagID int64, updatedJoins TagAliases) error { + return qb.dbi.ReplaceJoins(tagAliasTable, tagID, &updatedJoins) +} + +func (qb *TagQueryBuilder) Find(id int64) (*Tag, error) { + ret, err := qb.dbi.Find(id, tagDBTable) + return qb.toModel(ret), err +} + +func (qb *TagQueryBuilder) FindByNameOrAlias(name string) (*Tag, error) { + query := `SELECT tags.* FROM tags + left join tag_aliases on tags.id = tag_aliases.tag_id + WHERE tag_aliases.alias = ? OR tags.name = ?` + + args := []interface{}{name, name} + results, err := qb.queryTags(query, args) + if err != nil || len(results) < 1 { + return nil, err + } + return results[0], nil +} + +func (qb *TagQueryBuilder) FindBySceneID(sceneID int64) ([]*Tag, error) { + query := ` + SELECT tags.* FROM tags + LEFT JOIN scene_tags as scenes_join on scenes_join.tag_id = tags.id + LEFT JOIN scenes on scenes_join.scene_id = scenes.id + WHERE scenes.id = ? + GROUP BY tags.id + ` + args := []interface{}{sceneID} + return qb.queryTags(query, args) +} + +func (qb *TagQueryBuilder) FindByNames(names []string) ([]*Tag, error) { + query := "SELECT * FROM tags WHERE name IN " + getInBinding(len(names)) + var args []interface{} + for _, name := range names { + args = append(args, name) + } + return qb.queryTags(query, args) +} + +func (qb *TagQueryBuilder) FindByAliases(names []string) ([]*Tag, error) { + query := `SELECT tags.* FROM tags + left join tag_aliases on tags.id = tag_aliases.tag_id + WHERE tag_aliases.alias IN ` + getInBinding(len(names)) + + var args []interface{} + for _, name := range names { + args = append(args, name) + } + return qb.queryTags(query, args) +} + +func (qb *TagQueryBuilder) FindByName(name string) ([]*Tag, error) { + query := "SELECT * FROM tags WHERE upper(name) = upper(?)" + var args []interface{} + args = append(args, name) + return qb.queryTags(query, args) +} + +func (qb *TagQueryBuilder) FindByAlias(name string) ([]*Tag, error) { + query := `SELECT tags.* FROM tags + left join tag_aliases on tag.id = tag_aliases.tag_id + WHERE upper(tag_aliases.alias) = UPPER(?)` + + var args []interface{} + args = append(args, name) + return qb.queryTags(query, args) +} + +func (qb *TagQueryBuilder) Count() (int, error) { + return runCountQuery(buildCountQuery("SELECT tags.id FROM tags"), nil) +} + +func (qb *TagQueryBuilder) Query(tagFilter *TagFilterType, findFilter *QuerySpec) ([]*Tag, int) { + if tagFilter == nil { + tagFilter = &TagFilterType{} + } + if findFilter == nil { + findFilter = &QuerySpec{} + } + + query := queryBuilder{ + tableName: tagTable, + } + + query.body = selectDistinctIDs(tagTable) + + if q := tagFilter.Name; q != nil && *q != "" { + searchColumns := []string{"tags.name"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + query.addWhere(clause) + query.addArg(thisArgs...) + } + + query.sortAndPagination = qb.getTagSort(findFilter) + getPagination(findFilter) + idsResult, countResult := query.executeFind() + + var tags []*Tag + for _, id := range idsResult { + tag, _ := qb.Find(id) + tags = append(tags, tag) + } + + return tags, countResult +} + +func (qb *TagQueryBuilder) getTagSort(findFilter *QuerySpec) string { + var sort string + var direction string + if findFilter == nil { + sort = "name" + direction = "ASC" + } else { + sort = findFilter.GetSort("name") + direction = findFilter.GetDirection() + } + return getSort(sort, direction, tagTable) +} + +func (qb *TagQueryBuilder) queryTags(query string, args []interface{}) (Tags, error) { + var output Tags + err := qb.dbi.RawQuery(tagDBTable, query, args, &output) + return output, err +} + +func (qb *TagQueryBuilder) GetAliases(id int64) ([]string, error) { + joins := TagAliases{} + err := qb.dbi.FindJoins(tagAliasTable, id, &joins) + + return joins.ToAliases(), err +} diff --git a/pkg/models/sqlite_date.go b/pkg/models/sqlite_date.go index 3916f43..331ec72 100644 --- a/pkg/models/sqlite_date.go +++ b/pkg/models/sqlite_date.go @@ -2,9 +2,10 @@ package models import ( "database/sql/driver" + "time" + "github.com/stashapp/stashdb/pkg/logger" "github.com/stashapp/stashdb/pkg/utils" - "time" ) type SQLiteDate struct { @@ -38,3 +39,7 @@ func (t SQLiteDate) Value() (driver.Value, error) { } return result, nil } + +func (t SQLiteDate) IsValid() bool { + return t.Valid +} diff --git a/pkg/models/sqlite_timestamp.go b/pkg/models/sqlite_timestamp.go index b8c84f7..8bcebb8 100644 --- a/pkg/models/sqlite_timestamp.go +++ b/pkg/models/sqlite_timestamp.go @@ -19,3 +19,7 @@ func (t *SQLiteTimestamp) Scan(value interface{}) error { func (t SQLiteTimestamp) Value() (driver.Value, error) { return t.Timestamp.Format(time.RFC3339), nil } + +func (t SQLiteTimestamp) IsValid() bool { + return !t.Timestamp.IsZero() +}