feat: add llamaagent

This commit is contained in:
Clelia (Astra) Bertelli
2025-11-23 18:23:27 +01:00
parent eef87c7b68
commit 3951d87f62
45 changed files with 3037 additions and 17 deletions
Vendored
BIN
View File
Binary file not shown.
+10
View File
@@ -1 +1,11 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
.env
-10
View File
@@ -1,10 +0,0 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
-7
View File
@@ -1,7 +0,0 @@
[project]
name = "backend"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = []
+27
View File
@@ -0,0 +1,27 @@
package files
import (
"context"
"os"
_ "embed"
"github.com/jackc/pgx/v5"
)
//go:embed schema.sql
var ddl string
func CreateNewDb() (*pgx.Conn, error) {
ctx := context.Background()
connString := os.Getenv("POSTGRES_CONNECTION_STRING")
db, err := pgx.Connect(ctx, connString)
if err != nil {
return nil, err
}
_, err = db.Exec(ctx, ddl)
if err != nil {
return nil, err
}
return db, nil
}
+66
View File
@@ -0,0 +1,66 @@
package files
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"os"
)
type UploadedFile struct {
CreatedAt *string `json:"created_at"`
DataSourceID *string `json:"data_source_id"`
ExternalFileID *string `json:"external_file_id"`
FileSize *int64 `json:"file_size"`
FileType *string `json:"file_type"`
ID string `json:"id"`
LastModifiedAt *string `json:"last_modified_at"`
Name string `json:"name"`
PermissionInfo map[string]interface{} `json:"permission_info"`
ProjectID string `json:"project_id"`
ResourceInfo map[string]interface{} `json:"resource_info"`
UpdatedAt *string `json:"updated_at"`
}
func UploadFile(file io.Reader, fileName string) (string, error) {
apiKey := os.Getenv("LLAMA_CLOUD_API_KEY")
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
fileWriter, _ := writer.CreateFormFile("upload_file", fileName)
io.Copy(fileWriter, file)
writer.Close()
url := "https://api.cloud.llamaindex.ai/api/v1/files"
method := "POST"
client := &http.Client{}
req, err := http.NewRequest(method, url, &requestBody)
if err != nil {
return "", err
}
req.Header.Add("Content-Type", "multipart/form-data")
req.Header.Add("Accept", "application/json")
req.Header.Add("Authorization", "Bearer "+apiKey)
res, err := client.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return "", err
}
var fl UploadedFile
err = json.Unmarshal(body, &fl)
if err != nil {
return "", err
}
return fl.ID, nil
}
+7
View File
@@ -0,0 +1,7 @@
-- Files table
CREATE TABLE IF NOT EXISTS files (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL,
file_name TEXT NOT NULL,
file_category TEXT DEFAULT NULL
);
+32
View File
@@ -0,0 +1,32 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package filesdb
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}
+16
View File
@@ -0,0 +1,16 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package filesdb
import (
"github.com/jackc/pgx/v5/pgtype"
)
type File struct {
ID int32
Username string
FileName string
FileCategory pgtype.Text
}
+50
View File
@@ -0,0 +1,50 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: query.files.sql
package filesdb
import (
"context"
)
const deleteFile = `-- name: DeleteFile :exec
DELETE FROM files
WHERE id = $1
`
func (q *Queries) DeleteFile(ctx context.Context, id int32) error {
_, err := q.db.Exec(ctx, deleteFile, id)
return err
}
const getFiles = `-- name: GetFiles :many
SELECT id, username, file_name, file_category FROM files
WHERE username = $1
`
func (q *Queries) GetFiles(ctx context.Context, username string) ([]File, error) {
rows, err := q.db.Query(ctx, getFiles, username)
if err != nil {
return nil, err
}
defer rows.Close()
var items []File
for rows.Next() {
var i File
if err := rows.Scan(
&i.ID,
&i.Username,
&i.FileName,
&i.FileCategory,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
+60
View File
@@ -12,6 +12,8 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/run-llama/study-llama/frontend/auth"
db "github.com/run-llama/study-llama/frontend/authdb"
"github.com/run-llama/study-llama/frontend/files"
"github.com/run-llama/study-llama/frontend/filesdb"
"github.com/run-llama/study-llama/frontend/rules"
"github.com/run-llama/study-llama/frontend/rulesdb"
"github.com/run-llama/study-llama/frontend/templates"
@@ -202,6 +204,64 @@ func HandleDeleteRule(c *fiber.Ctx) error {
return templates.RulesList(rules).Render(c.Context(), c.Response().BodyWriter())
}
func HandleUploadFile(c *fiber.Ctx) error {
user, err := auth.AuthorizePost(c)
c.Set("Content-Type", "text/html")
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
file, err := c.FormFile("upload_file")
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
src, err := file.Open()
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
defer src.Close()
_, err = files.UploadFile(src, file.Filename)
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
db, err := files.CreateNewDb()
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
queries := filesdb.New(db)
files, err := queries.GetFiles(context.Background(), user.Username)
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
return templates.FilesList(files).Render(c.Context(), c.Response().BodyWriter())
}
func HandleDeleteFile(c *fiber.Ctx) error {
user, err := auth.AuthorizePost(c)
c.Set("Content-Type", "text/html")
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
fileId := c.Params("id")
fileIdInt, err := strconv.Atoi(fileId)
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
db, err := files.CreateNewDb()
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
queries := filesdb.New(db)
err = queries.DeleteFile(context.Background(), int32(fileIdInt))
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
files, err := queries.GetFiles(context.Background(), user.Username)
if err != nil {
return templates.StatusBanner(err).Render(c.Context(), c.Response().BodyWriter())
}
return templates.FilesList(files).Render(c.Context(), c.Response().BodyWriter())
}
func LoginRoute(c *fiber.Ctx) error {
if c.Method() != fiber.MethodGet {
return c.SendStatus(fiber.StatusMethodNotAllowed)
+7
View File
@@ -0,0 +1,7 @@
-- name: GetFiles :many
SELECT * FROM files
WHERE username = $1;
-- name: DeleteFile :exec
DELETE FROM files
WHERE id = $1;
+7
View File
@@ -0,0 +1,7 @@
-- Files table
CREATE TABLE files (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL,
file_name TEXT NOT NULL,
file_category TEXT DEFAULT NULL
);
+8
View File
@@ -15,4 +15,12 @@ sql:
go:
package: "rulesdb"
out: "rulesdb"
sql_package: "pgx/v5"
- engine: "postgresql"
queries: "query.files.sql"
schema: "schema.files.sql"
gen:
go:
package: "filesdb"
out: "filesdb"
sql_package: "pgx/v5"
+218
View File
@@ -0,0 +1,218 @@
package templates
import "github.com/run-llama/study-llama/frontend/filesdb"
import "strconv"
import "slices"
// FilesPage is the main page component for managing files
templ FilesPage(files []filesdb.File) {
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Study Llama - Upload Your Notes</title>
<script src="https://cdn.jsdelivr.net/npm/htmx.org@2.0.7/dist/htmx.min.js"></script>
<link href="https://cdn.jsdelivr.net/npm/daisyui@5" rel="stylesheet" type="text/css" />
<script src="https://cdn.jsdelivr.net/npm/@tailwindcss/browser@4"></script>
</head>
<div class="container mx-auto p-6 max-w-6xl">
<div class="flex justify-between items-center mb-6">
<h1 class="text-3xl font-bold">Notes Management</h1>
<button
class="btn btn-primary"
onclick="upload_file_modal.showModal()"
>
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-2" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd" d="M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
</svg>
Upload File
</button>
</div>
<div id="status-message"></div>
<div id="files-container" class="space-y-6">
@FilesList(files)
</div>
@UploadFileModal()
</div>
}
// FilesList displays files grouped by category
templ FilesList(files []filesdb.File) {
if len(files) == 0 {
<div class="alert alert-info">
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" class="stroke-current shrink-0 w-6 h-6"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"></path></svg>
<span>No files yet. Upload your first file to get started!</span>
</div>
} else {
@FilesByCategory(files)
}
}
// FilesByCategory groups and displays files by category
templ FilesByCategory(files []filesdb.File) {
{{
groupFilesByCategory := func(files []filesdb.File) map[string][]filesdb.File {
categories := []string{}
for _, fl := range files {
if fl.FileCategory.Valid && !slices.Contains(categories, fl.FileCategory.String) {
categories = append(categories, fl.FileCategory.String)
}
}
categoriesMap := map[string][]filesdb.File{}
for _, cat := range categories {
for _, fl := range files {
if fl.FileCategory.Valid && fl.FileCategory.String == cat {
ls, ok := categoriesMap[cat]
if !ok {
categoriesMap[cat] = []filesdb.File{fl}
} else {
ls = append(ls, fl)
categoriesMap[cat] = ls
}
}
}
}
return categoriesMap
}
categoryFiles := groupFilesByCategory(files)
numFiles := strconv.Itoa(len(categoryFiles))
}}
for category, categoryFiles := range groupFilesByCategory(files) {
<div class="mb-6">
<h2 class="text-2xl font-semibold mb-4 flex items-center gap-2">
<svg xmlns="http://www.w3.org/2000/svg" class="h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"></path>
</svg>
if category == "" {
<span>Uncategorized</span>
} else {
<span>{ category }</span>
}
<span class="badge badge-ghost">{ numFiles }</span>
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
for _, file := range categoryFiles {
@FileCard(file)
}
</div>
</div>
}
}
// FileCard displays a single file
templ FileCard(file filesdb.File) {
{{
fileId := strconv.Itoa(int(file.ID))
}}
<div class="card bg-base-100 shadow-lg border border-base-300 hover:shadow-xl transition-shadow">
<div class="card-body p-4">
<div class="flex items-start justify-between">
<div class="flex items-start gap-3 flex-1 min-w-0">
<div class="flex-1 min-w-0">
<h3 class="font-semibold text-sm truncate" title={ file.FileName }>
{ file.FileName }
</h3>
</div>
</div>
<div class="dropdown dropdown-end">
<label tabindex="0" class="btn btn-ghost btn-xs btn-square">
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5" viewBox="0 0 20 20" fill="currentColor">
<path d="M10 6a2 2 0 110-4 2 2 0 010 4zM10 12a2 2 0 110-4 2 2 0 010 4zM10 18a2 2 0 110-4 2 2 0 010 4z"></path>
</svg>
</label>
<ul tabindex="0" class="dropdown-content z-[1] menu p-2 shadow bg-base-100 rounded-box w-52">
<li>
<button
hx-delete={ "/notes/" + fileId }
hx-confirm="Are you sure you want to delete this file?"
hx-target="#files-container"
hx-swap="innerHTML"
class="text-error"
>
<svg xmlns="http://www.w3.org/2000/svg" class="h-4 w-4" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd" d="M9 2a1 1 0 00-.894.553L7.382 4H4a1 1 0 000 2v10a2 2 0 002 2h8a2 2 0 002-2V6a1 1 0 100-2h-3.382l-.724-1.447A1 1 0 0011 2H9zM7 8a1 1 0 012 0v6a1 1 0 11-2 0V8zm5-1a1 1 0 00-1 1v6a1 1 0 102 0V8a1 1 0 00-1-1z" clip-rule="evenodd"></path>
</svg>
Delete
</button>
</li>
</ul>
</div>
</div>
</div>
</div>
}
// UploadFileModal is the modal for uploading new files
templ UploadFileModal() {
<dialog id="upload_file_modal" class="modal">
<div class="modal-box">
<h3 class="font-bold text-lg mb-4">Upload File</h3>
<form
hx-post="/notes"
hx-encoding="multipart/form-data"
hx-target="#files-container"
hx-swap="innerHTML"
hx-on::after-request="if(event.detail.successful) { upload_file_modal.close(); this.reset(); }"
>
<div class="form-control w-full mb-4">
<label class="label">
<span class="label-text">Select File</span>
</label>
<input
type="file"
name="upload_file"
class="file-input file-input-bordered w-full"
required
onchange="updateFileName(this)"
/>
<label class="label">
<span class="label-text-alt" id="file-size-info"></span>
</label>
</div>
<div class="form-control w-full mb-4">
<label class="label">
<span class="label-text">Category (Optional)</span>
</label>
<select name="file_category" class="select select-bordered w-full">
<option value="">Uncategorized</option>
<option value="Documents">Documents</option>
<option value="Images">Images</option>
<option value="Videos">Videos</option>
<option value="Audio">Audio</option>
<option value="Archives">Archives</option>
<option value="Other">Other</option>
</select>
</div>
<div class="modal-action">
<button type="button" class="btn" onclick="upload_file_modal.close(); this.closest('form').reset();">Cancel</button>
<button type="submit" class="btn btn-primary">
<svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 mr-2" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd" d="M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
</svg>
Upload
</button>
</div>
</form>
</div>
<form method="dialog" class="modal-backdrop">
<button>close</button>
</form>
</dialog>
<script>
function updateFileName(input) {
const fileInfo = document.getElementById('file-size-info');
if (input.files && input.files[0]) {
const file = input.files[0];
const sizeMB = (file.size / (1024 * 1024)).toFixed(2);
fileInfo.textContent = `${file.name} (${sizeMB} MB)`;
} else {
fileInfo.textContent = '';
}
}
</script>
}
+306
View File
@@ -0,0 +1,306 @@
// Code generated by templ - DO NOT EDIT.
// templ: version: v0.3.960
package templates
//lint:file-ignore SA4006 This context is only used if a nested component is present.
import "github.com/a-h/templ"
import templruntime "github.com/a-h/templ/runtime"
import "github.com/run-llama/study-llama/frontend/filesdb"
import "strconv"
import "slices"
// FilesPage is the main page component for managing files
func FilesPage(files []filesdb.File) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var1 := templ.GetChildren(ctx)
if templ_7745c5c3_Var1 == nil {
templ_7745c5c3_Var1 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<head><meta charset=\"UTF-8\"><meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\"><title>Study Llama - Upload Your Notes</title><script src=\"https://cdn.jsdelivr.net/npm/htmx.org@2.0.7/dist/htmx.min.js\"></script><link href=\"https://cdn.jsdelivr.net/npm/daisyui@5\" rel=\"stylesheet\" type=\"text/css\"><script src=\"https://cdn.jsdelivr.net/npm/@tailwindcss/browser@4\"></script></head><div class=\"container mx-auto p-6 max-w-6xl\"><div class=\"flex justify-between items-center mb-6\"><h1 class=\"text-3xl font-bold\">Notes Management</h1><button class=\"btn btn-primary\" onclick=\"upload_file_modal.showModal()\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5 mr-2\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z\" clip-rule=\"evenodd\"></path></svg> Upload File</button></div><div id=\"status-message\"></div><div id=\"files-container\" class=\"space-y-6\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = FilesList(files).Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "</div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = UploadFileModal().Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 3, "</div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
// FilesList displays files grouped by category
func FilesList(files []filesdb.File) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var2 := templ.GetChildren(ctx)
if templ_7745c5c3_Var2 == nil {
templ_7745c5c3_Var2 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
if len(files) == 0 {
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 4, "<div class=\"alert alert-info\"><svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"none\" viewBox=\"0 0 24 24\" class=\"stroke-current shrink-0 w-6 h-6\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"2\" d=\"M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"></path></svg> <span>No files yet. Upload your first file to get started!</span></div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
} else {
templ_7745c5c3_Err = FilesByCategory(files).Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
return nil
})
}
// FilesByCategory groups and displays files by category
func FilesByCategory(files []filesdb.File) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var3 := templ.GetChildren(ctx)
if templ_7745c5c3_Var3 == nil {
templ_7745c5c3_Var3 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
groupFilesByCategory := func(files []filesdb.File) map[string][]filesdb.File {
categories := []string{}
for _, fl := range files {
if fl.FileCategory.Valid && !slices.Contains(categories, fl.FileCategory.String) {
categories = append(categories, fl.FileCategory.String)
}
}
categoriesMap := map[string][]filesdb.File{}
for _, cat := range categories {
for _, fl := range files {
if fl.FileCategory.Valid && fl.FileCategory.String == cat {
ls, ok := categoriesMap[cat]
if !ok {
categoriesMap[cat] = []filesdb.File{fl}
} else {
ls = append(ls, fl)
categoriesMap[cat] = ls
}
}
}
}
return categoriesMap
}
categoryFiles := groupFilesByCategory(files)
numFiles := strconv.Itoa(len(categoryFiles))
for category, categoryFiles := range groupFilesByCategory(files) {
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 5, "<div class=\"mb-6\"><h2 class=\"text-2xl font-semibold mb-4 flex items-center gap-2\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-6 w-6\" fill=\"none\" viewBox=\"0 0 24 24\" stroke=\"currentColor\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"2\" d=\"M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z\"></path></svg> ")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
if category == "" {
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 6, "<span>Uncategorized</span> ")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
} else {
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 7, "<span>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var4 string
templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(category)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 91, Col: 21}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 8, "</span> ")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 9, "<span class=\"badge badge-ghost\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var5 string
templ_7745c5c3_Var5, templ_7745c5c3_Err = templ.JoinStringErrs(numFiles)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 93, Col: 46}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var5))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 10, "</span></h2><div class=\"grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
for _, file := range categoryFiles {
templ_7745c5c3_Err = FileCard(file).Render(ctx, templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 11, "</div></div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
}
return nil
})
}
// FileCard displays a single file
func FileCard(file filesdb.File) templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var6 := templ.GetChildren(ctx)
if templ_7745c5c3_Var6 == nil {
templ_7745c5c3_Var6 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
fileId := strconv.Itoa(int(file.ID))
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 12, "<div class=\"card bg-base-100 shadow-lg border border-base-300 hover:shadow-xl transition-shadow\"><div class=\"card-body p-4\"><div class=\"flex items-start justify-between\"><div class=\"flex items-start gap-3 flex-1 min-w-0\"><div class=\"flex-1 min-w-0\"><h3 class=\"font-semibold text-sm truncate\" title=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var7 string
templ_7745c5c3_Var7, templ_7745c5c3_Err = templ.JoinStringErrs(file.FileName)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 114, Col: 70}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var7))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 13, "\">")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var8 string
templ_7745c5c3_Var8, templ_7745c5c3_Err = templ.JoinStringErrs(file.FileName)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 115, Col: 22}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var8))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 14, "</h3></div></div><div class=\"dropdown dropdown-end\"><label tabindex=\"0\" class=\"btn btn-ghost btn-xs btn-square\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path d=\"M10 6a2 2 0 110-4 2 2 0 010 4zM10 12a2 2 0 110-4 2 2 0 010 4zM10 18a2 2 0 110-4 2 2 0 010 4z\"></path></svg></label><ul tabindex=\"0\" class=\"dropdown-content z-[1] menu p-2 shadow bg-base-100 rounded-box w-52\"><li><button hx-delete=\"")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
var templ_7745c5c3_Var9 string
templ_7745c5c3_Var9, templ_7745c5c3_Err = templ.JoinStringErrs("/notes/" + fileId)
if templ_7745c5c3_Err != nil {
return templ.Error{Err: templ_7745c5c3_Err, FileName: `templates/notes.templ`, Line: 128, Col: 38}
}
_, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var9))
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 15, "\" hx-confirm=\"Are you sure you want to delete this file?\" hx-target=\"#files-container\" hx-swap=\"innerHTML\" class=\"text-error\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-4 w-4\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M9 2a1 1 0 00-.894.553L7.382 4H4a1 1 0 000 2v10a2 2 0 002 2h8a2 2 0 002-2V6a1 1 0 100-2h-3.382l-.724-1.447A1 1 0 0011 2H9zM7 8a1 1 0 012 0v6a1 1 0 11-2 0V8zm5-1a1 1 0 00-1 1v6a1 1 0 102 0V8a1 1 0 00-1-1z\" clip-rule=\"evenodd\"></path></svg> Delete</button></li></ul></div></div></div></div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
// UploadFileModal is the modal for uploading new files
func UploadFileModal() templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var10 := templ.GetChildren(ctx)
if templ_7745c5c3_Var10 == nil {
templ_7745c5c3_Var10 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 16, "<dialog id=\"upload_file_modal\" class=\"modal\"><div class=\"modal-box\"><h3 class=\"font-bold text-lg mb-4\">Upload File</h3><form hx-post=\"/notes\" hx-encoding=\"multipart/form-data\" hx-target=\"#files-container\" hx-swap=\"innerHTML\" hx-on::after-request=\"if(event.detail.successful) { upload_file_modal.close(); this.reset(); }\"><div class=\"form-control w-full mb-4\"><label class=\"label\"><span class=\"label-text\">Select File</span></label> <input type=\"file\" name=\"upload_file\" class=\"file-input file-input-bordered w-full\" required onchange=\"updateFileName(this)\"> <label class=\"label\"><span class=\"label-text-alt\" id=\"file-size-info\"></span></label></div><div class=\"form-control w-full mb-4\"><label class=\"label\"><span class=\"label-text\">Category (Optional)</span></label> <select name=\"file_category\" class=\"select select-bordered w-full\"><option value=\"\">Uncategorized</option> <option value=\"Documents\">Documents</option> <option value=\"Images\">Images</option> <option value=\"Videos\">Videos</option> <option value=\"Audio\">Audio</option> <option value=\"Archives\">Archives</option> <option value=\"Other\">Other</option></select></div><div class=\"modal-action\"><button type=\"button\" class=\"btn\" onclick=\"upload_file_modal.close(); this.closest('form').reset();\">Cancel</button> <button type=\"submit\" class=\"btn btn-primary\"><svg xmlns=\"http://www.w3.org/2000/svg\" class=\"h-5 w-5 mr-2\" viewBox=\"0 0 20 20\" fill=\"currentColor\"><path fill-rule=\"evenodd\" d=\"M3 17a1 1 0 011-1h12a1 1 0 110 2H4a1 1 0 01-1-1zM6.293 6.707a1 1 0 010-1.414l3-3a1 1 0 011.414 0l3 3a1 1 0 01-1.414 1.414L11 5.414V13a1 1 0 11-2 0V5.414L7.707 6.707a1 1 0 01-1.414 0z\" clip-rule=\"evenodd\"></path></svg> Upload</button></div></form></div><form method=\"dialog\" class=\"modal-backdrop\"><button>close</button></form></dialog><script>\n\t\tfunction updateFileName(input) {\n\t\t\tconst fileInfo = document.getElementById('file-size-info');\n\t\t\tif (input.files && input.files[0]) {\n\t\t\t\tconst file = input.files[0];\n\t\t\t\tconst sizeMB = (file.size / (1024 * 1024)).toFixed(2);\n\t\t\t\tfileInfo.textContent = `${file.name} (${sizeMB} MB)`;\n\t\t\t} else {\n\t\t\t\tfileInfo.textContent = '';\n\t\t\t}\n\t\t}\n\t</script>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate
+30
View File
@@ -0,0 +1,30 @@
[build-system]
requires = ["uv_build>=0.9.10,<0.10.0"]
build-backend = "uv_build"
[project]
name = "study-llama-backend"
version = "0.1.0"
description = "Backend workflows to orchestrate file classification and ingestion into a vector database"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"llama-cloud-services>=0.6.81",
"llama-index-workflows>=2.11.4",
"openai>=2.8.1",
"psycopg2-binary>=2.9.11",
"qdrant-client>=1.16.0",
"sqlalchemy>=2.0.44",
]
[tool.uv.build-backend]
module-name = "study_llama"
[tool.llamadeploy]
name = "study-llama"
env_files = [".env"]
llama_cloud = true
[tool.llamadeploy.workflows]
classify-and-extract = "study_llama.classify_and_extract.workflow:workflow"
search = "study_llama.search.workflow:workflow"
+7
View File
@@ -0,0 +1,7 @@
-- name: CreateFile :one
INSERT INTO files (
username, file_name, file_category
) VALUES (
$1, $2, $3
)
RETURNING *;
+3
View File
@@ -0,0 +1,3 @@
-- name: GetRules :many
SELECT * FROM rules
WHERE username = $1;
+7
View File
@@ -0,0 +1,7 @@
-- Files table
CREATE TABLE files (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL,
file_name TEXT NOT NULL,
file_category TEXT DEFAULT NULL
);
+8
View File
@@ -0,0 +1,8 @@
-- Rules table
CREATE TABLE rules (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL,
rule_name TEXT NOT NULL,
rule_type TEXT NOT NULL,
rule_description TEXT NOT NULL
);
+28
View File
@@ -0,0 +1,28 @@
import os
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import VectorParams, Distance
async def create_collections():
client = AsyncQdrantClient(
api_key=os.getenv("QDRANT_API_KEY"),
https=True,
port=443,
host=os.getenv("QDRANT_HOST"),
check_compatibility=False,
)
for coll in ("summaries", "faqs"):
if not (await client.collection_exists(coll)):
succ = await client.create_collection(
collection_name=coll,
vectors_config=VectorParams(
size=768,
distance=Distance.COSINE
)
)
print(f"Successfully created {coll}" if succ else f"Something went wrong while creating {coll}")
if __name__ == "__main__":
import asyncio
asyncio.run(create_collections())
+29
View File
@@ -0,0 +1,29 @@
version: "2"
plugins:
- name: py
wasm:
url: https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.3.0.wasm
sha256: fbedae96b5ecae2380a70fb5b925fd4bff58a6cfb1f3140375d098fbab7b3a3c
sql:
- schema: "schema.files.sql"
queries: "query.files.sql"
engine: "postgresql"
codegen:
- out: src/study_llama/filesdb
plugin: py
options:
package: study_llama.filesdb
emit_sync_querier: false
emit_async_querier: true
emit_pydantic_models: true
- schema: "schema.rules.sql"
queries: "query.rules.sql"
engine: "postgresql"
codegen:
- out: src/study_llama/rulesdb
plugin: py
options:
package: study_llama.rulesdb
emit_sync_querier: false
emit_async_querier: true
emit_pydantic_models: true
@@ -0,0 +1,21 @@
from workflows.events import StartEvent, StopEvent, Event
from pydantic import ConfigDict
from .models import QuestionAndAnswer
class InputFileEvent(StartEvent):
file_id: str
file_name: str
username: str
class ClassifiedFileEvent(Event):
file_type: str
class ExtractedFileEvent(Event):
summary: str
faqs: list[QuestionAndAnswer]
model_config = ConfigDict(arbitrary_types_allowed=True)
class IngestedFileEvent(StopEvent):
success: bool
error: str | None = None
@@ -0,0 +1,23 @@
from pydantic import BaseModel, Field, ConfigDict
from llama_cloud_services.extract import ExtractConfig
from llama_cloud.types.extract_mode import ExtractMode
class QuestionAndAnswer(BaseModel):
question: str = Field(description="Question related to the main document")
answer: str = Field(description="Answer to the question")
class StudyNotes(BaseModel):
summary: str = Field(description="Summary of the study notes in the document")
faqs: list[QuestionAndAnswer] = Field(description="List of potential 'Frequently Asked Questions' (with associated answer) to help a student review and prepare with the study notes")
model_config = ConfigDict(arbitrary_types_allowed=True)
class WorkflowState(BaseModel):
username: str = ""
file_name: str = ""
file_type: str = ""
file_id: str = ""
EXTRACT_CONFIG = ExtractConfig(
extraction_mode=ExtractMode.MULTIMODAL,
)
@@ -0,0 +1,38 @@
import os
from qdrant_client import AsyncQdrantClient
from llama_cloud_services.beta.classifier import LlamaClassify
from llama_cloud_services.extract import LlamaExtract
from sqlalchemy.ext.asyncio import create_async_engine, AsyncConnection
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
async def get_llama_classify(*args, **kwargs):
return LlamaClassify.from_api_key(api_key=os.getenv("LLAMA_CLOUD_API_KEY", ""))
async def get_llama_extract(*args, **kwargs):
return LlamaExtract(
api_key=os.getenv("LLAMA_CLOUD_API_KEY", ""),
)
async def get_db_conn(*args, **kwargs):
eng = create_async_engine(url=os.getenv("POSTGRES_CONNECTION_STRING", ""))
return AsyncConnection(async_engine=eng)
async def get_vector_db_summaries(*args, **kwargs):
client = AsyncQdrantClient(
api_key=os.getenv("QDRANT_API_KEY"),
https=True,
port=443,
host=os.getenv("QDRANT_HOST"),
check_compatibility=False,
)
return SummaryVectorDB(client=client, collection_name="summaries")
async def get_vector_db_faqs(*args, **kwargs):
client = AsyncQdrantClient(
api_key=os.getenv("QDRANT_API_KEY"),
https=True,
port=443,
host=os.getenv("QDRANT_HOST"),
check_compatibility=False,
)
return FaqsVectorDB(client=client, collection_name="faqs")
@@ -0,0 +1,12 @@
import re
from llama_cloud.types.classifier_rule import ClassifierRule
from study_llama.rulesdb.models import Rule
def rules_to_classify_rules(rules: list[Rule]) -> list[ClassifierRule]:
class_rules: list[ClassifierRule] = []
for rule in rules:
class_rules.append(ClassifierRule(
type=re.sub(r"\s+", "_", rule.rule_type.lower().strip()),
description=rule.rule_description,
))
return class_rules
@@ -0,0 +1,85 @@
import os
from workflows import Workflow, Context, step
from workflows.resource import Resource
from llama_cloud.types.file import File
from llama_cloud.types.extract_run import ExtractRun
from typing import Annotated, TYPE_CHECKING, cast
from study_llama.filesdb.query_files import AsyncQuerier as AsyncFilesQuerier
from study_llama.rulesdb.query_rules import AsyncQuerier as AsyncRulesQuerier
from .events import InputFileEvent, IngestedFileEvent, ClassifiedFileEvent, ExtractedFileEvent
from .resources import get_db_conn, get_llama_classify, get_llama_extract, get_vector_db_faqs, get_vector_db_summaries
from .models import StudyNotes, WorkflowState, EXTRACT_CONFIG
from .utils import rules_to_classify_rules
if TYPE_CHECKING:
from llama_cloud_services.beta.classifier.client import LlamaClassify
from llama_cloud_services.extract import LlamaExtract
from sqlalchemy.ext.asyncio import AsyncConnection
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
class ClassifyExtractWorkflow(Workflow):
@step
async def classify_file(self, ev: InputFileEvent, ctx: Context[WorkflowState], classifier: Annotated[LlamaClassify, Resource(get_llama_classify)], db_conn: Annotated[AsyncConnection, Resource(get_db_conn)]) -> ClassifiedFileEvent | IngestedFileEvent:
querier = AsyncRulesQuerier(conn=db_conn)
response = querier.get_rules(username=ev.username)
rules = []
async for rule in response:
rules.append(rule)
class_rules = rules_to_classify_rules(rules=rules)
result = await classifier.aclassify_file_ids(rules=class_rules, file_ids=[ev.file_id])
file_type: str | None = None
for item in result.items:
if (class_res := item.result) is not None:
if class_res.type is not None:
file_type = class_res.type
break
if file_type is not None:
async with ctx.store.edit_state() as state:
state.file_name = ev.file_name
state.username = ev.username
state.file_id = ev.file_id
state.file_type = file_type
querier_files = AsyncFilesQuerier(conn=db_conn)
await querier_files.create_file(username=ev.username, file_name=ev.file_name, file_category=file_type)
return ClassifiedFileEvent(file_type=file_type)
else:
return IngestedFileEvent(success=False, error="It was not possible to classify the provided file based on the existing categories")
@step
async def extract_file_details(self, ev: ClassifiedFileEvent, extractor: Annotated[LlamaExtract, Resource(get_llama_extract)], ctx: Context[WorkflowState]) -> ExtractedFileEvent | IngestedFileEvent:
state = await ctx.store.get_state()
result = await extractor.aextract(
data_schema=StudyNotes,
config=EXTRACT_CONFIG,
files=File(
id=state.file_id,
name=state.file_name,
project_id=os.getenv("LLAMA_CLOUD_PROJECT_ID"),
data_source_id=None,
created_at=None,
external_file_id=None,
file_size=None,
file_type=None,
last_modified_at=None,
permission_info=None,
resource_info=None,
updated_at=None
),
)
if (data := cast(ExtractRun, result).data) is not None:
extraction_result = StudyNotes.model_validate(data)
return ExtractedFileEvent(summary=extraction_result.summary, faqs=extraction_result.faqs)
else:
return IngestedFileEvent(success=False, error="It was not possible to extract details from the provided file")
@step
async def ingest_file_details(self, ev: ExtractedFileEvent, summaries_vdb: Annotated[SummaryVectorDB, Resource(get_vector_db_summaries)], faqs_vdb: Annotated[FaqsVectorDB, Resource(get_vector_db_faqs)], ctx: Context[WorkflowState]) -> IngestedFileEvent:
state = await ctx.store.get_state()
questions = [faq.question for faq in ev.faqs]
answers = [faq.answer for faq in ev.faqs]
await faqs_vdb.upload(questions, answers, state.username, state.file_type, state.file_name)
await summaries_vdb.upload(ev.summary, state.username, state.file_type, state.file_name)
return IngestedFileEvent(success=True)
workflow = ClassifyExtractWorkflow(timeout=1000)
View File
+12
View File
@@ -0,0 +1,12 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
import pydantic
from typing import Optional
class File(pydantic.BaseModel):
id: int
username: str
file_name: str
file_category: Optional[str]
+36
View File
@@ -0,0 +1,36 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
# source: query.files.sql
from typing import Optional
import sqlalchemy
import sqlalchemy.ext.asyncio
from study_llama.filesdb import models
CREATE_FILE = """-- name: create_file \\:one
INSERT INTO files (
username, file_name, file_category
) VALUES (
:p1, :p2, :p3
)
RETURNING id, username, file_name, file_category
"""
class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn
async def create_file(self, *, username: str, file_name: str, file_category: Optional[str]) -> Optional[models.File]:
row = (await self._conn.execute(sqlalchemy.text(CREATE_FILE), {"p1": username, "p2": file_name, "p3": file_category})).first()
if row is None:
return None
return models.File(
id=row[0],
username=row[1],
file_name=row[2],
file_category=row[3],
)
View File
+12
View File
@@ -0,0 +1,12 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
import pydantic
class Rule(pydantic.BaseModel):
id: int
username: str
rule_name: str
rule_type: str
rule_description: str
+32
View File
@@ -0,0 +1,32 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.30.0
# source: query.rules.sql
from typing import AsyncIterator
import sqlalchemy
import sqlalchemy.ext.asyncio
from study_llama.rulesdb import models
GET_RULES = """-- name: get_rules \\:many
SELECT id, username, rule_name, rule_type, rule_description FROM rules
WHERE username = :p1
"""
class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn
async def get_rules(self, *, username: str) -> AsyncIterator[models.Rule]:
result = await self._conn.stream(sqlalchemy.text(GET_RULES), {"p1": username})
async for row in result:
yield models.Rule(
id=row[0],
username=row[1],
rule_name=row[2],
rule_type=row[3],
rule_description=row[4],
)
View File
+16
View File
@@ -0,0 +1,16 @@
from workflows.events import StartEvent, StopEvent
from typing import Literal
from pydantic import ConfigDict
from study_llama.vectordb.vectordb import Result
class SearchInputEvent(StartEvent):
search_type: Literal["summary", "faqs"]
search_input: str
username: str
file_name: str | None = None
category: str | None = None
class SearchOutputEvent(StopEvent):
results: list[Result]
model_config = ConfigDict(arbitrary_types_allowed=True)
+23
View File
@@ -0,0 +1,23 @@
import os
from qdrant_client import AsyncQdrantClient
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
async def get_vector_db_summaries(*args, **kwargs):
client = AsyncQdrantClient(
api_key=os.getenv("QDRANT_API_KEY"),
https=True,
port=443,
host=os.getenv("QDRANT_HOST"),
check_compatibility=False,
)
return SummaryVectorDB(client=client, collection_name="summaries")
async def get_vector_db_faqs(*args, **kwargs):
client = AsyncQdrantClient(
api_key=os.getenv("QDRANT_API_KEY"),
https=True,
port=443,
host=os.getenv("QDRANT_HOST"),
check_compatibility=False,
)
return FaqsVectorDB(client=client, collection_name="faqs")
+19
View File
@@ -0,0 +1,19 @@
from workflows import Workflow, Context, step
from workflows.resource import Resource
from typing import Annotated, TYPE_CHECKING
from .resources import get_vector_db_faqs, get_vector_db_summaries
from .events import SearchInputEvent, SearchOutputEvent
if TYPE_CHECKING:
from study_llama.vectordb.vectordb import SummaryVectorDB, FaqsVectorDB
class SearchWorkflow(Workflow):
@step
async def search(self, ev: SearchInputEvent, summaries_vdb: Annotated[SummaryVectorDB, Resource(get_vector_db_summaries)], faqs_vdb: Annotated[FaqsVectorDB, Resource(get_vector_db_faqs)]) -> SearchOutputEvent:
if ev.search_type == "faqs":
results = await faqs_vdb.search(ev.search_input, ev.username, ev.category, ev.file_name)
return SearchOutputEvent(results=results)
else:
results = await summaries_vdb.search(ev.search_input, ev.username, ev.category, ev.file_name)
return SearchOutputEvent(results=results)
workflow = SearchWorkflow(timeout=600)
+17
View File
@@ -0,0 +1,17 @@
import os
from openai import AsyncOpenAI
class OpenAIEmbedder:
def __init__(self, client: AsyncOpenAI) -> None:
self._client = client
async def embed(self, texts: list[str]) -> list[list[float]]:
response = await self._client.embeddings.create(
input=texts,
model="text-embedding-3-small",
dimensions=768
)
return [d.embedding for d in response.data]
openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
embedder = OpenAIEmbedder(client=openai_client)
+97
View File
@@ -0,0 +1,97 @@
import uuid
from pydantic import BaseModel
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
from typing import cast, Literal
from .embeddings import embedder
class Result(BaseModel):
result_type: Literal["answer", "summary"]
text: str
file_name: str
category: str
similarity: float
class SummaryVectorDB:
def __init__(self, client: AsyncQdrantClient, collection_name: str):
self._client = client
self.collection_name = collection_name
async def upload(self, summary: str, username: str, category: str, file_name: str) -> None:
vec = await embedder.embed([summary])
point = PointStruct(
id=uuid.uuid4(),
vector=vec[0],
payload={"category": category, "file_name": file_name, "summary": summary, "username": username}
)
self._client.upload_points(self.collection_name, points=[point])
return None
async def search(self, text: str, username: str, category: str | None = None, file_name: str | None = None) -> list[Result]:
filters = Filter(must=[
FieldCondition(
key="username",
match=MatchValue(value=username)
)
])
if category is not None:
(cast(list[FieldCondition], filters.must)).append(
FieldCondition(
key="category",
match=MatchValue(value=category)
)
)
if file_name is not None:
(cast(list[FieldCondition], filters.must)).append(
FieldCondition(
key="file_name",
match=MatchValue(value=file_name)
)
)
vec = await embedder.embed([text])
results = await self._client.query_points(self.collection_name, query=vec[0], query_filter=filters, score_threshold=0.75)
points = results.points
return [Result(text=point.payload["summary"], similarity=point.score, file_name=point.payload["file_name"], category=point.payload["category"], result_type="summary") for point in points if point.payload is not None]
class FaqsVectorDB:
def __init__(self, client: AsyncQdrantClient, collection_name: str):
self._client = client
self.collection_name = collection_name
async def upload(self, questions: list[str], answers: list[str], username: str, category: str, file_name: str) -> None:
vecs = await embedder.embed(questions)
points = []
for i,vec in enumerate(vecs):
points.append(PointStruct(
id=uuid.uuid4(),
vector=vec,
payload={"category": category, "file_name": file_name, "question": questions[i], "answer": answers[i], "username": username}
))
self._client.upload_points(self.collection_name, points=points)
return None
async def search(self, text: str, username: str, category: str | None = None, file_name: str | None = None) -> list[Result]:
filters = Filter(must=[
FieldCondition(
key="username",
match=MatchValue(value=username)
)
])
if category is not None:
(cast(list[FieldCondition], filters.must)).append(
FieldCondition(
key="category",
match=MatchValue(value=category)
)
)
if file_name is not None:
(cast(list[FieldCondition], filters.must)).append(
FieldCondition(
key="file_name",
match=MatchValue(value=file_name)
)
)
vec = await embedder.embed([text])
results = await self._client.query_points(self.collection_name, query=vec[0], query_filter=filters, score_threshold=0.75)
points = results.points
return [Result(text=point.payload["answer"], similarity=point.score, file_name=point.payload["file_name"], category=point.payload["category"], result_type="answer") for point in points if point.payload is not None]
Generated
+1668
View File
File diff suppressed because it is too large Load Diff