diff --git a/README.md b/README.md index 6a2784e..17b365e 100644 --- a/README.md +++ b/README.md @@ -77,13 +77,13 @@ result, err := client.Search(graphiti.SearchQuery{ ```go messages := []graphiti.Message{ { - Content: "Hello, how are you?", - RoleType: graphiti.RoleTypeUser, + Content: "Hello, how are you?", + Author: "User", Timestamp: time.Now(), }, { - Content: "I'm doing great, thank you!", - RoleType: graphiti.RoleTypeAssistant, + Content: "I'm doing great, thank you!", + Author: "Assistant", Timestamp: time.Now(), }, } @@ -130,8 +130,8 @@ fmt.Printf("Created node: %s\n", node.UUID) ```go messages := []graphiti.Message{ { - Content: "What were my settings?", - RoleType: graphiti.RoleTypeUser, + Content: "What were my settings?", + Author: "User", Timestamp: time.Now(), }, } @@ -198,8 +198,7 @@ type Message struct { Content string // The message content UUID *string // Optional UUID Name string // Optional name for episodic node - RoleType RoleType // user, assistant, or system - Role *string // Optional custom role (user name, bot name, etc.) + Author string // The author/entity that created this message Timestamp time.Time // Message timestamp SourceDescription string // Optional source description } diff --git a/client.go b/client.go new file mode 100644 index 0000000..5148a2a --- /dev/null +++ b/client.go @@ -0,0 +1,195 @@ +package graphiti + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +// Client represents a Graphiti API client +type Client struct { + baseURL string + httpClient *http.Client +} + +// ClientOption is a functional option for configuring the Client +type ClientOption func(*Client) + +// WithHTTPClient sets a custom HTTP client +func WithHTTPClient(httpClient *http.Client) ClientOption { + return func(c *Client) { + c.httpClient = httpClient + } +} + +// WithTimeout sets the HTTP client timeout +func WithTimeout(timeout time.Duration) ClientOption { + return func(c *Client) { + c.httpClient.Timeout = timeout + } +} + +// NewClient creates a new Graphiti API client +func NewClient(baseURL string, opts ...ClientOption) *Client { + client := &Client{ + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } + + for _, opt := range opts { + opt(client) + } + + return client +} + +// do performs an HTTP request and decodes the response +func (c *Client) do(method, path string, body interface{}, result interface{}) error { + var reqBody io.Reader + if body != nil { + jsonData, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + reqBody = bytes.NewBuffer(jsonData) + } + + reqURL := c.baseURL + path + req, err := http.NewRequest(method, reqURL, reqBody) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + if result != nil { + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + } + + return nil +} + +// HealthCheck performs a health check on the API +func (c *Client) HealthCheck() (*HealthCheckResponse, error) { + var result HealthCheckResponse + if err := c.do(http.MethodGet, "/healthcheck", nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// Search searches for facts in the graph +func (c *Client) Search(query SearchQuery) (*SearchResults, error) { + var result SearchResults + if err := c.do(http.MethodPost, "/search", query, &result); err != nil { + return nil, err + } + return &result, nil +} + +// GetEntityEdge retrieves a specific entity edge by UUID +func (c *Client) GetEntityEdge(uuid string) (*FactResult, error) { + var result FactResult + path := fmt.Sprintf("/entity-edge/%s", url.PathEscape(uuid)) + if err := c.do(http.MethodGet, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// GetEpisodes retrieves episodes for a group +func (c *Client) GetEpisodes(groupID string, lastN int) ([]Episode, error) { + var result []Episode + path := fmt.Sprintf("/episodes/%s?last_n=%d", url.PathEscape(groupID), lastN) + if err := c.do(http.MethodGet, path, nil, &result); err != nil { + return nil, err + } + return result, nil +} + +// GetMemory retrieves memory based on messages +func (c *Client) GetMemory(request GetMemoryRequest) (*GetMemoryResponse, error) { + var result GetMemoryResponse + if err := c.do(http.MethodPost, "/get-memory", request, &result); err != nil { + return nil, err + } + return &result, nil +} + +// AddMessages adds messages to the graph (asynchronous operation) +func (c *Client) AddMessages(request AddMessagesRequest) (*Result, error) { + var result Result + if err := c.do(http.MethodPost, "/messages", request, &result); err != nil { + return nil, err + } + return &result, nil +} + +// AddEntityNode adds an entity node to the graph +func (c *Client) AddEntityNode(request AddEntityNodeRequest) (*EntityNode, error) { + var result EntityNode + if err := c.do(http.MethodPost, "/entity-node", request, &result); err != nil { + return nil, err + } + return &result, nil +} + +// DeleteEntityEdge deletes an entity edge by UUID +func (c *Client) DeleteEntityEdge(uuid string) (*Result, error) { + var result Result + path := fmt.Sprintf("/entity-edge/%s", url.PathEscape(uuid)) + if err := c.do(http.MethodDelete, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// DeleteGroup deletes a group by ID +func (c *Client) DeleteGroup(groupID string) (*Result, error) { + var result Result + path := fmt.Sprintf("/group/%s", url.PathEscape(groupID)) + if err := c.do(http.MethodDelete, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// DeleteEpisode deletes an episode by UUID +func (c *Client) DeleteEpisode(uuid string) (*Result, error) { + var result Result + path := fmt.Sprintf("/episode/%s", url.PathEscape(uuid)) + if err := c.do(http.MethodDelete, path, nil, &result); err != nil { + return nil, err + } + return &result, nil +} + +// Clear clears all data from the graph +func (c *Client) Clear() (*Result, error) { + var result Result + if err := c.do(http.MethodPost, "/clear", nil, &result); err != nil { + return nil, err + } + return &result, nil +} diff --git a/example/go.mod b/example/go.mod new file mode 100644 index 0000000..b34676d --- /dev/null +++ b/example/go.mod @@ -0,0 +1,11 @@ +module github.com/pentagi/graphiti-go-client/example + +go 1.23 + +replace github.com/pentagi/graphiti-go-client => ../ + +require ( + github.com/google/uuid v1.6.0 + github.com/pentagi/graphiti-go-client v0.0.0-00010101000000-000000000000 +) + diff --git a/example/go.sum b/example/go.sum new file mode 100644 index 0000000..f846c7a --- /dev/null +++ b/example/go.sum @@ -0,0 +1,3 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= + diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..84d91fb --- /dev/null +++ b/example/main.go @@ -0,0 +1,163 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/google/uuid" + graphiti "github.com/pentagi/graphiti-go-client" +) + +// This example demonstrates how to use the Graphiti Go client. +// +// Important: The /messages endpoint processes data asynchronously. This example +// polls for episodes to verify data was successfully created before searching. +// +// Troubleshooting: If you see "No episodes were created" errors: +// 1. Check server logs for "Error executing Neo4j query: Driver closed" +// 2. Ensure Neo4j is running and properly configured +// 3. Verify the Graphiti server has a persistent database connection +// 4. Check that the async worker is processing jobs successfully + +func main() { + // Create a client with extended timeout for long-running operations + client := graphiti.NewClient("http://localhost:8000", graphiti.WithTimeout(60*time.Second)) + + // Health check + fmt.Println("=== Health Check ===") + health, err := client.HealthCheck() + if err != nil { + log.Fatalf("Health check failed: %v", err) + } + fmt.Printf("Status: %s\n\n", health.Status) + + // Create a unique group ID for this example + groupID := uuid.New().String() + fmt.Printf("Using group ID: %s\n\n", groupID) + + // Add messages + fmt.Println("=== Adding Messages ===") + messages := []graphiti.Message{ + { + Content: "I love hiking in the mountains on weekends.", + Author: "Alice", + Timestamp: time.Now().Add(-2 * time.Hour), + }, + { + Content: "That sounds great! Do you have a favorite trail?", + Author: "Assistant", + Timestamp: time.Now().Add(-90 * time.Minute), + }, + { + Content: "Yes, I particularly enjoy the Pacific Crest Trail. I try to go there every summer.", + Author: "Alice", + Timestamp: time.Now().Add(-60 * time.Minute), + }, + } + + addResult, err := client.AddMessages(graphiti.AddMessagesRequest{ + GroupID: groupID, + Messages: messages, + }) + if err != nil { + log.Fatalf("Failed to add messages: %v", err) + } + fmt.Printf("%s: %v\n\n", addResult.Message, addResult.Success) + + // Wait for processing and verify data exists (poll for episodes) + fmt.Println("Waiting for messages to be processed...") + maxAttempts := 10 + pollInterval := 5 * time.Second + var episodes []graphiti.Episode + + for attempt := 1; attempt <= maxAttempts; attempt++ { + fmt.Printf(" Polling for episodes (attempt %d/%d)...\n", attempt, maxAttempts) + episodes, err = client.GetEpisodes(groupID, 10) + if err != nil { + log.Printf(" Warning: Failed to get episodes: %v", err) + } else if len(episodes) > 0 { + fmt.Printf(" ✓ Found %d episodes, processing complete!\n\n", len(episodes)) + break + } + + if attempt < maxAttempts { + time.Sleep(pollInterval) + } + } + + if len(episodes) == 0 { + log.Fatalf("Timeout: No episodes were created after %v. The async job may have failed.", time.Duration(maxAttempts)*pollInterval) + } + + // Search for facts + fmt.Println("=== Searching for Facts ===") + searchResult, err := client.Search(graphiti.SearchQuery{ + Query: "What does the user like to do?", + MaxFacts: 5, + GroupIDs: &[]string{groupID}, + }) + if err != nil { + log.Fatalf("Search failed: %v", err) + } + fmt.Printf("Found %d facts:\n", len(searchResult.Facts)) + for i, fact := range searchResult.Facts { + fmt.Printf("%d. %s\n (from: %s, created: %s)\n", + i+1, fact.Fact, fact.Name, fact.CreatedAt.Format(time.RFC3339)) + } + fmt.Println() + + // Get memory from messages + fmt.Println("=== Getting Memory ===") + memoryMessages := []graphiti.Message{ + { + Content: "What hobbies does the user have?", + Author: "User", + Timestamp: time.Now(), + }, + } + memoryResponse, err := client.GetMemory(graphiti.GetMemoryRequest{ + GroupID: groupID, + MaxFacts: 10, + Messages: memoryMessages, + }) + if err != nil { + log.Fatalf("Failed to get memory: %v", err) + } + fmt.Printf("Retrieved %d facts from memory:\n", len(memoryResponse.Facts)) + for i, fact := range memoryResponse.Facts { + fmt.Printf("%d. %s\n", i+1, fact.Fact) + } + fmt.Println() + + // Add an entity node + fmt.Println("=== Adding Entity Node ===") + entityUUID := uuid.New().String() + node, err := client.AddEntityNode(graphiti.AddEntityNodeRequest{ + UUID: entityUUID, + GroupID: groupID, + Name: "User Interests", + Summary: "The user's hobbies and interests", + }) + if err != nil { + log.Fatalf("Failed to add entity node: %v", err) + } + fmt.Printf("Created entity node: %s (UUID: %s)\n\n", node.Name, node.UUID) + + // Display episodes (already fetched during polling) + fmt.Println("=== Episodes Summary ===") + fmt.Printf("Total episodes: %d\n", len(episodes)) + for i, episode := range episodes { + fmt.Printf("%d. %s: %s\n", i+1, episode.Name, episode.Content) + } + fmt.Println() + + // Cleanup: delete the group + fmt.Println("=== Cleanup ===") + deleteResult, err := client.DeleteGroup(groupID) + if err != nil { + log.Printf("Warning: Failed to delete group: %v", err) + } else { + fmt.Printf("%s: %v\n", deleteResult.Message, deleteResult.Success) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1dd4edc --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/pentagi/graphiti-go-client + +go 1.23 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/types.go b/types.go new file mode 100644 index 0000000..8617d31 --- /dev/null +++ b/types.go @@ -0,0 +1,98 @@ +package graphiti + +import "time" + +// Message represents a message in the system +type Message struct { + Content string `json:"content"` + UUID *string `json:"uuid,omitempty"` + Name string `json:"name,omitempty"` + Author string `json:"author"` + Timestamp time.Time `json:"timestamp"` + SourceDescription string `json:"source_description,omitempty"` +} + +// Result represents a generic result response +type Result struct { + Message string `json:"message"` + Success bool `json:"success"` +} + +// HealthCheckResponse represents the health check response +type HealthCheckResponse struct { + Status string `json:"status"` +} + +// SearchQuery represents a search query request +type SearchQuery struct { + GroupIDs *[]string `json:"group_ids,omitempty"` + Query string `json:"query"` + MaxFacts int `json:"max_facts,omitempty"` +} + +// FactResult represents a fact result from the graph +type FactResult struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Fact string `json:"fact"` + ValidAt *time.Time `json:"valid_at,omitempty"` + InvalidAt *time.Time `json:"invalid_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + ExpiredAt *time.Time `json:"expired_at,omitempty"` +} + +// SearchResults represents the results of a search query +type SearchResults struct { + Facts []FactResult `json:"facts"` +} + +// GetMemoryRequest represents a request to get memory +type GetMemoryRequest struct { + GroupID string `json:"group_id"` + MaxFacts int `json:"max_facts,omitempty"` + CenterNodeUUID *string `json:"center_node_uuid"` + Messages []Message `json:"messages"` +} + +// GetMemoryResponse represents the response from getting memory +type GetMemoryResponse struct { + Facts []FactResult `json:"facts"` +} + +// AddMessagesRequest represents a request to add messages +type AddMessagesRequest struct { + GroupID string `json:"group_id"` + Messages []Message `json:"messages"` +} + +// AddEntityNodeRequest represents a request to add an entity node +type AddEntityNodeRequest struct { + UUID string `json:"uuid"` + GroupID string `json:"group_id"` + Name string `json:"name"` + Summary string `json:"summary,omitempty"` +} + +// EntityNode represents an entity node in the graph +type EntityNode struct { + UUID string `json:"uuid"` + GroupID string `json:"group_id"` + Name string `json:"name"` + Summary string `json:"summary,omitempty"` + CreatedAt time.Time `json:"created_at"` + Labels []string `json:"labels,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// Episode represents an episode in the graph +type Episode struct { + UUID string `json:"uuid"` + GroupID string `json:"group_id"` + Name string `json:"name"` + Content string `json:"content"` + Source string `json:"source"` + SourceDescription string `json:"source_description,omitempty"` + CreatedAt time.Time `json:"created_at"` + ValidAt time.Time `json:"valid_at"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +}