mirror of
https://github.com/run-llama/study-llama.git
synced 2026-06-30 21:07:53 -04:00
feat: add llamaagent
This commit is contained in:
+10
@@ -1 +1,11 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
.env
|
||||
@@ -1,10 +0,0 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
@@ -1,7 +0,0 @@
|
||||
[project]
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
@@ -0,0 +1,27 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var ddl string
|
||||
|
||||
func CreateNewDb() (*pgx.Conn, error) {
|
||||
ctx := context.Background()
|
||||
connString := os.Getenv("POSTGRES_CONNECTION_STRING")
|
||||
db, err := pgx.Connect(ctx, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = db.Exec(ctx, ddl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
type UploadedFile struct {
|
||||
CreatedAt *string `json:"created_at"`
|
||||
DataSourceID *string `json:"data_source_id"`
|
||||
ExternalFileID *string `json:"external_file_id"`
|
||||
FileSize *int64 `json:"file_size"`
|
||||
FileType *string `json:"file_type"`
|
||||
ID string `json:"id"`
|
||||
LastModifiedAt *string `json:"last_modified_at"`
|
||||
Name string `json:"name"`
|
||||
PermissionInfo map[string]interface{} `json:"permission_info"`
|
||||
ProjectID string `json:"project_id"`
|
||||
ResourceInfo map[string]interface{} `json:"resource_info"`
|
||||
UpdatedAt *string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func UploadFile(file io.Reader, fileName string) (string, error) {
|
||||
apiKey := os.Getenv("LLAMA_CLOUD_API_KEY")
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
fileWriter, _ := writer.CreateFormFile("upload_file", fileName)
|
||||
|
||||
io.Copy(fileWriter, file)
|
||||
|
||||
writer.Close()
|
||||
url := "https://api.cloud.llamaindex.ai/api/v1/files"
|
||||
method := "POST"
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(method, url, &requestBody)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Add("Content-Type", "multipart/form-data")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
req.Header.Add("Authorization", "Bearer "+apiKey)
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var fl UploadedFile
|
||||
err = json.Unmarshal(body, &fl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fl.ID, nil
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Files table
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_category TEXT DEFAULT NULL
|
||||
);
|
||||
@@ -0,0 +1,32 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package filesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package filesdb
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
ID int32
|
||||
Username string
|
||||
FileName string
|
||||
FileCategory pgtype.Text
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: query.files.sql
|
||||
|
||||
package filesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const deleteFile = `-- name: DeleteFile :exec
|
||||
DELETE FROM files
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteFile(ctx context.Context, id int32) error {
|
||||
_, err := q.db.Exec(ctx, deleteFile, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getFiles = `-- name: GetFiles :many
|
||||
SELECT id, username, file_name, file_category FROM files
|
||||
WHERE username = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetFiles(ctx context.Context, username string) ([]File, error) {
|
||||
rows, err := q.db.Query(ctx, getFiles, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []File
|
||||
for rows.Next() {
|
||||
var i File
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Username,
|
||||
&i.FileName,
|
||||
&i.FileCategory,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/run-llama/study-llama/frontend/auth"
|
||||
db "github.com/run-llama/study-llama/frontend/authdb"
|
||||
"github.com/run-llama/study-llama/frontend/files"
|
||||
"github.com/run-llama/study-llama/frontend/filesdb"
|
||||
"github.com/run-llama/study-llama/frontend/rules"
|
||||
"github.com/run-llama/study-llama/frontend/rulesdb"
|
||||
"github.com/run-llama/study-llama/frontend/templates"
|
||||
@@ -202,6 +204,64 @@ func HandleDeleteRule(c *fiber.Ctx) error {
|
||||
return templates.RulesList(rules).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
|
||||
func HandleUploadFile(c *fiber.Ctx) error {
|
||||
user, err := auth.AuthorizePost(c)
|
||||
c.Set("Content-Type", "text/html")
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
file, err := c.FormFile("upload_file")
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
defer src.Close()
|
||||
_, err = files.UploadFile(src, file.Filename)
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
db, err := files.CreateNewDb()
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
queries := filesdb.New(db)
|
||||
files, err := queries.GetFiles(context.Background(), user.Username)
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
return templates.FilesList(files).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
|
||||
func HandleDeleteFile(c *fiber.Ctx) error {
|
||||
user, err := auth.AuthorizePost(c)
|
||||
c.Set("Content-Type", "text/html")
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
fileId := c.Params("id")
|
||||
fileIdInt, err := strconv.Atoi(fileId)
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
db, err := files.CreateNewDb()
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
queries := filesdb.New(db)
|
||||
err = queries.DeleteFile(context.Background(), int32(fileIdInt))
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
files, err := queries.GetFiles(context.Background(), user.Username)
|
||||
if err != nil {
|
||||
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
return templates.FilesList(files).Render(c.Context(), c.Response().BodyWriter())
|
||||
}
|
||||
|
||||
func LoginRoute(c *fiber.Ctx) error {
|
||||
if c.Method() != fiber.MethodGet {
|
||||
return c.SendStatus(fiber.StatusMethodNotAllowed)
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
-- name: GetFiles :many
|
||||
SELECT * FROM files
|
||||
WHERE username = $1;
|
||||
|
||||
-- name: DeleteFile :exec
|
||||
DELETE FROM files
|
||||
WHERE id = $1;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Files table
|
||||
CREATE TABLE files (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_category TEXT DEFAULT NULL
|
||||
);
|
||||
@@ -15,4 +15,12 @@ sql:
|
||||
go:
|
||||
package: "rulesdb"
|
||||
out: "rulesdb"
|
||||
sql_package: "pgx/v5"
|
||||
- engine: "postgresql"
|
||||
queries: "query.files.sql"
|
||||
schema: "schema.files.sql"
|
||||
gen:
|
||||
go:
|
||||
package: "filesdb"
|
||||
out: "filesdb"
|
||||
sql_package: "pgx/v5"
|
||||
@@ -0,0 +1,218 @@
|
||||
package templates
|
||||
|
||||
import "github.com/run-llama/study-llama/frontend/filesdb"
|
||||
import "strconv"
|
||||
import "slices"
|
||||
|
||||
// FilesPage is the main page component for managing files
|
||||
templ FilesPage(files []filesdb.File) {
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Study Llama - Upload Your Notes</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/htmx.org@2.0.7/dist/htmx.min.js"></script>
|
||||
<link href="https://cdn.jsdelivr.net/npm/daisyui@5" rel="stylesheet" type="text/css" />
|
||||
<script src="https://cdn.jsdelivr.net/npm/@tailwindcss/browser@4"></script>
|
||||
</head>
|
||||
<div class="container mx-auto p-6 max-w-6xl">
|
||||
<div class="flex justify-between items-center mb-6">
|
||||
<h1 class="text-3xl font-bold">Notes Management</h1>
|
||||
<button
|
||||
class="btn btn-primary"
|
||||
onclick="upload_file_modal.showModal()"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-2" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fill-rule="evenodd" d="M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
|
||||
</svg>
|
||||
Upload File
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div id="status-message"></div>
|
||||
|
||||
<div id="files-container" class="space-y-6">
|
||||
@FilesList(files)
|
||||
</div>
|
||||
|
||||
@UploadFileModal()
|
||||
</div>
|
||||
}
|
||||
|
||||
// FilesList displays files grouped by category
|
||||
templ FilesList(files []filesdb.File) {
|
||||
if len(files) == 0 {
|
||||
<div class="alert alert-info">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" class="stroke-current shrink-0 w-6 h-6"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"></path></svg>
|
||||
<span>No files yet. Upload your first file to get started!</span>
|
||||
</div>
|
||||
} else {
|
||||
@FilesByCategory(files)
|
||||
}
|
||||
}
|
||||
|
||||
// FilesByCategory groups and displays files by category
|
||||
templ FilesByCategory(files []filesdb.File) {
|
||||
{{
|
||||
groupFilesByCategory := func(files []filesdb.File) map[string][]filesdb.File {
|
||||
categories := []string{}
|
||||
for _, fl := range files {
|
||||
if fl.FileCategory.Valid && !slices.Contains(categories, fl.FileCategory.String) {
|
||||
categories = append(categories, fl.FileCategory.String)
|
||||
}
|
||||
}
|
||||
categoriesMap := map[string][]filesdb.File{}
|
||||
for _, cat := range categories {
|
||||
for _, fl := range files {
|
||||
if fl.FileCategory.Valid && fl.FileCategory.String == cat {
|
||||
ls, ok := categoriesMap[cat]
|
||||
if !ok {
|
||||
categoriesMap[cat] = []filesdb.File{fl}
|
||||
} else {
|
||||
ls = append(ls, fl)
|
||||
categoriesMap[cat] = ls
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return categoriesMap
|
||||
}
|
||||
categoryFiles := groupFilesByCategory(files)
|
||||
numFiles := strconv.Itoa(len(categoryFiles))
|
||||
}}
|
||||
for category, categoryFiles := range groupFilesByCategory(files) {
|
||||
<div class="mb-6">
|
||||
<h2 class="text-2xl font-semibold mb-4 flex items-center gap-2">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"></path>
|
||||
</svg>
|
||||
if category == "" {
|
||||
<span>Uncategorized</span>
|
||||
} else {
|
||||
<span>{ category }</span>
|
||||
}
|
||||
<span class="badge badge-ghost">{ numFiles }</span>
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
for _, file := range categoryFiles {
|
||||
@FileCard(file)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
|
||||
// FileCard displays a single file
|
||||
templ FileCard(file filesdb.File) {
|
||||
{{
|
||||
fileId := strconv.Itoa(int(file.ID))
|
||||
}}
|
||||
<div class="card bg-base-100 shadow-lg border border-base-300 hover:shadow-xl transition-shadow">
|
||||
<div class="card-body p-4">
|
||||
<div class="flex items-start justify-between">
|
||||
<div class="flex items-start gap-3 flex-1 min-w-0">
|
||||
<div class="flex-1 min-w-0">
|
||||
<h3 class="font-semibold text-sm truncate" title={ file.FileName }>
|
||||
{ file.FileName }
|
||||
</h3>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dropdown dropdown-end">
|
||||
<label tabindex="0" class="btn btn-ghost btn-xs btn-square">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path d="M10 6a2 2 0 110-4 2 2 0 010 4zM10 12a2 2 0 110-4 2 2 0 010 4zM10 18a2 2 0 110-4 2 2 0 010 4z"></path>
|
||||
</svg>
|
||||
</label>
|
||||
<ul tabindex="0" class="dropdown-content z-[1] menu p-2 shadow bg-base-100 rounded-box w-52">
|
||||
<li>
|
||||
<button
|
||||
hx-delete={ "/notes/" + fileId }
|
||||
hx-confirm="Are you sure you want to delete this file?"
|
||||
hx-target="#files-container"
|
||||
hx-swap="innerHTML"
|
||||
class="text-error"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="h-4 w-4" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fill-rule="evenodd" d="M9 2a1 1 0 00-.894.553L7.382 4H4a1 1 0 000 2v10a2 2 0 002 2h8a2 2 0 002-2V6a1 1 0 100-2h-3.382l-.724-1.447A1 1 0 0011 2H9zM7 8a1 1 0 012 0v6a1 1 0 11-2 0V8zm5-1a1 1 0 00-1 1v6a1 1 0 102 0V8a1 1 0 00-1-1z" clip-rule="evenodd"></path>
|
||||
</svg>
|
||||
Delete
|
||||
</button>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
// UploadFileModal is the modal for uploading new files
|
||||
templ UploadFileModal() {
|
||||
<dialog id="upload_file_modal" class="modal">
|
||||
<div class="modal-box">
|
||||
<h3 class="font-bold text-lg mb-4">Upload File</h3>
|
||||
<form
|
||||
hx-post="/notes"
|
||||
hx-encoding="multipart/form-data"
|
||||
hx-target="#files-container"
|
||||
hx-swap="innerHTML"
|
||||
hx-on::after-request="if(event.detail.successful) { upload_file_modal.close(); this.reset(); }"
|
||||
>
|
||||
<div class="form-control w-full mb-4">
|
||||
<label class="label">
|
||||
<span class="label-text">Select File</span>
|
||||
</label>
|
||||
<input
|
||||
type="file"
|
||||
name="upload_file"
|
||||
class="file-input file-input-bordered w-full"
|
||||
required
|
||||
onchange="updateFileName(this)"
|
||||
/>
|
||||
<label class="label">
|
||||
<span class="label-text-alt" id="file-size-info"></span>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="form-control w-full mb-4">
|
||||
<label class="label">
|
||||
<span class="label-text">Category (Optional)</span>
|
||||
</label>
|
||||
<select name="file_category" class="select select-bordered w-full">
|
||||
<option value="">Uncategorized</option>
|
||||
<option value="Documents">Documents</option>
|
||||
<option value="Images">Images</option>
|
||||
<option value="Videos">Videos</option>
|
||||
<option value="Audio">Audio</option>
|
||||
<option value="Archives">Archives</option>
|
||||
<option value="Other">Other</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="modal-action">
|
||||
<button type="button" class="btn" onclick="upload_file_modal.close(); this.closest('form').reset();">Cancel</button>
|
||||
<button type="submit" class="btn btn-primary">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-2" viewBox="0 0 20 20" fill="currentColor">
|
||||
<path fill-rule="evenodd" d="M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
|
||||
</svg>
|
||||
Upload
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
<form method="dialog" class="modal-backdrop">
|
||||
<button>close</button>
|
||||
</form>
|
||||
</dialog>
|
||||
|
||||
<script>
|
||||
function updateFileName(input) {
|
||||
const fileInfo = document.getElementById('file-size-info');
|
||||
if (input.files && input.files[0]) {
|
||||
const file = input.files[0];
|
||||
const sizeMB = (file.size / (1024 * 1024)).toFixed(2);
|
||||
fileInfo.textContent = `${file.name} (${sizeMB} MB)`;
|
||||
} else {
|
||||
fileInfo.textContent = '';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
// Code generated by templ - DO NOT EDIT.
|
||||
|
||||
// templ: version: v0.3.960
|
||||
package templates
|
||||
|
||||
//lint:file-ignore SA4006 This context is only used if a nested component is present.
|
||||
|
||||
import "github.com/a-h/templ"
|
||||
import templruntime "github.com/a-h/templ/runtime"
|
||||
|
||||
import "github.com/run-llama/study-llama/frontend/filesdb"
|
||||
import "strconv"
|
||||
import "slices"
|
||||
|
||||
// FilesPage is the main page component for managing files
|
||||
func FilesPage(files []filesdb.File) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var1 == nil {
|
||||
templ_7745c5c3_Var1 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<head><meta charset=\"UTF-8\"><meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\"><title>Study Llama - Upload Your Notes</title><script src=\"https://cdn.jsdelivr.net/npm/htmx.org@2.0.7/dist/htmx.min.js\"></script><link href=\"https://cdn.jsdelivr.net/npm/daisyui@5\" rel=\"stylesheet\" type=\"text/css\"><script src=\"https://cdn.jsdelivr.net/npm/@tailwindcss/browser@4\"></script></head><div class=\"container mx-auto p-6 max-w-6xl\"><div class=\"flex justify-between items-center mb-6\"><h1 class=\"text-3xl font-bold\">Notes Management</h1><button class=\"btn btn-primary\" onclick=\"upload_file_modal.showModal()\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5 mr-2\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z\" clip-rule=\"evenodd\"></path></svg> Upload File</button></div><div id=\"status-message\"></div><div id=\"files-container\" class=\"space-y-6\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = FilesList(files).Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = UploadFileModal().Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FilesList displays files grouped by category
|
||||
func FilesList(files []filesdb.File) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var2 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var2 == nil {
|
||||
templ_7745c5c3_Var2 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
if len(files) == 0 {
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "<div class=\"alert alert-info\"><svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"none\" viewBox=\"0 0 24 24\" class=\"stroke-current shrink-0 w-6 h-6\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"2\" d=\"M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"></path></svg> <span>No files yet. Upload your first file to get started!</span></div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
} else {
|
||||
templ_7745c5c3_Err = FilesByCategory(files).Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FilesByCategory groups and displays files by category
|
||||
func FilesByCategory(files []filesdb.File) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var3 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var3 == nil {
|
||||
templ_7745c5c3_Var3 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
groupFilesByCategory := func(files []filesdb.File) map[string][]filesdb.File {
|
||||
categories := []string{}
|
||||
for _, fl := range files {
|
||||
if fl.FileCategory.Valid && !slices.Contains(categories, fl.FileCategory.String) {
|
||||
categories = append(categories, fl.FileCategory.String)
|
||||
}
|
||||
}
|
||||
categoriesMap := map[string][]filesdb.File{}
|
||||
for _, cat := range categories {
|
||||
for _, fl := range files {
|
||||
if fl.FileCategory.Valid && fl.FileCategory.String == cat {
|
||||
ls, ok := categoriesMap[cat]
|
||||
if !ok {
|
||||
categoriesMap[cat] = []filesdb.File{fl}
|
||||
} else {
|
||||
ls = append(ls, fl)
|
||||
categoriesMap[cat] = ls
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return categoriesMap
|
||||
}
|
||||
categoryFiles := groupFilesByCategory(files)
|
||||
numFiles := strconv.Itoa(len(categoryFiles))
|
||||
for category, categoryFiles := range groupFilesByCategory(files) {
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "<div class=\"mb-6\"><h2 class=\"text-2xl font-semibold mb-4 flex items-center gap-2\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-6 w-6\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"2\" d=\"M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z\"></path></svg> ")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
if category == "" {
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "<span>Uncategorized</span> ")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
} else {
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "<span>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var4 string
|
||||
templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(category)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 91, Col: 21}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</span> ")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "<span class=\"badge badge-ghost\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var5 string
|
||||
templ_7745c5c3_Var5, templ_7745c5c3_Err = templ.JoinStringErrs(numFiles)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 93, Col: 46}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var5))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "</span></h2><div class=\"grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
for _, file := range categoryFiles {
|
||||
templ_7745c5c3_Err = FileCard(file).Render(ctx, templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "</div></div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FileCard displays a single file
|
||||
func FileCard(file filesdb.File) templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var6 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var6 == nil {
|
||||
templ_7745c5c3_Var6 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
fileId := strconv.Itoa(int(file.ID))
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "<div class=\"card bg-base-100 shadow-lg border border-base-300 hover:shadow-xl transition-shadow\"><div class=\"card-body p-4\"><div class=\"flex items-start justify-between\"><div class=\"flex items-start gap-3 flex-1 min-w-0\"><div class=\"flex-1 min-w-0\"><h3 class=\"font-semibold text-sm truncate\" title=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var7 string
|
||||
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(file.FileName)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 114, Col: 70}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "\">")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var8 string
|
||||
templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(file.FileName)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 115, Col: 22}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "</h3></div></div><div class=\"dropdown dropdown-end\"><label tabindex=\"0\" class=\"btn btn-ghost btn-xs btn-square\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path d=\"M10 6a2 2 0 110-4 2 2 0 010 4zM10 12a2 2 0 110-4 2 2 0 010 4zM10 18a2 2 0 110-4 2 2 0 010 4z\"></path></svg></label><ul tabindex=\"0\" class=\"dropdown-content z-[1] menu p-2 shadow bg-base-100 rounded-box w-52\"><li><button hx-delete=\"")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
var templ_7745c5c3_Var9 string
|
||||
templ_7745c5c3_Var9, templ_7745c5c3_Err = templ.JoinStringErrs("/notes/" + fileId)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 128, Col: 38}
|
||||
}
|
||||
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var9))
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "\" hx-confirm=\"Are you sure you want to delete this file?\" hx-target=\"#files-container\" hx-swap=\"innerHTML\" class=\"text-error\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-4 w-4\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M9 2a1 1 0 00-.894.553L7.382 4H4a1 1 0 000 2v10a2 2 0 002 2h8a2 2 0 002-2V6a1 1 0 100-2h-3.382l-.724-1.447A1 1 0 0011 2H9zM7 8a1 1 0 012 0v6a1 1 0 11-2 0V8zm5-1a1 1 0 00-1 1v6a1 1 0 102 0V8a1 1 0 00-1-1z\" clip-rule=\"evenodd\"></path></svg> Delete</button></li></ul></div></div></div></div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UploadFileModal is the modal for uploading new files
|
||||
func UploadFileModal() templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var10 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var10 == nil {
|
||||
templ_7745c5c3_Var10 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "<dialog id=\"upload_file_modal\" class=\"modal\"><div class=\"modal-box\"><h3 class=\"font-bold text-lg mb-4\">Upload File</h3><form hx-post=\"/notes\" hx-encoding=\"multipart/form-data\" hx-target=\"#files-container\" hx-swap=\"innerHTML\" hx-on::after-request=\"if(event.detail.successful) { upload_file_modal.close(); this.reset(); }\"><div class=\"form-control w-full mb-4\"><label class=\"label\"><span class=\"label-text\">Select File</span></label> <input type=\"file\" name=\"upload_file\" class=\"file-input file-input-bordered w-full\" required onchange=\"updateFileName(this)\"> <label class=\"label\"><span class=\"label-text-alt\" id=\"file-size-info\"></span></label></div><div class=\"form-control w-full mb-4\"><label class=\"label\"><span class=\"label-text\">Category (Optional)</span></label> <select name=\"file_category\" class=\"select select-bordered w-full\"><option value=\"\">Uncategorized</option> <option value=\"Documents\">Documents</option> <option value=\"Images\">Images</option> <option value=\"Videos\">Videos</option> <option value=\"Audio\">Audio</option> <option value=\"Archives\">Archives</option> <option value=\"Other\">Other</option></select></div><div class=\"modal-action\"><button type=\"button\" class=\"btn\" onclick=\"upload_file_modal.close(); this.closest('form').reset();\">Cancel</button> <button type=\"submit\" class=\"btn btn-primary\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5 mr-2\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z\" clip-rule=\"evenodd\"></path></svg> Upload</button></div></form></div><form method=\"dialog\" class=\"modal-backdrop\"><button>close</button></form></dialog><script>\n\t\tfunction updateFileName(input) {\n\t\t\tconst fileInfo = document.getElementById('file-size-info');\n\t\t\tif (input.files && input.files[0]) {\n\t\t\t\tconst file = input.files[0];\n\t\t\t\tconst sizeMB = (file.size / (1024 * 1024)).toFixed(2);\n\t\t\t\tfileInfo.textContent = `${file.name} (${sizeMB} MB)`;\n\t\t\t} else {\n\t\t\t\tfileInfo.textContent = '';\n\t\t\t}\n\t\t}\n\t</script>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
||||
@@ -0,0 +1,30 @@
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.10,<0.10.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[project]
|
||||
name = "study-llama-backend"
|
||||
version = "0.1.0"
|
||||
description = "Backend workflows to orchestrate file classification and ingestion into a vector database"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"llama-cloud-services>=0.6.81",
|
||||
"llama-index-workflows>=2.11.4",
|
||||
"openai>=2.8.1",
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"qdrant-client>=1.16.0",
|
||||
"sqlalchemy>=2.0.44",
|
||||
]
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "study_llama"
|
||||
|
||||
[tool.llamadeploy]
|
||||
name = "study-llama"
|
||||
env_files = [".env"]
|
||||
llama_cloud = true
|
||||
|
||||
[tool.llamadeploy.workflows]
|
||||
classify-and-extract = "study_llama.classify_and_extract.workflow:workflow"
|
||||
search = "study_llama.search.workflow:workflow"
|
||||
@@ -0,0 +1,7 @@
|
||||
-- name: CreateFile :one
|
||||
INSERT INTO files (
|
||||
username, file_name, file_category
|
||||
) VALUES (
|
||||
$1, $2, $3
|
||||
)
|
||||
RETURNING *;
|
||||
@@ -0,0 +1,3 @@
|
||||
-- name: GetRules :many
|
||||
SELECT * FROM rules
|
||||
WHERE username = $1;
|
||||
@@ -0,0 +1,7 @@
|
||||
-- Files table
|
||||
CREATE TABLE files (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
file_name TEXT NOT NULL,
|
||||
file_category TEXT DEFAULT NULL
|
||||
);
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Rules table
|
||||
CREATE TABLE rules (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
rule_name TEXT NOT NULL,
|
||||
rule_type TEXT NOT NULL,
|
||||
rule_description TEXT NOT NULL
|
||||
);
|
||||
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import VectorParams, Distance
|
||||
|
||||
async def create_collections():
|
||||
client = AsyncQdrantClient(
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
https=True,
|
||||
port=443,
|
||||
host=os.getenv("QDRANT_HOST"),
|
||||
check_compatibility=False,
|
||||
)
|
||||
for coll in ("summaries", "faqs"):
|
||||
if not (await client.collection_exists(coll)):
|
||||
succ = await client.create_collection(
|
||||
collection_name=coll,
|
||||
vectors_config=VectorParams(
|
||||
size=768,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
print(f"Successfully created {coll}" if succ else f"Something went wrong while creating {coll}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(create_collections())
|
||||
@@ -0,0 +1,29 @@
|
||||
version: "2"
|
||||
plugins:
|
||||
- name: py
|
||||
wasm:
|
||||
url: https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.3.0.wasm
|
||||
sha256: fbedae96b5ecae2380a70fb5b925fd4bff58a6cfb1f3140375d098fbab7b3a3c
|
||||
sql:
|
||||
- schema: "schema.files.sql"
|
||||
queries: "query.files.sql"
|
||||
engine: "postgresql"
|
||||
codegen:
|
||||
- out: src/study_llama/filesdb
|
||||
plugin: py
|
||||
options:
|
||||
package: study_llama.filesdb
|
||||
emit_sync_querier: false
|
||||
emit_async_querier: true
|
||||
emit_pydantic_models: true
|
||||
- schema: "schema.rules.sql"
|
||||
queries: "query.rules.sql"
|
||||
engine: "postgresql"
|
||||
codegen:
|
||||
- out: src/study_llama/rulesdb
|
||||
plugin: py
|
||||
options:
|
||||
package: study_llama.rulesdb
|
||||
emit_sync_querier: false
|
||||
emit_async_querier: true
|
||||
emit_pydantic_models: true
|
||||
@@ -0,0 +1,21 @@
|
||||
from workflows.events import StartEvent, StopEvent, Event
|
||||
from pydantic import ConfigDict
|
||||
from .models import QuestionAndAnswer
|
||||
|
||||
class InputFileEvent(StartEvent):
|
||||
file_id: str
|
||||
file_name: str
|
||||
username: str
|
||||
|
||||
class ClassifiedFileEvent(Event):
|
||||
file_type: str
|
||||
|
||||
class ExtractedFileEvent(Event):
|
||||
summary: str
|
||||
faqs: list[QuestionAndAnswer]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class IngestedFileEvent(StopEvent):
|
||||
success: bool
|
||||
error: str | None = None
|
||||
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from llama_cloud_services.extract import ExtractConfig
|
||||
from llama_cloud.types.extract_mode import ExtractMode
|
||||
|
||||
class QuestionAndAnswer(BaseModel):
|
||||
question: str = Field(description="Question related to the main document")
|
||||
answer: str = Field(description="Answer to the question")
|
||||
|
||||
class StudyNotes(BaseModel):
|
||||
summary: str = Field(description="Summary of the study notes in the document")
|
||||
faqs: list[QuestionAndAnswer] = Field(description="List of potential 'Frequently Asked Questions' (with associated answer) to help a student review and prepare with the study notes")
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class WorkflowState(BaseModel):
|
||||
username: str = ""
|
||||
file_name: str = ""
|
||||
file_type: str = ""
|
||||
file_id: str = ""
|
||||
|
||||
EXTRACT_CONFIG = ExtractConfig(
|
||||
extraction_mode=ExtractMode.MULTIMODAL,
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from llama_cloud_services.beta.classifier import LlamaClassify
|
||||
from llama_cloud_services.extract import LlamaExtract
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncConnection
|
||||
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
|
||||
|
||||
async def get_llama_classify(*args, **kwargs):
|
||||
return LlamaClassify.from_api_key(api_key=os.getenv("LLAMA_CLOUD_API_KEY", ""))
|
||||
|
||||
async def get_llama_extract(*args, **kwargs):
|
||||
return LlamaExtract(
|
||||
api_key=os.getenv("LLAMA_CLOUD_API_KEY", ""),
|
||||
)
|
||||
|
||||
async def get_db_conn(*args, **kwargs):
|
||||
eng = create_async_engine(url=os.getenv("POSTGRES_CONNECTION_STRING", ""))
|
||||
return AsyncConnection(async_engine=eng)
|
||||
|
||||
async def get_vector_db_summaries(*args, **kwargs):
|
||||
client = AsyncQdrantClient(
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
https=True,
|
||||
port=443,
|
||||
host=os.getenv("QDRANT_HOST"),
|
||||
check_compatibility=False,
|
||||
)
|
||||
return SummaryVectorDB(client=client, collection_name="summaries")
|
||||
|
||||
async def get_vector_db_faqs(*args, **kwargs):
|
||||
client = AsyncQdrantClient(
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
https=True,
|
||||
port=443,
|
||||
host=os.getenv("QDRANT_HOST"),
|
||||
check_compatibility=False,
|
||||
)
|
||||
return FaqsVectorDB(client=client, collection_name="faqs")
|
||||
@@ -0,0 +1,12 @@
|
||||
import re
|
||||
from llama_cloud.types.classifier_rule import ClassifierRule
|
||||
from study_llama.rulesdb.models import Rule
|
||||
|
||||
def rules_to_classify_rules(rules: list[Rule]) -> list[ClassifierRule]:
|
||||
class_rules: list[ClassifierRule] = []
|
||||
for rule in rules:
|
||||
class_rules.append(ClassifierRule(
|
||||
type=re.sub(r"\s+", "_", rule.rule_type.lower().strip()),
|
||||
description=rule.rule_description,
|
||||
))
|
||||
return class_rules
|
||||
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from workflows import Workflow, Context, step
|
||||
from workflows.resource import Resource
|
||||
from llama_cloud.types.file import File
|
||||
from llama_cloud.types.extract_run import ExtractRun
|
||||
from typing import Annotated, TYPE_CHECKING, cast
|
||||
from study_llama.filesdb.query_files import AsyncQuerier as AsyncFilesQuerier
|
||||
from study_llama.rulesdb.query_rules import AsyncQuerier as AsyncRulesQuerier
|
||||
from .events import InputFileEvent, IngestedFileEvent, ClassifiedFileEvent, ExtractedFileEvent
|
||||
from .resources import get_db_conn, get_llama_classify, get_llama_extract, get_vector_db_faqs, get_vector_db_summaries
|
||||
from .models import StudyNotes, WorkflowState, EXTRACT_CONFIG
|
||||
from .utils import rules_to_classify_rules
|
||||
if TYPE_CHECKING:
|
||||
from llama_cloud_services.beta.classifier.client import LlamaClassify
|
||||
from llama_cloud_services.extract import LlamaExtract
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection
|
||||
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
|
||||
|
||||
class ClassifyExtractWorkflow(Workflow):
|
||||
@step
|
||||
async def classify_file(self, ev: InputFileEvent, ctx: Context[WorkflowState], classifier: Annotated[LlamaClassify, Resource(get_llama_classify)], db_conn: Annotated[AsyncConnection, Resource(get_db_conn)]) -> ClassifiedFileEvent | IngestedFileEvent:
|
||||
querier = AsyncRulesQuerier(conn=db_conn)
|
||||
response = querier.get_rules(username=ev.username)
|
||||
rules = []
|
||||
async for rule in response:
|
||||
rules.append(rule)
|
||||
class_rules = rules_to_classify_rules(rules=rules)
|
||||
result = await classifier.aclassify_file_ids(rules=class_rules, file_ids=[ev.file_id])
|
||||
file_type: str | None = None
|
||||
for item in result.items:
|
||||
if (class_res := item.result) is not None:
|
||||
if class_res.type is not None:
|
||||
file_type = class_res.type
|
||||
break
|
||||
if file_type is not None:
|
||||
async with ctx.store.edit_state() as state:
|
||||
state.file_name = ev.file_name
|
||||
state.username = ev.username
|
||||
state.file_id = ev.file_id
|
||||
state.file_type = file_type
|
||||
querier_files = AsyncFilesQuerier(conn=db_conn)
|
||||
await querier_files.create_file(username=ev.username, file_name=ev.file_name, file_category=file_type)
|
||||
return ClassifiedFileEvent(file_type=file_type)
|
||||
else:
|
||||
return IngestedFileEvent(success=False, error="It was not possible to classify the provided file based on the existing categories")
|
||||
|
||||
|
||||
@step
|
||||
async def extract_file_details(self, ev: ClassifiedFileEvent, extractor: Annotated[LlamaExtract, Resource(get_llama_extract)], ctx: Context[WorkflowState]) -> ExtractedFileEvent | IngestedFileEvent:
|
||||
state = await ctx.store.get_state()
|
||||
result = await extractor.aextract(
|
||||
data_schema=StudyNotes,
|
||||
config=EXTRACT_CONFIG,
|
||||
files=File(
|
||||
id=state.file_id,
|
||||
name=state.file_name,
|
||||
project_id=os.getenv("LLAMA_CLOUD_PROJECT_ID"),
|
||||
data_source_id=None,
|
||||
created_at=None,
|
||||
external_file_id=None,
|
||||
file_size=None,
|
||||
file_type=None,
|
||||
last_modified_at=None,
|
||||
permission_info=None,
|
||||
resource_info=None,
|
||||
updated_at=None
|
||||
),
|
||||
)
|
||||
if (data := cast(ExtractRun, result).data) is not None:
|
||||
extraction_result = StudyNotes.model_validate(data)
|
||||
return ExtractedFileEvent(summary=extraction_result.summary, faqs=extraction_result.faqs)
|
||||
else:
|
||||
return IngestedFileEvent(success=False, error="It was not possible to extract details from the provided file")
|
||||
|
||||
@step
|
||||
async def ingest_file_details(self, ev: ExtractedFileEvent, summaries_vdb: Annotated[SummaryVectorDB, Resource(get_vector_db_summaries)], faqs_vdb: Annotated[FaqsVectorDB, Resource(get_vector_db_faqs)], ctx: Context[WorkflowState]) -> IngestedFileEvent:
|
||||
state = await ctx.store.get_state()
|
||||
questions = [faq.question for faq in ev.faqs]
|
||||
answers = [faq.answer for faq in ev.faqs]
|
||||
await faqs_vdb.upload(questions, answers, state.username, state.file_type, state.file_name)
|
||||
await summaries_vdb.upload(ev.summary, state.username, state.file_type, state.file_name)
|
||||
return IngestedFileEvent(success=True)
|
||||
|
||||
|
||||
workflow = ClassifyExtractWorkflow(timeout=1000)
|
||||
@@ -0,0 +1,12 @@
|
||||
# Code generated by sqlc. DO NOT EDIT.
|
||||
# versions:
|
||||
# sqlc v1.30.0
|
||||
import pydantic
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class File(pydantic.BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
file_name: str
|
||||
file_category: Optional[str]
|
||||
@@ -0,0 +1,36 @@
|
||||
# Code generated by sqlc. DO NOT EDIT.
|
||||
# versions:
|
||||
# sqlc v1.30.0
|
||||
# source: query.files.sql
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy
|
||||
import sqlalchemy.ext.asyncio
|
||||
|
||||
from study_llama.filesdb import models
|
||||
|
||||
|
||||
CREATE_FILE = """-- name: create_file \\:one
|
||||
INSERT INTO files (
|
||||
username, file_name, file_category
|
||||
) VALUES (
|
||||
:p1, :p2, :p3
|
||||
)
|
||||
RETURNING id, username, file_name, file_category
|
||||
"""
|
||||
|
||||
|
||||
class AsyncQuerier:
|
||||
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
|
||||
self._conn = conn
|
||||
|
||||
async def create_file(self, *, username: str, file_name: str, file_category: Optional[str]) -> Optional[models.File]:
|
||||
row = (await self._conn.execute(sqlalchemy.text(CREATE_FILE), {"p1": username, "p2": file_name, "p3": file_category})).first()
|
||||
if row is None:
|
||||
return None
|
||||
return models.File(
|
||||
id=row[0],
|
||||
username=row[1],
|
||||
file_name=row[2],
|
||||
file_category=row[3],
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
# Code generated by sqlc. DO NOT EDIT.
|
||||
# versions:
|
||||
# sqlc v1.30.0
|
||||
import pydantic
|
||||
|
||||
|
||||
class Rule(pydantic.BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
rule_name: str
|
||||
rule_type: str
|
||||
rule_description: str
|
||||
@@ -0,0 +1,32 @@
|
||||
# Code generated by sqlc. DO NOT EDIT.
|
||||
# versions:
|
||||
# sqlc v1.30.0
|
||||
# source: query.rules.sql
|
||||
from typing import AsyncIterator
|
||||
|
||||
import sqlalchemy
|
||||
import sqlalchemy.ext.asyncio
|
||||
|
||||
from study_llama.rulesdb import models
|
||||
|
||||
|
||||
GET_RULES = """-- name: get_rules \\:many
|
||||
SELECT id, username, rule_name, rule_type, rule_description FROM rules
|
||||
WHERE username = :p1
|
||||
"""
|
||||
|
||||
|
||||
class AsyncQuerier:
|
||||
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
|
||||
self._conn = conn
|
||||
|
||||
async def get_rules(self, *, username: str) -> AsyncIterator[models.Rule]:
|
||||
result = await self._conn.stream(sqlalchemy.text(GET_RULES), {"p1": username})
|
||||
async for row in result:
|
||||
yield models.Rule(
|
||||
id=row[0],
|
||||
username=row[1],
|
||||
rule_name=row[2],
|
||||
rule_type=row[3],
|
||||
rule_description=row[4],
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from workflows.events import StartEvent, StopEvent
|
||||
from typing import Literal
|
||||
from pydantic import ConfigDict
|
||||
from study_llama.vectordb.vectordb import Result
|
||||
|
||||
class SearchInputEvent(StartEvent):
|
||||
search_type: Literal["summary", "faqs"]
|
||||
search_input: str
|
||||
username: str
|
||||
file_name: str | None = None
|
||||
category: str | None = None
|
||||
|
||||
class SearchOutputEvent(StopEvent):
|
||||
results: list[Result]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
|
||||
|
||||
async def get_vector_db_summaries(*args, **kwargs):
|
||||
client = AsyncQdrantClient(
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
https=True,
|
||||
port=443,
|
||||
host=os.getenv("QDRANT_HOST"),
|
||||
check_compatibility=False,
|
||||
)
|
||||
return SummaryVectorDB(client=client, collection_name="summaries")
|
||||
|
||||
async def get_vector_db_faqs(*args, **kwargs):
|
||||
client = AsyncQdrantClient(
|
||||
api_key=os.getenv("QDRANT_API_KEY"),
|
||||
https=True,
|
||||
port=443,
|
||||
host=os.getenv("QDRANT_HOST"),
|
||||
check_compatibility=False,
|
||||
)
|
||||
return FaqsVectorDB(client=client, collection_name="faqs")
|
||||
@@ -0,0 +1,19 @@
|
||||
from workflows import Workflow, Context, step
|
||||
from workflows.resource import Resource
|
||||
from typing import Annotated, TYPE_CHECKING
|
||||
from .resources import get_vector_db_faqs, get_vector_db_summaries
|
||||
from .events import SearchInputEvent, SearchOutputEvent
|
||||
if TYPE_CHECKING:
|
||||
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
|
||||
|
||||
class SearchWorkflow(Workflow):
|
||||
@step
|
||||
async def search(self, ev: SearchInputEvent, summaries_vdb: Annotated[SummaryVectorDB, Resource(get_vector_db_summaries)], faqs_vdb: Annotated[FaqsVectorDB, Resource(get_vector_db_faqs)]) -> SearchOutputEvent:
|
||||
if ev.search_type == "faqs":
|
||||
results = await faqs_vdb.search(ev.search_input, ev.username, ev.category, ev.file_name)
|
||||
return SearchOutputEvent(results=results)
|
||||
else:
|
||||
results = await summaries_vdb.search(ev.search_input, ev.username, ev.category, ev.file_name)
|
||||
return SearchOutputEvent(results=results)
|
||||
|
||||
workflow = SearchWorkflow(timeout=600)
|
||||
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
class OpenAIEmbedder:
|
||||
def __init__(self, client: AsyncOpenAI) -> None:
|
||||
self._client = client
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
response = await self._client.embeddings.create(
|
||||
input=texts,
|
||||
model="text-embedding-3-small",
|
||||
dimensions=768
|
||||
)
|
||||
return [d.embedding for d in response.data]
|
||||
|
||||
openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
embedder = OpenAIEmbedder(client=openai_client)
|
||||
@@ -0,0 +1,97 @@
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
|
||||
from typing import cast, Literal
|
||||
from .embeddings import embedder
|
||||
|
||||
class Result(BaseModel):
|
||||
result_type: Literal["answer", "summary"]
|
||||
text: str
|
||||
file_name: str
|
||||
category: str
|
||||
similarity: float
|
||||
|
||||
class SummaryVectorDB:
|
||||
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
||||
self._client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def upload(self, summary: str, username: str, category: str, file_name: str) -> None:
|
||||
vec = await embedder.embed([summary])
|
||||
point = PointStruct(
|
||||
id=uuid.uuid4(),
|
||||
vector=vec[0],
|
||||
payload={"category": category, "file_name": file_name, "summary": summary, "username": username}
|
||||
)
|
||||
self._client.upload_points(self.collection_name, points=[point])
|
||||
return None
|
||||
|
||||
async def search(self, text: str, username: str, category: str | None = None, file_name: str | None = None) -> list[Result]:
|
||||
filters = Filter(must=[
|
||||
FieldCondition(
|
||||
key="username",
|
||||
match=MatchValue(value=username)
|
||||
)
|
||||
])
|
||||
if category is not None:
|
||||
(cast(list[FieldCondition], filters.must)).append(
|
||||
FieldCondition(
|
||||
key="category",
|
||||
match=MatchValue(value=category)
|
||||
)
|
||||
)
|
||||
if file_name is not None:
|
||||
(cast(list[FieldCondition], filters.must)).append(
|
||||
FieldCondition(
|
||||
key="file_name",
|
||||
match=MatchValue(value=file_name)
|
||||
)
|
||||
)
|
||||
vec = await embedder.embed([text])
|
||||
results = await self._client.query_points(self.collection_name, query=vec[0], query_filter=filters, score_threshold=0.75)
|
||||
points = results.points
|
||||
return [Result(text=point.payload["summary"], similarity=point.score, file_name=point.payload["file_name"], category=point.payload["category"], result_type="summary") for point in points if point.payload is not None]
|
||||
|
||||
class FaqsVectorDB:
|
||||
def __init__(self, client: AsyncQdrantClient, collection_name: str):
|
||||
self._client = client
|
||||
self.collection_name = collection_name
|
||||
|
||||
async def upload(self, questions: list[str], answers: list[str], username: str, category: str, file_name: str) -> None:
|
||||
vecs = await embedder.embed(questions)
|
||||
points = []
|
||||
for i,vec in enumerate(vecs):
|
||||
points.append(PointStruct(
|
||||
id=uuid.uuid4(),
|
||||
vector=vec,
|
||||
payload={"category": category, "file_name": file_name, "question": questions[i], "answer": answers[i], "username": username}
|
||||
))
|
||||
self._client.upload_points(self.collection_name, points=points)
|
||||
return None
|
||||
|
||||
async def search(self, text: str, username: str, category: str | None = None, file_name: str | None = None) -> list[Result]:
|
||||
filters = Filter(must=[
|
||||
FieldCondition(
|
||||
key="username",
|
||||
match=MatchValue(value=username)
|
||||
)
|
||||
])
|
||||
if category is not None:
|
||||
(cast(list[FieldCondition], filters.must)).append(
|
||||
FieldCondition(
|
||||
key="category",
|
||||
match=MatchValue(value=category)
|
||||
)
|
||||
)
|
||||
if file_name is not None:
|
||||
(cast(list[FieldCondition], filters.must)).append(
|
||||
FieldCondition(
|
||||
key="file_name",
|
||||
match=MatchValue(value=file_name)
|
||||
)
|
||||
)
|
||||
vec = await embedder.embed([text])
|
||||
results = await self._client.query_points(self.collection_name, query=vec[0], query_filter=filters, score_threshold=0.75)
|
||||
points = results.points
|
||||
return [Result(text=point.payload["answer"], similarity=point.score, file_name=point.payload["file_name"], category=point.payload["category"], result_type="answer") for point in points if point.payload is not None]
|
||||
Reference in New Issue
Block a user