mirror of
https://github.com/GH05TCREW/eidolon.git
synced 2026-07-01 11:55:39 -04:00
first commit
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
.git
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
.mypy_cache
|
||||
.coverage
|
||||
venv
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
eidolon/ui/node_modules
|
||||
eidolon/ui/dist
|
||||
eidolon/ui/build
|
||||
eidolon/ui/.vite
|
||||
eidolon/ui/.turbo
|
||||
node_modules
|
||||
dist
|
||||
build
|
||||
*.log
|
||||
|
||||
# Database volumes
|
||||
postgres_data
|
||||
neo4j_data
|
||||
docker-data
|
||||
|
||||
# Test artifacts
|
||||
.tooling-test
|
||||
@@ -0,0 +1,17 @@
|
||||
# Database connections (required at startup)
|
||||
EIDOLON_NEO4J__URI=bolt://localhost:7687
|
||||
EIDOLON_NEO4J__USER=neo4j
|
||||
EIDOLON_NEO4J__PASSWORD=password
|
||||
EIDOLON_NEO4J__DATABASE=neo4j
|
||||
EIDOLON_POSTGRES__URL=postgresql://postgres:password@localhost:5432/eidolon
|
||||
|
||||
# Docker compose variables
|
||||
NEO4J_USER=neo4j
|
||||
NEO4J_PASSWORD=password
|
||||
NEO4J_DB=neo4j
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=password
|
||||
POSTGRES_DB=eidolon
|
||||
|
||||
# Debug flags
|
||||
EIDOLON_LLM_DEBUG=1
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Ruff cache
|
||||
.ruff_cache/
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
# Pytest
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Database files
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Docker volumes/data
|
||||
postgres_data/
|
||||
neo4j_data/
|
||||
docker-data/
|
||||
|
||||
# Node.js (UI)
|
||||
eidolon/ui/node_modules/
|
||||
eidolon/ui/dist/
|
||||
eidolon/ui/build/
|
||||
eidolon/ui/.vite/
|
||||
eidolon/ui/.turbo/
|
||||
|
||||
# API keys
|
||||
.env.vars
|
||||
|
||||
# Tooling test artifacts
|
||||
.tooling-test/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.bak
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# JetBrains IDEs
|
||||
*.iml
|
||||
.idea/
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pytest
|
||||
.tox/
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
VIRTUAL_ENV=/opt/venv
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends build-essential libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m venv $VIRTUAL_ENV
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
WORKDIR /app
|
||||
COPY pyproject.toml README.md ./
|
||||
COPY eidolon ./eidolon
|
||||
|
||||
RUN pip install --upgrade pip \
|
||||
&& pip install .
|
||||
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||
VIRTUAL_ENV=/opt/venv
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends libpq5 nmap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
WORKDIR /app
|
||||
COPY eidolon ./eidolon
|
||||
COPY pyproject.toml README.md ./
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["uvicorn", "eidolon.api.app:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Masic
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,52 @@
|
||||
# Eidolon
|
||||
|
||||
Network scanner with AI-powered analysis and automation. Scans your infrastructure with nmap, stores it in a graph database (Neo4j), and lets you query and operate on it using natural language. LLM agents generate plans, execute approved actions, and log everything for audit.
|
||||
|
||||
## Features
|
||||
|
||||
- **Network scanning**: Automated nmap scans build a real-time map of hosts, ports, and services
|
||||
- **Graph database**: Neo4j stores assets, networks, and connectivity relationships
|
||||
- **Natural language queries**: Ask "What paths exist from internet to database X?" and get answers
|
||||
- **Plan generation**: LLM translates intents like "isolate host X" into executable steps
|
||||
- **Execution runtime**: Sandboxed tools (terminal, browser, file edit) with permission controls
|
||||
- **Audit trail**: Every scan, query, plan, and execution logged to Postgres
|
||||
- **Interactive UI**: React frontend for graph visualization, chat, and approval workflows
|
||||
|
||||
## Quick Start
|
||||
|
||||
```powershell
|
||||
# Windows PowerShell
|
||||
.\scripts\dev.ps1
|
||||
|
||||
# Linux/macOS
|
||||
chmod +x scripts/dev.sh && ./scripts/dev.sh
|
||||
```
|
||||
|
||||
This starts all services:
|
||||
- Postgres (audit logs, chat history): `localhost:5432`
|
||||
- Neo4j (knowledge graph): `localhost:7474` (browser), `localhost:7687` (bolt)
|
||||
- API server: `http://localhost:8080`
|
||||
- React UI: `http://localhost:5173`
|
||||
|
||||
## Configuration
|
||||
|
||||
Optional: Copy `.env.example` to `.env` and configure:
|
||||
- Database credentials (if changing defaults)
|
||||
- LLM settings: `EIDOLON_LLM__MODEL`, `EIDOLON_LLM__API_KEY`, `EIDOLON_LLM__API_BASE`
|
||||
|
||||
## Development
|
||||
|
||||
**Run tests**:
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
**Lint**:
|
||||
```bash
|
||||
ruff check .
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
name: eidolon
|
||||
services:
|
||||
neo4j:
|
||||
image: neo4j:5.25
|
||||
environment:
|
||||
- NEO4J_AUTH=${NEO4J_USER:-neo4j}/${NEO4J_PASSWORD:-password}
|
||||
- NEO4J_dbms_default__database=${NEO4J_DB:-neo4j}
|
||||
ports:
|
||||
- "7474:7474"
|
||||
- "7687:7687"
|
||||
volumes:
|
||||
- neo4j_data:/data
|
||||
- neo4j_logs:/logs
|
||||
restart: unless-stopped
|
||||
|
||||
postgres:
|
||||
image: postgres:16
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-eidolon}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./eidolon/db/postgres/schema.sql:/docker-entrypoint-initdb.d/schema.sql:ro
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
neo4j_data:
|
||||
neo4j_logs:
|
||||
postgres_data:
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from eidolon.api.dependencies import get_audit_store, get_graph_repository
|
||||
from eidolon.api.handlers import tasks
|
||||
from eidolon.api.middleware.auth import AuthMiddleware
|
||||
from eidolon.api.middleware.rate_limit import RateLimitMiddleware
|
||||
from eidolon.api.routes import (
|
||||
agent,
|
||||
approvals,
|
||||
audit,
|
||||
chat,
|
||||
collector,
|
||||
graph,
|
||||
ingest,
|
||||
permissions,
|
||||
plan,
|
||||
query,
|
||||
)
|
||||
from eidolon.api.routes import settings as settings_router
|
||||
from eidolon.config.settings import get_settings
|
||||
from eidolon.runtime.task_events import task_event_bus
|
||||
from eidolon.worker.retention import RetentionWorker
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="Eidolon API",
|
||||
version="0.1.0",
|
||||
description="Evidence-backed infrastructure graph and agent runtime.",
|
||||
)
|
||||
|
||||
# Add routers first
|
||||
app.include_router(query.router)
|
||||
app.include_router(plan.router)
|
||||
app.include_router(graph.router)
|
||||
app.include_router(chat.router)
|
||||
app.include_router(collector.router)
|
||||
app.include_router(ingest.router)
|
||||
app.include_router(permissions.router)
|
||||
app.include_router(settings_router.router)
|
||||
app.include_router(audit.router)
|
||||
app.include_router(approvals.router)
|
||||
app.include_router(agent.router)
|
||||
app.include_router(tasks.router)
|
||||
|
||||
# Add middleware in reverse order (they execute in reverse)
|
||||
# CORS must be added LAST so it executes FIRST
|
||||
app.add_middleware(RateLimitMiddleware, capacity=300, window_seconds=60)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CORS middleware LAST = executes FIRST
|
||||
# For development: allow all origins without credentials
|
||||
# For production: specify exact origins in settings.api.cors_origins
|
||||
if "*" in settings.api.cors_origins:
|
||||
# Development mode: wildcard origins, no credentials
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.api.cors_origins,
|
||||
allow_credentials=False,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
# Production mode: specific origins, allow credentials
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.api.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/healthz")
|
||||
def health() -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup() -> None:
|
||||
# Start retention worker to clean up old audit events
|
||||
audit_store = get_audit_store()
|
||||
retention_worker = RetentionWorker(audit_store, retention_days=90)
|
||||
app.state.retention_task = asyncio.create_task(
|
||||
retention_worker.run_forever(interval_hours=24)
|
||||
)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown() -> None:
|
||||
# Signal task event bus to shutdown streaming connections
|
||||
task_event_bus.shutdown()
|
||||
|
||||
# Close graph repository
|
||||
with suppress(Exception):
|
||||
repo = get_graph_repository()
|
||||
close = getattr(repo, "close", None)
|
||||
if close:
|
||||
close()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from eidolon.config.settings import get_settings
|
||||
from eidolon.core.graph.neo4j import Neo4jGraphRepository
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.reasoning.entity import EntityResolver
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.stores import (
|
||||
ApprovalStore,
|
||||
AuditStore,
|
||||
ChatStore,
|
||||
InMemoryApprovalStore,
|
||||
InMemoryAuditStore,
|
||||
InMemoryChatStore,
|
||||
InMemoryScannerStore,
|
||||
InMemorySettingsStore,
|
||||
ScannerStore,
|
||||
SettingsStore,
|
||||
)
|
||||
from eidolon.db.postgres.store import (
|
||||
PostgresApprovalStore,
|
||||
PostgresAuditStore,
|
||||
PostgresChatStore,
|
||||
PostgresScannerStore,
|
||||
PostgresSettingsStore,
|
||||
postgres_available,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_graph_repository() -> GraphRepository:
|
||||
return Neo4jGraphRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_entity_resolver() -> EntityResolver:
|
||||
return EntityResolver()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_llm_client() -> LiteLLMClient:
|
||||
store = get_settings_store()
|
||||
app_settings = store.get_app_settings()
|
||||
return LiteLLMClient(settings=app_settings.llm)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_audit_store() -> AuditStore:
|
||||
settings = get_settings()
|
||||
fallback = InMemoryAuditStore()
|
||||
if postgres_available():
|
||||
return PostgresAuditStore(settings.postgres.url, fallback=fallback)
|
||||
return fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_approval_store() -> ApprovalStore:
|
||||
settings = get_settings()
|
||||
fallback = InMemoryApprovalStore()
|
||||
if postgres_available():
|
||||
return PostgresApprovalStore(settings.postgres.url, fallback=fallback)
|
||||
return fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_chat_store() -> ChatStore:
|
||||
settings = get_settings()
|
||||
fallback = InMemoryChatStore()
|
||||
if postgres_available():
|
||||
return PostgresChatStore(settings.postgres.url, fallback=fallback)
|
||||
return fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings_store() -> SettingsStore:
|
||||
settings = get_settings()
|
||||
if postgres_available():
|
||||
return PostgresSettingsStore(settings.postgres.url)
|
||||
# Fallback: return in-memory that always returns defaults
|
||||
store = InMemorySettingsStore()
|
||||
store.update_settings(settings.sandbox)
|
||||
return store
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_scanner_store() -> ScannerStore:
|
||||
settings = get_settings()
|
||||
fallback = InMemoryScannerStore()
|
||||
if postgres_available():
|
||||
return PostgresScannerStore(settings.postgres.url, fallback=fallback)
|
||||
return fallback
|
||||
|
||||
|
||||
def require_roles(*roles: str):
|
||||
def _dependency(request: Request):
|
||||
auth_error = getattr(request.state, "auth_error", None)
|
||||
if auth_error:
|
||||
raise HTTPException(status_code=401, detail=str(auth_error))
|
||||
identity = getattr(request.state, "identity", None)
|
||||
if not identity:
|
||||
raise HTTPException(status_code=403, detail="missing identity")
|
||||
if not any(identity.has_role(role) for role in roles):
|
||||
raise HTTPException(status_code=403, detail="insufficient role")
|
||||
return identity
|
||||
|
||||
return _dependency
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from eidolon.api.dependencies import require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.runtime.task_events import task_event_bus
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
||||
_TASK_STREAM_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
|
||||
|
||||
async def _stream() -> AsyncGenerator[bytes, None]:
|
||||
subscriber = task_event_bus.subscribe_async()
|
||||
try:
|
||||
# Send history first
|
||||
for event in task_event_bus.history():
|
||||
payload = json.dumps(event.to_payload())
|
||||
yield f"data: {payload}\n\n".encode()
|
||||
|
||||
# Stream live events with proper cancellation support
|
||||
while True:
|
||||
try:
|
||||
# Wait for event with timeout for keepalive
|
||||
event = await asyncio.wait_for(subscriber.get(), timeout=15.0)
|
||||
|
||||
# None is shutdown sentinel
|
||||
if event is None:
|
||||
break
|
||||
|
||||
payload = json.dumps(event.to_payload())
|
||||
yield f"data: {payload}\n\n".encode()
|
||||
except TimeoutError:
|
||||
# Send keepalive to prevent client timeout
|
||||
yield b": keepalive\n\n"
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected or server shutting down - clean exit
|
||||
pass
|
||||
finally:
|
||||
task_event_bus.unsubscribe_async(subscriber)
|
||||
|
||||
|
||||
@router.get("/stream")
|
||||
async def task_stream(identity: IdentityContext = _TASK_STREAM_IDENTITY) -> StreamingResponse:
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from eidolon.config.settings import AuthSettings, get_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class IdentityContext:
|
||||
user_id: str
|
||||
roles: list[str] = field(default_factory=lambda: ["viewer"])
|
||||
claims: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def has_role(self, role: str) -> bool:
|
||||
return role in self.roles
|
||||
|
||||
|
||||
class AuthError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _parse_roles(value: Any) -> list[str]:
|
||||
if isinstance(value, list):
|
||||
return [str(role).strip() for role in value if str(role).strip()]
|
||||
if isinstance(value, str):
|
||||
normalized = value.replace(",", " ")
|
||||
return [role for role in (item.strip() for item in normalized.split()) if role]
|
||||
return []
|
||||
|
||||
|
||||
def _b64url_decode(segment: str) -> bytes:
|
||||
padded = segment + "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(padded.encode("utf-8"))
|
||||
|
||||
|
||||
def extract_bearer_token(headers: Mapping[str, str]) -> str | None:
|
||||
auth_header = headers.get("authorization") or headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
parts = auth_header.split()
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1].strip()
|
||||
return None
|
||||
|
||||
|
||||
def _verify_jwt(token: str, settings: AuthSettings) -> dict[str, Any]:
|
||||
if not settings.jwt_secret:
|
||||
raise AuthError("JWT secret not configured")
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
raise AuthError("invalid token format")
|
||||
try:
|
||||
header = json.loads(_b64url_decode(parts[0]).decode("utf-8"))
|
||||
payload = json.loads(_b64url_decode(parts[1]).decode("utf-8"))
|
||||
except (binascii.Error, UnicodeDecodeError, json.JSONDecodeError, ValueError) as exc:
|
||||
raise AuthError("invalid token payload") from exc
|
||||
if header.get("alg") != "HS256":
|
||||
raise AuthError("unsupported JWT algorithm")
|
||||
signing_input = f"{parts[0]}.{parts[1]}".encode()
|
||||
expected_sig = hmac.new(
|
||||
settings.jwt_secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
actual_sig = _b64url_decode(parts[2])
|
||||
if not hmac.compare_digest(actual_sig, expected_sig):
|
||||
raise AuthError("invalid token signature")
|
||||
now = int(time.time())
|
||||
exp = payload.get("exp")
|
||||
if isinstance(exp, (int, float)) and int(exp) < now:
|
||||
raise AuthError("token expired")
|
||||
nbf = payload.get("nbf")
|
||||
if isinstance(nbf, (int, float)) and int(nbf) > now:
|
||||
raise AuthError("token not yet valid")
|
||||
if settings.jwt_issuer and payload.get("iss") != settings.jwt_issuer:
|
||||
raise AuthError("token issuer mismatch")
|
||||
if settings.jwt_audience:
|
||||
aud = payload.get("aud")
|
||||
if isinstance(aud, list):
|
||||
if settings.jwt_audience not in [str(item) for item in aud]:
|
||||
raise AuthError("token audience mismatch")
|
||||
elif aud is None or str(aud) != settings.jwt_audience:
|
||||
raise AuthError("token audience mismatch")
|
||||
return payload
|
||||
|
||||
|
||||
def resolve_identity(
|
||||
headers: Mapping[str, str],
|
||||
settings: AuthSettings,
|
||||
token: str | None = None,
|
||||
) -> tuple[IdentityContext | None, str | None]:
|
||||
if settings.mode == "none":
|
||||
return IdentityContext(user_id="anonymous", roles=["viewer", "planner", "executor"]), None
|
||||
if settings.mode == "header":
|
||||
user_id = headers.get(settings.header_user_id, "anonymous")
|
||||
roles_header = headers.get(settings.header_roles, "viewer")
|
||||
roles = _parse_roles(roles_header) or ["viewer"]
|
||||
return IdentityContext(user_id=user_id, roles=roles), None
|
||||
|
||||
bearer = token or extract_bearer_token(headers)
|
||||
if not bearer:
|
||||
return None, "missing bearer token"
|
||||
try:
|
||||
claims = _verify_jwt(bearer, settings)
|
||||
except AuthError as exc:
|
||||
return None, str(exc)
|
||||
roles = _parse_roles(claims.get("roles") or claims.get("role") or claims.get("scope"))
|
||||
if not roles:
|
||||
roles = ["viewer"]
|
||||
user_id = str(claims.get("sub") or claims.get("user_id") or claims.get("uid") or "anonymous")
|
||||
return IdentityContext(user_id=user_id, roles=roles, claims=claims), None
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Attach identity to request.state using configured auth mode."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
settings = get_settings().auth
|
||||
identity, error = resolve_identity(request.headers, settings)
|
||||
if identity:
|
||||
request.state.identity = identity
|
||||
if error:
|
||||
request.state.auth_error = error
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class SlidingWindowLimiter:
|
||||
def __init__(self, capacity: int, window_seconds: int) -> None:
|
||||
self.capacity = capacity
|
||||
self.window_seconds = window_seconds
|
||||
self.buckets: dict[str, tuple[int, float]] = {}
|
||||
|
||||
def allow(self, key: str) -> bool:
|
||||
now = time.time()
|
||||
count, reset = self.buckets.get(key, (0, now + self.window_seconds))
|
||||
if now > reset:
|
||||
count = 0
|
||||
reset = now + self.window_seconds
|
||||
if count >= self.capacity:
|
||||
self.buckets[key] = (count, reset)
|
||||
return False
|
||||
self.buckets[key] = (count + 1, reset)
|
||||
return True
|
||||
|
||||
def reset_at(self, key: str) -> float:
|
||||
return self.buckets.get(key, (0, time.time() + self.window_seconds))[1]
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Simple per-identity sliding window limiter for API routes.
|
||||
Replace with Redis-backed limiter in production.
|
||||
"""
|
||||
|
||||
def __init__(self, app, capacity: int = 60, window_seconds: int = 60) -> None:
|
||||
super().__init__(app)
|
||||
self.limiter = SlidingWindowLimiter(capacity, window_seconds)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
identity = getattr(request.state, "identity", None)
|
||||
key = identity.user_id if identity else request.client.host
|
||||
if not self.limiter.allow(key):
|
||||
reset_at = self.limiter.reset_at(key)
|
||||
retry_after = max(1, int(reset_at - time.time()))
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "rate limit exceeded"},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
return await call_next(request)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_approval_store,
|
||||
get_graph_repository,
|
||||
get_llm_client,
|
||||
)
|
||||
from eidolon.api.middleware.auth import extract_bearer_token, resolve_identity
|
||||
from eidolon.config.settings import get_settings
|
||||
from eidolon.core.graph.algorithms import blast_radius
|
||||
from eidolon.core.models.plan import (
|
||||
EntityRef,
|
||||
ExecutionRequest,
|
||||
ExecutionResponse,
|
||||
ToolExecutionResult,
|
||||
)
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.planner import Planner
|
||||
from eidolon.core.stores import ApprovalStore
|
||||
from eidolon.runtime.executor import ExecutionEngine
|
||||
|
||||
|
||||
class AgentPlanRequest(BaseModel):
|
||||
intent: str
|
||||
target: EntityRef
|
||||
|
||||
|
||||
class AgentRunRequest(BaseModel):
|
||||
intent: str
|
||||
target: EntityRef | None = None
|
||||
dry_run: bool = True
|
||||
approval_token: str | None = None
|
||||
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def agent_ws(websocket: WebSocket) -> None:
|
||||
settings = get_settings()
|
||||
token = extract_bearer_token(websocket.headers)
|
||||
token = (
|
||||
token or websocket.query_params.get("token") or websocket.query_params.get("access_token")
|
||||
)
|
||||
identity, error = resolve_identity(websocket.headers, settings.auth, token=token)
|
||||
if error or not identity:
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
if not identity.has_role("executor"):
|
||||
await websocket.close(code=4403)
|
||||
return
|
||||
await websocket.accept()
|
||||
approval_store: ApprovalStore = get_approval_store()
|
||||
repository = get_graph_repository()
|
||||
llm_client: LiteLLMClient = get_llm_client()
|
||||
planner = Planner(llm_client=llm_client)
|
||||
|
||||
def _execute_request(request: ExecutionRequest) -> ExecutionResponse:
|
||||
needs_approval = request.requires_approval or any(
|
||||
step.requires_approval for step in request.steps
|
||||
)
|
||||
if not request.dry_run and needs_approval:
|
||||
if not request.approval_token:
|
||||
raise RuntimeError("approval token required for execution")
|
||||
approval = approval_store.get_by_token(request.approval_token)
|
||||
if not approval or approval.action != "execute":
|
||||
raise RuntimeError("invalid approval token")
|
||||
|
||||
engine = ExecutionEngine(repository, runtime_settings=settings.sandbox)
|
||||
results: list[ToolExecutionResult] = []
|
||||
for step in request.steps:
|
||||
results.append(engine.execute_step(step, dry_run=request.dry_run))
|
||||
status = "ok" if all(result.status != "error" for result in results) else "partial_failure"
|
||||
return ExecutionResponse(request=request, results=results, status=status)
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "connected", "status": "ok"})
|
||||
async for message in websocket.iter_text():
|
||||
request_id = None
|
||||
try:
|
||||
payload = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "status": "error", "error": "invalid json"}
|
||||
)
|
||||
continue
|
||||
if isinstance(payload, dict):
|
||||
request_id = payload.get("request_id") or payload.get("id")
|
||||
message_type = payload.get("type") if isinstance(payload, dict) else None
|
||||
message_type = message_type or (
|
||||
payload.get("action") if isinstance(payload, dict) else None
|
||||
)
|
||||
data = payload.get("payload") if isinstance(payload, dict) else None
|
||||
data = data or (payload.get("request") if isinstance(payload, dict) else payload)
|
||||
response: dict[str, object]
|
||||
try:
|
||||
if message_type == "plan":
|
||||
request = AgentPlanRequest.model_validate(data)
|
||||
steps = planner.generate_plan(intent=request.intent, target=request.target)
|
||||
radius = (
|
||||
blast_radius(repository, [request.target.entity_id], depth=2)
|
||||
if request.target.entity_id
|
||||
else None
|
||||
)
|
||||
response = {
|
||||
"type": "plan",
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"steps": [step.model_dump() for step in steps],
|
||||
"blast_radius": radius.model_dump() if radius else None,
|
||||
},
|
||||
}
|
||||
elif message_type == "execute":
|
||||
request = ExecutionRequest.model_validate(data)
|
||||
exec_response = _execute_request(request)
|
||||
response = {
|
||||
"type": "execute",
|
||||
"status": "ok",
|
||||
"data": exec_response.model_dump(),
|
||||
}
|
||||
elif message_type == "run":
|
||||
request = AgentRunRequest.model_validate(data)
|
||||
steps = planner.generate_plan(
|
||||
intent=request.intent,
|
||||
target=request.target
|
||||
or EntityRef(entity_type="Asset", display_name="unknown"),
|
||||
)
|
||||
exec_request = ExecutionRequest(
|
||||
dry_run=request.dry_run,
|
||||
steps=steps,
|
||||
approval_token=request.approval_token,
|
||||
requires_approval=any(step.requires_approval for step in steps),
|
||||
)
|
||||
if exec_request.dry_run:
|
||||
response = {
|
||||
"type": "run",
|
||||
"status": "ok",
|
||||
"data": {"steps": [step.model_dump() for step in steps]},
|
||||
}
|
||||
else:
|
||||
exec_response = _execute_request(exec_request)
|
||||
response = {
|
||||
"type": "run",
|
||||
"status": "ok",
|
||||
"data": exec_response.model_dump(),
|
||||
}
|
||||
elif message_type == "ping":
|
||||
response = {"type": "pong", "status": "ok"}
|
||||
else:
|
||||
response = {
|
||||
"type": "error",
|
||||
"status": "error",
|
||||
"error": "unsupported message type",
|
||||
}
|
||||
except (RuntimeError, ValidationError) as exc:
|
||||
response = {"type": "error", "status": "error", "error": str(exc)}
|
||||
|
||||
if request_id is not None:
|
||||
response["request_id"] = request_id
|
||||
await websocket.send_json(response)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import get_approval_store, require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.models.approval import ApprovalRecord
|
||||
from eidolon.core.stores import ApprovalStore
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
action: str = Field(description="Action name requiring approval")
|
||||
ttl_seconds: int = Field(default=900, ge=60, le=86400)
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
token: str
|
||||
action: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
router = APIRouter(prefix="/approvals", tags=["approvals"])
|
||||
_APPROVAL_STORE = Depends(get_approval_store)
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
@router.post("/", response_model=ApprovalResponse)
|
||||
def create_approval(
|
||||
request: ApprovalRequest,
|
||||
store: ApprovalStore = _APPROVAL_STORE,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> ApprovalResponse:
|
||||
approval: ApprovalRecord = store.create(
|
||||
user_id=identity.user_id,
|
||||
action=request.action,
|
||||
ttl_seconds=request.ttl_seconds,
|
||||
)
|
||||
return ApprovalResponse(
|
||||
token=approval.token, action=approval.action, expires_at=approval.expires_at
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from eidolon.api.dependencies import get_audit_store, require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.models.event import AuditEvent
|
||||
from eidolon.core.stores import AuditStore
|
||||
|
||||
router = APIRouter(prefix="/audit", tags=["audit"])
|
||||
_AUDIT_STORE = Depends(get_audit_store)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
_PAGE_QUERY = Query(1, ge=1, description="Page number (1-indexed)")
|
||||
_PAGE_SIZE_QUERY = Query(50, ge=1, le=500, description="Events per page")
|
||||
_EVENT_TYPE_QUERY = Query(None, description="Filter by event type")
|
||||
_START_DATE_QUERY = Query(None, description="Filter events after this date")
|
||||
_END_DATE_QUERY = Query(None, description="Filter events before this date")
|
||||
|
||||
|
||||
class AuditListResponse(BaseModel):
|
||||
events: list[AuditEvent]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AuditClearResponse(BaseModel):
|
||||
status: str
|
||||
deleted: int
|
||||
|
||||
|
||||
@router.get("/", response_model=AuditListResponse)
|
||||
def list_events(
|
||||
page: int = _PAGE_QUERY,
|
||||
page_size: int = _PAGE_SIZE_QUERY,
|
||||
event_type: str | None = _EVENT_TYPE_QUERY,
|
||||
start_date: datetime | None = _START_DATE_QUERY,
|
||||
end_date: datetime | None = _END_DATE_QUERY,
|
||||
store: AuditStore = _AUDIT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> AuditListResponse:
|
||||
events = store.list_filtered(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
event_type=event_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
total = store.count_filtered(
|
||||
event_type=event_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
has_more = (page * page_size) < total
|
||||
|
||||
return AuditListResponse(
|
||||
events=events,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{audit_id}", response_model=AuditEvent)
|
||||
def get_event(
|
||||
audit_id: UUID,
|
||||
store: AuditStore = _AUDIT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> AuditEvent:
|
||||
event = store.get(audit_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="audit event not found")
|
||||
return event
|
||||
|
||||
|
||||
@router.delete("/", response_model=AuditClearResponse)
|
||||
def clear_events(
|
||||
store: AuditStore = _AUDIT_STORE,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> AuditClearResponse:
|
||||
deleted = store.delete_older_than(datetime.max)
|
||||
return AuditClearResponse(status="cleared", deleted=deleted)
|
||||
@@ -0,0 +1,465 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_chat_store,
|
||||
get_graph_repository,
|
||||
get_llm_client,
|
||||
require_roles,
|
||||
)
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.chat import ChatMessage, ChatSession
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.stores import ChatStore
|
||||
from eidolon.runtime.assistant import AssistantAgent, build_system_prompt
|
||||
from eidolon.runtime.sandbox import SandboxRuntime
|
||||
from eidolon.runtime.tools.browser import BrowserTool
|
||||
from eidolon.runtime.tools.file_edit import FileEditTool
|
||||
from eidolon.runtime.tools.finish import FinishTool
|
||||
from eidolon.runtime.tools.graph_query import GraphQueryTool
|
||||
from eidolon.runtime.tools.terminal import TerminalTool
|
||||
from eidolon.runtime.tools.thinking import ThinkingTool
|
||||
from eidolon.runtime.tools.todo import TodoTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
_CHAT_STORE = Depends(get_chat_store)
|
||||
_LLM_CLIENT = Depends(get_llm_client)
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
def _build_sandbox(
|
||||
repository: GraphRepository, settings_store: ChatStore | None = None
|
||||
) -> SandboxRuntime:
|
||||
"""Build the sandbox runtime with all tools and permission checking."""
|
||||
from eidolon.api.dependencies import get_settings_store
|
||||
|
||||
# Get runtime permissions from database if available, otherwise use config file defaults
|
||||
if settings_store is None:
|
||||
settings_store = get_settings_store()
|
||||
|
||||
sandbox_settings = settings_store.get_settings()
|
||||
runtime = SandboxRuntime(settings=sandbox_settings)
|
||||
|
||||
# Register all available tools
|
||||
runtime.register_tool(TerminalTool())
|
||||
runtime.register_tool(BrowserTool())
|
||||
runtime.register_tool(FileEditTool())
|
||||
runtime.register_tool(ThinkingTool())
|
||||
runtime.register_tool(TodoTool())
|
||||
runtime.register_tool(FinishTool())
|
||||
runtime.register_tool(GraphQueryTool(repository))
|
||||
|
||||
return runtime
|
||||
|
||||
|
||||
def _find_last_request_id(messages: list[ChatMessage]) -> str | None:
|
||||
for msg in reversed(messages):
|
||||
request_id = msg.metadata.get("request_id")
|
||||
if isinstance(request_id, str) and request_id:
|
||||
return request_id
|
||||
return None
|
||||
|
||||
|
||||
def _is_cancelled_message(msg: ChatMessage) -> bool:
|
||||
return msg.role == "assistant" and msg.metadata.get("cancelled") is True
|
||||
|
||||
|
||||
def _append_cancelled_tool_responses(
|
||||
store: ChatStore,
|
||||
session_id: UUID,
|
||||
user_id: str,
|
||||
request_id: str | None,
|
||||
tool_calls: list[Any],
|
||||
) -> None:
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
tool_call_id = call.get("id")
|
||||
if not tool_call_id:
|
||||
continue
|
||||
metadata = {"tool_call_id": tool_call_id, "tool_name": call.get("name", "unknown")}
|
||||
if request_id:
|
||||
metadata["request_id"] = request_id
|
||||
tool_response = ChatMessage(
|
||||
role="tool",
|
||||
content="Cancelled by user",
|
||||
metadata=metadata,
|
||||
)
|
||||
store.append_message(session_id, tool_response, user_id=user_id)
|
||||
|
||||
|
||||
def _append_cancellation_message(
|
||||
store: ChatStore,
|
||||
session_id: UUID,
|
||||
user_id: str,
|
||||
request_id: str | None,
|
||||
reason: str,
|
||||
) -> None:
|
||||
metadata: dict[str, Any] = {"kind": "internal", "cancelled": True}
|
||||
if request_id:
|
||||
metadata["request_id"] = request_id
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=reason,
|
||||
metadata=metadata,
|
||||
)
|
||||
store.append_message(session_id, assistant_message, user_id=user_id)
|
||||
|
||||
|
||||
def _finalize_cancelled_request(
|
||||
session: ChatSession | None,
|
||||
store: ChatStore,
|
||||
user_id: str,
|
||||
request_id: str | None,
|
||||
reason: str,
|
||||
) -> None:
|
||||
if not session or not session.messages:
|
||||
return
|
||||
last_msg = session.messages[-1]
|
||||
if _is_cancelled_message(last_msg):
|
||||
return
|
||||
if last_msg.role == "assistant" and "tool_calls" in last_msg.metadata:
|
||||
tool_calls = last_msg.metadata.get("tool_calls", [])
|
||||
if isinstance(tool_calls, list):
|
||||
_append_cancelled_tool_responses(
|
||||
store, session.session_id, user_id, request_id, tool_calls
|
||||
)
|
||||
_append_cancellation_message(store, session.session_id, user_id, request_id, reason)
|
||||
|
||||
|
||||
def _auto_cancel_pending_request(
|
||||
session: ChatSession,
|
||||
store: ChatStore,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
if not session.messages:
|
||||
return
|
||||
last_msg = session.messages[-1]
|
||||
if last_msg.role == "assistant" and "tool_calls" not in last_msg.metadata:
|
||||
return
|
||||
if _is_cancelled_message(last_msg):
|
||||
return
|
||||
request_id = _find_last_request_id(session.messages)
|
||||
_finalize_cancelled_request(
|
||||
session,
|
||||
store,
|
||||
user_id,
|
||||
request_id,
|
||||
"Previous request cancelled by a newer message.",
|
||||
)
|
||||
|
||||
|
||||
class CancellationToken:
|
||||
def __init__(self, session_id: UUID, request_id: str, registry: CancellationRegistry) -> None:
|
||||
self._session_id = session_id
|
||||
self._request_id = request_id
|
||||
self._registry = registry
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._registry.is_cancelled(self._session_id, self._request_id)
|
||||
|
||||
def set(self) -> None:
|
||||
self._registry.cancel(self._session_id, self._request_id)
|
||||
|
||||
|
||||
class CancellationRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._events: dict[str, threading.Event] = {}
|
||||
|
||||
def register(self, session_id: UUID, request_id: str) -> CancellationToken:
|
||||
token_key = self._token_key(session_id, request_id)
|
||||
with self._lock:
|
||||
event = self._events.get(token_key)
|
||||
if event is None:
|
||||
event = threading.Event()
|
||||
self._events[token_key] = event
|
||||
return CancellationToken(session_id, request_id, self)
|
||||
|
||||
def cancel(self, session_id: UUID, request_id: str) -> bool:
|
||||
found = False
|
||||
token_key = self._token_key(session_id, request_id)
|
||||
with self._lock:
|
||||
event = self._events.get(token_key)
|
||||
if event is not None:
|
||||
event.set()
|
||||
found = True
|
||||
return found
|
||||
|
||||
def is_cancelled(self, session_id: UUID, request_id: str) -> bool:
|
||||
token_key = self._token_key(session_id, request_id)
|
||||
with self._lock:
|
||||
event = self._events.get(token_key)
|
||||
if event is not None and event.is_set():
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self, session_id: UUID, request_id: str) -> None:
|
||||
token_key = self._token_key(session_id, request_id)
|
||||
with self._lock:
|
||||
self._events.pop(token_key, None)
|
||||
|
||||
def _token_key(self, session_id: UUID, request_id: str) -> str:
|
||||
return f"{session_id}:{request_id}"
|
||||
|
||||
|
||||
_cancellation_registry = CancellationRegistry()
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatSessionSummary(BaseModel):
|
||||
session_id: UUID
|
||||
title: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
message_count: int
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
role: Literal["user", "assistant", "system"] = Field(default="user")
|
||||
content: str
|
||||
metadata: dict[str, Any] | None = Field(default=None)
|
||||
request_id: str | None = Field(default=None)
|
||||
|
||||
|
||||
class BulkDeleteResponse(BaseModel):
|
||||
status: str
|
||||
deleted: int
|
||||
|
||||
|
||||
class CancelChatRequest(BaseModel):
|
||||
request_id: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=list[ChatSessionSummary])
|
||||
def list_sessions(
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> list[ChatSessionSummary]:
|
||||
sessions = store.list_sessions(limit=50, user_id=identity.user_id)
|
||||
return [
|
||||
ChatSessionSummary(
|
||||
session_id=session.session_id,
|
||||
title=session.title,
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
message_count=len(session.messages),
|
||||
)
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=ChatSession)
|
||||
def create_session(
|
||||
request: CreateSessionRequest,
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> ChatSession:
|
||||
return store.create_session(title=request.title, user_id=identity.user_id)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=ChatSession)
|
||||
def get_session(
|
||||
session_id: UUID,
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> ChatSession:
|
||||
session = store.get_session(session_id, user_id=identity.user_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="session not found")
|
||||
return session
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def delete_session(
|
||||
session_id: UUID,
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> dict[str, str]:
|
||||
deleted = store.delete_session(session_id, user_id=identity.user_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="session not found")
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.delete("/sessions", response_model=BulkDeleteResponse)
|
||||
def delete_sessions(
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> BulkDeleteResponse:
|
||||
deleted = 0
|
||||
while True:
|
||||
sessions = store.list_sessions(limit=200, user_id=identity.user_id)
|
||||
if not sessions:
|
||||
break
|
||||
for session in sessions:
|
||||
if store.delete_session(session.session_id, user_id=identity.user_id):
|
||||
deleted += 1
|
||||
return BulkDeleteResponse(status="deleted", deleted=deleted)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/messages", response_model=ChatSession)
|
||||
def add_message(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
payload: ChatMessageRequest,
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
llm_client: LiteLLMClient = _LLM_CLIENT,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
stream: bool = False,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> ChatSession | StreamingResponse:
|
||||
from eidolon.api.dependencies import get_settings_store
|
||||
|
||||
request_id = payload.request_id
|
||||
if stream and not request_id:
|
||||
request_id = f"req_{uuid4()}"
|
||||
metadata = dict(payload.metadata or {})
|
||||
if request_id:
|
||||
metadata["request_id"] = request_id
|
||||
if payload.role == "user":
|
||||
existing_session = store.get_session(session_id, user_id=identity.user_id)
|
||||
if not existing_session:
|
||||
raise HTTPException(status_code=404, detail="session not found")
|
||||
_auto_cancel_pending_request(existing_session, store, identity.user_id)
|
||||
message = ChatMessage(
|
||||
role=payload.role,
|
||||
content=payload.content,
|
||||
metadata=metadata,
|
||||
)
|
||||
session = store.append_message(session_id, message, user_id=identity.user_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="session not found")
|
||||
|
||||
# If user message, run the agent loop to generate response
|
||||
if payload.role == "user" and llm_client.is_available():
|
||||
settings_store = get_settings_store()
|
||||
sandbox = _build_sandbox(repository, settings_store)
|
||||
system_prompt = build_system_prompt(
|
||||
sandbox.active_tools.values(), sandbox.settings, repository
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
llm_client=llm_client,
|
||||
sandbox=sandbox,
|
||||
system_prompt=system_prompt,
|
||||
max_iterations=10,
|
||||
)
|
||||
if stream:
|
||||
cancellation_token = (
|
||||
_cancellation_registry.register(session_id, request_id) if request_id else None
|
||||
)
|
||||
|
||||
def event_stream():
|
||||
cancelled = False
|
||||
try:
|
||||
for msg in agent.run_iter(
|
||||
session.messages, cancellation_token=cancellation_token
|
||||
):
|
||||
if cancellation_token and cancellation_token.is_set():
|
||||
cancelled = True
|
||||
break
|
||||
if anyio.from_thread.run(request.is_disconnected):
|
||||
cancelled = True
|
||||
if cancellation_token:
|
||||
cancellation_token.set()
|
||||
break
|
||||
if request_id:
|
||||
msg.metadata["request_id"] = request_id
|
||||
stored = store.append_message(session_id, msg, user_id=identity.user_id)
|
||||
if stored:
|
||||
payload = {
|
||||
"type": "message",
|
||||
"message": msg.model_dump(mode="json"),
|
||||
}
|
||||
yield json.dumps(payload) + "\n"
|
||||
except Exception as e:
|
||||
logger.exception("Agent loop failed")
|
||||
error_metadata = {"kind": "error"}
|
||||
if request_id:
|
||||
error_metadata["request_id"] = request_id
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=f"I encountered an error: {e}",
|
||||
metadata=error_metadata,
|
||||
)
|
||||
store.append_message(session_id, assistant_message, user_id=identity.user_id)
|
||||
payload = {
|
||||
"type": "message",
|
||||
"message": assistant_message.model_dump(mode="json"),
|
||||
}
|
||||
yield json.dumps(payload) + "\n"
|
||||
finally:
|
||||
if request_id:
|
||||
was_cancelled = cancelled or (
|
||||
cancellation_token and cancellation_token.is_set()
|
||||
)
|
||||
_cancellation_registry.clear(session_id, request_id)
|
||||
if was_cancelled:
|
||||
current_session = store.get_session(
|
||||
session_id, user_id=identity.user_id
|
||||
)
|
||||
_finalize_cancelled_request(
|
||||
current_session,
|
||||
store,
|
||||
identity.user_id,
|
||||
request_id,
|
||||
"Request cancelled by user.",
|
||||
)
|
||||
yield json.dumps({"type": "done"}) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Cache-Control": "no-cache"},
|
||||
)
|
||||
|
||||
try:
|
||||
for msg in agent.run_iter(session.messages):
|
||||
if request_id:
|
||||
msg.metadata["request_id"] = request_id
|
||||
session = store.append_message(session_id, msg, user_id=identity.user_id)
|
||||
except Exception as e:
|
||||
logger.exception("Agent loop failed")
|
||||
error_metadata = {"kind": "error"}
|
||||
if request_id:
|
||||
error_metadata["request_id"] = request_id
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=f"I encountered an error: {e}",
|
||||
metadata=error_metadata,
|
||||
)
|
||||
session = store.append_message(session_id, assistant_message, user_id=identity.user_id)
|
||||
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/cancel")
|
||||
def cancel_request(
|
||||
session_id: UUID,
|
||||
payload: CancelChatRequest,
|
||||
store: ChatStore = _CHAT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> dict[str, str]:
|
||||
cancelled = _cancellation_registry.cancel(session_id, payload.request_id)
|
||||
if cancelled:
|
||||
return {"status": "cancelled"}
|
||||
return {"status": "not_found"}
|
||||
@@ -0,0 +1,586 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_audit_store,
|
||||
get_entity_resolver,
|
||||
get_graph_repository,
|
||||
get_scanner_store,
|
||||
require_roles,
|
||||
)
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.collectors.factory import build_manager
|
||||
from eidolon.collectors.network import ScanCancelledError
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.event import AuditEvent
|
||||
from eidolon.core.models.scanner import ScannerConfig
|
||||
from eidolon.core.reasoning.entity import EntityResolver
|
||||
from eidolon.core.stores import AuditStore, ScannerStore
|
||||
from eidolon.runtime.task_events import TaskEvent, task_event_bus
|
||||
from eidolon.worker.ingest import IngestWorker
|
||||
|
||||
|
||||
class CollectorRunResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str = "started"
|
||||
|
||||
|
||||
class ScanHistoryItem(BaseModel):
|
||||
id: str
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
status: str
|
||||
events_collected: int
|
||||
error_message: str | None = None
|
||||
config_summary: str | None = None
|
||||
|
||||
|
||||
class ScanHistoryResponse(BaseModel):
|
||||
scans: list[ScanHistoryItem]
|
||||
|
||||
|
||||
class CancelScanRequest(BaseModel):
|
||||
task_id: str
|
||||
|
||||
|
||||
router = APIRouter(prefix="/collector", tags=["collector"])
|
||||
_AUDIT_STORE = Depends(get_audit_store)
|
||||
_SCANNER_STORE = Depends(get_scanner_store)
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_ENTITY_RESOLVER = Depends(get_entity_resolver)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_PLANNER_EXECUTOR_IDENTITY = Depends(require_roles("planner", "executor"))
|
||||
|
||||
|
||||
PORT_PRESET_PORTS: dict[str, list[int]] = {
|
||||
"fast": [80, 443],
|
||||
"normal": [
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
25,
|
||||
53,
|
||||
80,
|
||||
110,
|
||||
143,
|
||||
443,
|
||||
465,
|
||||
587,
|
||||
993,
|
||||
995,
|
||||
3306,
|
||||
3389,
|
||||
5432,
|
||||
8080,
|
||||
8443,
|
||||
],
|
||||
}
|
||||
VALID_PORT_PRESETS = {"fast", "normal", "full", "custom"}
|
||||
|
||||
|
||||
class _ScanRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._active: set[str] = set()
|
||||
self._cancelled: set[str] = set()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register(self, task_id: str) -> None:
|
||||
with self._lock:
|
||||
self._active.add(task_id)
|
||||
|
||||
def cancel(self, task_id: str) -> bool:
|
||||
with self._lock:
|
||||
if task_id in self._active:
|
||||
self._cancelled.add(task_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_cancelled(self, task_id: str) -> bool:
|
||||
with self._lock:
|
||||
return task_id in self._cancelled
|
||||
|
||||
def clear(self, task_id: str) -> None:
|
||||
with self._lock:
|
||||
self._active.discard(task_id)
|
||||
self._cancelled.discard(task_id)
|
||||
|
||||
|
||||
_scan_registry = _ScanRegistry()
|
||||
|
||||
|
||||
def _parse_target_range(value: str) -> tuple[int, int]:
|
||||
if "/" in value:
|
||||
network = ipaddress.ip_network(value, strict=False)
|
||||
if network.version != 4:
|
||||
raise ValueError("Only IPv4 targets are supported")
|
||||
return int(network.network_address), int(network.broadcast_address)
|
||||
|
||||
if "-" in value:
|
||||
start_str, end_str = value.split("-", 1)
|
||||
start_ip = ipaddress.ip_address(start_str)
|
||||
if start_ip.version != 4:
|
||||
raise ValueError("Only IPv4 targets are supported")
|
||||
if "." in end_str:
|
||||
end_ip = ipaddress.ip_address(end_str)
|
||||
else:
|
||||
parts = start_str.split(".")
|
||||
if len(parts) != 4:
|
||||
raise ValueError("Invalid IP range")
|
||||
end_ip = ipaddress.ip_address(".".join([*parts[:3], end_str]))
|
||||
if end_ip.version != 4:
|
||||
raise ValueError("Only IPv4 targets are supported")
|
||||
start_val = int(start_ip)
|
||||
end_val = int(end_ip)
|
||||
if end_val < start_val:
|
||||
raise ValueError("Range end must be greater than start")
|
||||
return start_val, end_val
|
||||
|
||||
ip_val = ipaddress.ip_address(value)
|
||||
if ip_val.version != 4:
|
||||
raise ValueError("Only IPv4 targets are supported")
|
||||
ip_int = int(ip_val)
|
||||
return ip_int, ip_int
|
||||
|
||||
|
||||
def _validate_targets(targets: list[str]) -> None:
|
||||
if not targets:
|
||||
raise HTTPException(status_code=422, detail="At least one target is required")
|
||||
if len(targets) > 50:
|
||||
raise HTTPException(status_code=422, detail="Maximum of 50 targets allowed")
|
||||
normalized = [target.strip() for target in targets if target.strip()]
|
||||
if len(set(normalized)) != len(normalized):
|
||||
raise HTTPException(status_code=422, detail="Duplicate targets are not allowed")
|
||||
|
||||
ranges = []
|
||||
for target in normalized:
|
||||
try:
|
||||
start, end = _parse_target_range(target)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
ranges.append((start, end, target))
|
||||
|
||||
ranges.sort(key=lambda item: item[0])
|
||||
for idx in range(1, len(ranges)):
|
||||
_prev_start, prev_end, prev_target = ranges[idx - 1]
|
||||
curr_start, _curr_end, curr_target = ranges[idx]
|
||||
if curr_start <= prev_end:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Target {curr_target} overlaps {prev_target}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_ports(port_preset: str, ports: list[int]) -> list[int]:
|
||||
if port_preset not in VALID_PORT_PRESETS:
|
||||
raise HTTPException(status_code=422, detail="Invalid port preset")
|
||||
|
||||
if port_preset in PORT_PRESET_PORTS:
|
||||
return PORT_PRESET_PORTS[port_preset]
|
||||
|
||||
if port_preset == "full":
|
||||
return []
|
||||
|
||||
if not ports:
|
||||
raise HTTPException(status_code=422, detail="Custom ports are required")
|
||||
|
||||
seen: set[int] = set()
|
||||
normalized: list[int] = []
|
||||
for port in ports:
|
||||
if not isinstance(port, int):
|
||||
raise HTTPException(status_code=422, detail="Ports must be integers")
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(status_code=422, detail="Ports must be between 1 and 65535")
|
||||
if port in seen:
|
||||
raise HTTPException(status_code=422, detail="Duplicate ports are not allowed")
|
||||
seen.add(port)
|
||||
normalized.append(port)
|
||||
|
||||
if len(normalized) > 1000:
|
||||
raise HTTPException(status_code=422, detail="Maximum of 1000 ports allowed")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_config(config: ScannerConfig) -> ScannerConfig:
|
||||
config.network_cidrs = [target.strip() for target in config.network_cidrs if target.strip()]
|
||||
_validate_targets(config.network_cidrs)
|
||||
|
||||
config.ports = _validate_ports(config.port_preset, config.ports)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _format_config_summary(config: ScannerConfig) -> str:
|
||||
targets = ", ".join(config.network_cidrs)
|
||||
if config.port_preset == "full":
|
||||
port_label = "ports 1-65535"
|
||||
elif config.ports:
|
||||
head = ",".join(str(port) for port in config.ports[:5])
|
||||
port_label = f"ports {head}{'...' if len(config.ports) > 5 else ''}"
|
||||
else:
|
||||
port_label = "ports none"
|
||||
return " ".join([part for part in [targets, port_label] if part]).strip()
|
||||
|
||||
|
||||
def _build_scan_config(config: ScannerConfig) -> dict:
|
||||
return {
|
||||
"network": {
|
||||
"cidrs": config.network_cidrs,
|
||||
"ping_concurrency": config.options.ping_concurrency,
|
||||
"port_scan_workers": config.options.port_scan_workers,
|
||||
"ports": config.ports,
|
||||
"port_preset": config.port_preset,
|
||||
"dns_resolution": config.options.dns_resolution,
|
||||
"aggressive": config.options.aggressive,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _run_scan_sync(
|
||||
task_id: str,
|
||||
config: dict,
|
||||
config_summary: str,
|
||||
repository: GraphRepository,
|
||||
resolver: EntityResolver,
|
||||
audit_store: AuditStore,
|
||||
) -> None:
|
||||
"""Synchronous scan logic that runs in background."""
|
||||
worker = IngestWorker(repository, resolver)
|
||||
|
||||
# Track stats per collector
|
||||
collector_stats: dict[str, dict] = {}
|
||||
current_collector: str | None = None
|
||||
|
||||
def emit_fn(event) -> None:
|
||||
nonlocal current_collector
|
||||
if current_collector:
|
||||
stats = collector_stats[current_collector]
|
||||
stats["events_processed"] += 1
|
||||
# Track entity types
|
||||
entity_type = event.entity_type
|
||||
stats["by_type"][entity_type] = stats["by_type"].get(entity_type, 0) + 1
|
||||
worker.process_event(event)
|
||||
|
||||
def progress_fn(line: str) -> None:
|
||||
"""Publish scan progress output to event bus."""
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="progress",
|
||||
payload={"task_id": task_id, "output": line},
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
manager = build_manager(
|
||||
config,
|
||||
emit_fn,
|
||||
cancellation_checker=lambda: _scan_registry.is_cancelled(task_id),
|
||||
progress_callback=progress_fn,
|
||||
)
|
||||
collectors = manager.list_collectors()
|
||||
|
||||
# Initialize stats for each collector
|
||||
for name in collectors:
|
||||
collector_stats[name] = {
|
||||
"events_processed": 0,
|
||||
"by_type": {},
|
||||
"status": "pending",
|
||||
}
|
||||
|
||||
# Emit scan start events
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="collector.scan.started",
|
||||
details={
|
||||
"collectors": collectors,
|
||||
"task_id": task_id,
|
||||
"config_summary": config_summary,
|
||||
},
|
||||
status="running",
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="started",
|
||||
payload={"collectors": collectors, "task_id": task_id},
|
||||
)
|
||||
)
|
||||
|
||||
# Run each collector and track results
|
||||
errors: list[Exception] = []
|
||||
for collector_name in collectors:
|
||||
if _scan_registry.is_cancelled(task_id):
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="collector.scan.cancelled",
|
||||
details={"task_id": task_id, "config_summary": config_summary},
|
||||
status="cancelled",
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="cancelled",
|
||||
payload={"task_id": task_id},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
current_collector = collector_name
|
||||
collector = manager._collectors[collector_name]
|
||||
|
||||
# Publish task event for this collector starting
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="running",
|
||||
payload={"current_collector": collector_name, "task_id": task_id},
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
collector.run()
|
||||
collector_stats[collector_name]["status"] = "ok"
|
||||
|
||||
# Emit per-collector audit event
|
||||
stats = collector_stats[collector_name]
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type=f"collector.{collector_name}",
|
||||
details={
|
||||
"events_processed": stats["events_processed"],
|
||||
"by_type": stats["by_type"],
|
||||
"task_id": task_id,
|
||||
},
|
||||
status="ok",
|
||||
)
|
||||
)
|
||||
|
||||
# Publish task event for collector completion
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="progress",
|
||||
payload={
|
||||
"collector": collector_name,
|
||||
"events_processed": stats["events_processed"],
|
||||
"task_id": task_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
except ScanCancelledError:
|
||||
# Scan was cancelled during this collector
|
||||
collector_stats[collector_name]["status"] = "cancelled"
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="collector.scan.cancelled",
|
||||
details={
|
||||
"task_id": task_id,
|
||||
"config_summary": config_summary,
|
||||
"collector": collector_name,
|
||||
},
|
||||
status="cancelled",
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="cancelled",
|
||||
payload={"task_id": task_id, "collector": collector_name},
|
||||
)
|
||||
)
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
collector_stats[collector_name]["status"] = "failed"
|
||||
collector_stats[collector_name]["error"] = str(exc)
|
||||
events_processed = collector_stats[collector_name]["events_processed"]
|
||||
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type=f"collector.{collector_name}",
|
||||
details={
|
||||
"events_processed": events_processed,
|
||||
"error": str(exc),
|
||||
"task_id": task_id,
|
||||
},
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
|
||||
current_collector = None
|
||||
total_events = sum(stats["events_processed"] for stats in collector_stats.values())
|
||||
if errors and total_events == 0:
|
||||
status = "failed"
|
||||
elif errors:
|
||||
status = "partial"
|
||||
else:
|
||||
status = "complete"
|
||||
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="collector.scan.complete",
|
||||
details={
|
||||
"collectors": collectors,
|
||||
"total_events": total_events,
|
||||
"collector_stats": collector_stats,
|
||||
"errors": [str(err) for err in errors],
|
||||
"task_id": task_id,
|
||||
"config_summary": config_summary,
|
||||
"status": status,
|
||||
},
|
||||
status=(
|
||||
"ok"
|
||||
if status == "complete"
|
||||
else "failed" if status == "failed" else "partial_failure"
|
||||
),
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="complete" if status == "complete" else status,
|
||||
payload={
|
||||
"collectors": collectors,
|
||||
"total_events": total_events,
|
||||
"errors": [str(err) for err in errors],
|
||||
"task_id": task_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="collector.scan.failed",
|
||||
details={"error": str(exc), "task_id": task_id, "config_summary": config_summary},
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="failed",
|
||||
payload={"error": str(exc), "task_id": task_id},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
_scan_registry.clear(task_id)
|
||||
|
||||
|
||||
@router.get("/config", response_model=ScannerConfig)
|
||||
async def get_config(
|
||||
scanner_store: ScannerStore = _SCANNER_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> ScannerConfig:
|
||||
record = scanner_store.get_config(identity.user_id)
|
||||
return record.config
|
||||
|
||||
|
||||
@router.put("/config", response_model=ScannerConfig)
|
||||
async def update_config(
|
||||
payload: dict,
|
||||
scanner_store: ScannerStore = _SCANNER_STORE,
|
||||
identity: IdentityContext = _PLANNER_EXECUTOR_IDENTITY,
|
||||
) -> ScannerConfig:
|
||||
try:
|
||||
config = ScannerConfig.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
config = _normalize_config(config)
|
||||
record = scanner_store.update_config(identity.user_id, config)
|
||||
return record.config
|
||||
|
||||
|
||||
@router.get("/scan/history", response_model=ScanHistoryResponse)
|
||||
async def scan_history(
|
||||
limit: int = 10,
|
||||
audit_store: AuditStore = _AUDIT_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> ScanHistoryResponse:
|
||||
"""Get scan history from audit log instead of separate scan_runs table."""
|
||||
# Query audit events for scan completions
|
||||
events = audit_store.list_filtered(
|
||||
page=1,
|
||||
page_size=limit,
|
||||
event_type="collector.scan.complete",
|
||||
)
|
||||
|
||||
scans = []
|
||||
for event in events:
|
||||
details = event.details or {}
|
||||
scans.append(
|
||||
ScanHistoryItem(
|
||||
id=str(event.audit_id),
|
||||
started_at=event.timestamp,
|
||||
completed_at=event.timestamp,
|
||||
status=details.get("status", "complete"),
|
||||
events_collected=details.get("total_events", 0),
|
||||
error_message=(
|
||||
"; ".join(details.get("errors", [])) if details.get("errors") else None
|
||||
),
|
||||
config_summary=details.get("config_summary"),
|
||||
)
|
||||
)
|
||||
|
||||
return ScanHistoryResponse(scans=scans)
|
||||
|
||||
|
||||
@router.post("/scan", response_model=CollectorRunResponse)
|
||||
async def trigger_scan(
|
||||
background_tasks: BackgroundTasks,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
resolver: EntityResolver = _ENTITY_RESOLVER,
|
||||
audit_store: AuditStore = _AUDIT_STORE,
|
||||
scanner_store: ScannerStore = _SCANNER_STORE,
|
||||
identity: IdentityContext = _PLANNER_EXECUTOR_IDENTITY,
|
||||
) -> CollectorRunResponse:
|
||||
"""Start a network scan in the background and return immediately."""
|
||||
task_id = str(uuid4())
|
||||
record = scanner_store.get_config(identity.user_id)
|
||||
config = _normalize_config(record.config)
|
||||
config_summary = _format_config_summary(config)
|
||||
|
||||
_scan_registry.register(task_id)
|
||||
|
||||
# Schedule background task
|
||||
background_tasks.add_task(
|
||||
_run_scan_sync,
|
||||
task_id=task_id,
|
||||
config=_build_scan_config(config),
|
||||
config_summary=config_summary,
|
||||
repository=repository,
|
||||
resolver=resolver,
|
||||
audit_store=audit_store,
|
||||
)
|
||||
|
||||
return CollectorRunResponse(
|
||||
task_id=task_id,
|
||||
status="started",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/scan/cancel")
|
||||
async def cancel_scan(
|
||||
payload: CancelScanRequest,
|
||||
identity: IdentityContext = _PLANNER_EXECUTOR_IDENTITY,
|
||||
) -> dict:
|
||||
if not _scan_registry.cancel(payload.task_id):
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="collector.scan",
|
||||
status="cancelling",
|
||||
payload={"task_id": payload.task_id, "user_id": identity.user_id},
|
||||
)
|
||||
)
|
||||
return {"status": "cancelling"}
|
||||
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import get_graph_repository, require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.asset import Asset, NetworkContainer
|
||||
from eidolon.core.models.graph import GraphPath, Node
|
||||
|
||||
router = APIRouter(prefix="/graph", tags=["graph"])
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
_LIMIT_QUERY = Query(100, ge=1, le=500)
|
||||
_MAX_DEPTH_QUERY = Query(4, ge=1, le=8)
|
||||
_NODE_LIMIT_QUERY = Query(200, ge=1, le=1000)
|
||||
_EDGE_LIMIT_QUERY = Query(400, ge=1, le=2000)
|
||||
|
||||
|
||||
class GraphClearResponse(BaseModel):
|
||||
status: str
|
||||
nodes_deleted: int
|
||||
|
||||
|
||||
class GraphOverviewNode(BaseModel):
|
||||
node_id: UUID
|
||||
label: str
|
||||
name: str | None = None
|
||||
kind: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class GraphOverviewEdge(BaseModel):
|
||||
source: UUID
|
||||
target: UUID
|
||||
type: str
|
||||
confidence: float | None = None
|
||||
|
||||
|
||||
class GraphOverviewResponse(BaseModel):
|
||||
nodes: list[GraphOverviewNode]
|
||||
edges: list[GraphOverviewEdge]
|
||||
|
||||
|
||||
class GraphQueryRequest(BaseModel):
|
||||
cypher: str
|
||||
parameters: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class GraphQueryResponse(BaseModel):
|
||||
records: list[dict[str, Any]]
|
||||
|
||||
|
||||
def _coerce_metadata(value: object) -> dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
return {}
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
|
||||
def _parse_uuid(value: object) -> UUID | None:
|
||||
try:
|
||||
return UUID(str(value))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/assets", response_model=list[Asset])
|
||||
def list_assets(
|
||||
limit: int = _LIMIT_QUERY,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> list[Node]:
|
||||
return repository.list_nodes(label="Asset", limit=limit)
|
||||
|
||||
|
||||
@router.get("/networks", response_model=list[NetworkContainer])
|
||||
def list_networks(
|
||||
limit: int = _LIMIT_QUERY,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> list[Node]:
|
||||
return repository.list_nodes(label="NetworkContainer", limit=limit)
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}", response_model=Asset)
|
||||
def get_asset(
|
||||
asset_id: UUID,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> Node:
|
||||
node = repository.get_node(asset_id)
|
||||
if not node or node.label != "Asset":
|
||||
raise HTTPException(status_code=404, detail="asset not found")
|
||||
return node
|
||||
|
||||
|
||||
@router.get("/paths", response_model=list[GraphPath])
|
||||
def get_paths(
|
||||
source_id: UUID,
|
||||
target_id: UUID,
|
||||
max_depth: int = _MAX_DEPTH_QUERY,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> list[GraphPath]:
|
||||
return repository.find_paths(source_id, target_id, max_depth)
|
||||
|
||||
|
||||
@router.delete("/", response_model=GraphClearResponse)
|
||||
def clear_graph(
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> GraphClearResponse:
|
||||
deleted = repository.clear()
|
||||
return GraphClearResponse(status="cleared", nodes_deleted=deleted)
|
||||
|
||||
|
||||
@router.get("/overview", response_model=GraphOverviewResponse)
|
||||
def graph_overview(
|
||||
node_limit: int = _NODE_LIMIT_QUERY,
|
||||
edge_limit: int = _EDGE_LIMIT_QUERY,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> GraphOverviewResponse:
|
||||
nodes_result = list(
|
||||
repository.run_cypher(
|
||||
"""
|
||||
MATCH (n)
|
||||
RETURN n.node_id AS node_id,
|
||||
head(labels(n)) AS label,
|
||||
n.cidr AS cidr,
|
||||
n.name AS name,
|
||||
n.kind AS kind,
|
||||
n.metadata AS metadata
|
||||
LIMIT $limit
|
||||
""",
|
||||
{"limit": node_limit},
|
||||
)
|
||||
)
|
||||
|
||||
nodes: list[GraphOverviewNode] = []
|
||||
node_ids: list[str] = []
|
||||
for record in nodes_result:
|
||||
node_id = _parse_uuid(record.get("node_id"))
|
||||
if not node_id:
|
||||
continue
|
||||
metadata = _coerce_metadata(record.get("metadata"))
|
||||
|
||||
# Build display name priority: IP -> CIDR -> hostname -> name -> UUID.
|
||||
display_name = None
|
||||
if metadata:
|
||||
display_name = metadata.get("ip") or metadata.get("hostname")
|
||||
if not display_name:
|
||||
display_name = record.get("cidr") or record.get("name") or str(node_id)
|
||||
|
||||
node = GraphOverviewNode(
|
||||
node_id=node_id,
|
||||
label=str(record.get("label") or "Node"),
|
||||
name=display_name,
|
||||
kind=record.get("kind"),
|
||||
metadata=metadata,
|
||||
)
|
||||
nodes.append(node)
|
||||
node_ids.append(str(node_id))
|
||||
|
||||
if not node_ids:
|
||||
return GraphOverviewResponse(nodes=[], edges=[])
|
||||
|
||||
edges_result = list(
|
||||
repository.run_cypher(
|
||||
"""
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE a.node_id IN $node_ids AND b.node_id IN $node_ids
|
||||
RETURN a.node_id AS source,
|
||||
b.node_id AS target,
|
||||
type(r) AS type,
|
||||
r.confidence AS confidence
|
||||
LIMIT $limit
|
||||
""",
|
||||
{"node_ids": node_ids, "limit": edge_limit},
|
||||
)
|
||||
)
|
||||
|
||||
edges: list[GraphOverviewEdge] = []
|
||||
for record in edges_result:
|
||||
source = _parse_uuid(record.get("source"))
|
||||
target = _parse_uuid(record.get("target"))
|
||||
if not source or not target:
|
||||
continue
|
||||
edges.append(
|
||||
GraphOverviewEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
type=str(record.get("type") or "RELATED"),
|
||||
confidence=record.get("confidence"),
|
||||
)
|
||||
)
|
||||
|
||||
return GraphOverviewResponse(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.post("/query", response_model=GraphQueryResponse)
|
||||
def execute_cypher_query(
|
||||
request: GraphQueryRequest,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> GraphQueryResponse:
|
||||
"""
|
||||
Execute a raw Cypher query against the graph database.
|
||||
|
||||
This endpoint allows direct Cypher queries for advanced use cases.
|
||||
Results are returned as a list of records (dictionaries).
|
||||
"""
|
||||
try:
|
||||
results = repository.run_cypher(request.cypher, request.parameters or {})
|
||||
records = [dict(record) for record in results]
|
||||
return GraphQueryResponse(records=records)
|
||||
except (Neo4jError, TypeError, ValueError, RuntimeError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Query execution failed: {exc!s}",
|
||||
) from exc
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_audit_store,
|
||||
get_entity_resolver,
|
||||
get_graph_repository,
|
||||
require_roles,
|
||||
)
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.event import AuditEvent, CollectorEvent
|
||||
from eidolon.core.reasoning.entity import EntityResolver
|
||||
from eidolon.core.stores import AuditStore
|
||||
from eidolon.runtime.task_events import TaskEvent, task_event_bus
|
||||
from eidolon.worker.ingest import IngestWorker
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
accepted: int
|
||||
|
||||
|
||||
router = APIRouter(prefix="/ingest", tags=["ingest"])
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_ENTITY_RESOLVER = Depends(get_entity_resolver)
|
||||
_AUDIT_STORE = Depends(get_audit_store)
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
@router.post("/events", response_model=IngestResponse)
|
||||
def ingest_events(
|
||||
events: list[CollectorEvent],
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
resolver: EntityResolver = _ENTITY_RESOLVER,
|
||||
audit_store: AuditStore = _AUDIT_STORE,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> IngestResponse:
|
||||
worker = IngestWorker(repository, resolver)
|
||||
worker.process(events)
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="ingest",
|
||||
details={"accepted": len(events)},
|
||||
status="ok",
|
||||
)
|
||||
)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="ingest",
|
||||
status="ok",
|
||||
payload={"accepted": len(events)},
|
||||
)
|
||||
)
|
||||
return IngestResponse(accepted=len(events))
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from eidolon.api.dependencies import get_settings_store, require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.config.settings import SandboxPermissions
|
||||
from eidolon.core.stores import SettingsStore
|
||||
|
||||
router = APIRouter(prefix="/permissions", tags=["permissions"])
|
||||
_SETTINGS_STORE = Depends(get_settings_store)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
class PermissionsResponse(BaseModel):
|
||||
sandbox: SandboxPermissions
|
||||
|
||||
|
||||
@router.get("/", response_model=PermissionsResponse)
|
||||
def get_permissions(
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
store: SettingsStore = _SETTINGS_STORE,
|
||||
) -> PermissionsResponse:
|
||||
"""Get sandbox permissions."""
|
||||
return PermissionsResponse(sandbox=store.get_settings())
|
||||
|
||||
|
||||
@router.put("/", response_model=PermissionsResponse)
|
||||
def update_permissions(
|
||||
permissions: SandboxPermissions,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
store: SettingsStore = _SETTINGS_STORE,
|
||||
) -> PermissionsResponse:
|
||||
"""Update sandbox permissions."""
|
||||
store.update_settings(permissions)
|
||||
return PermissionsResponse(sandbox=store.get_settings())
|
||||
@@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_approval_store,
|
||||
get_audit_store,
|
||||
get_graph_repository,
|
||||
get_llm_client,
|
||||
require_roles,
|
||||
)
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.config.settings import get_settings
|
||||
from eidolon.core.graph.algorithms import blast_radius
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.event import AuditEvent
|
||||
from eidolon.core.models.plan import (
|
||||
BlastRadius,
|
||||
EntityRef,
|
||||
ExecutionRequest,
|
||||
ExecutionResponse,
|
||||
PlanStep,
|
||||
ToolExecutionResult,
|
||||
)
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.planner import Planner
|
||||
from eidolon.core.stores import ApprovalStore, AuditStore
|
||||
from eidolon.runtime.executor import ExecutionEngine
|
||||
from eidolon.runtime.task_events import TaskEvent, task_event_bus
|
||||
|
||||
|
||||
class PlanRequest(BaseModel):
|
||||
intent: str = Field(description="Intent to satisfy (natural language)")
|
||||
target: EntityRef = Field(description="Primary target entity for the plan")
|
||||
|
||||
|
||||
class PlanResponse(BaseModel):
|
||||
steps: list[PlanStep]
|
||||
blast_radius: BlastRadius | None = None
|
||||
|
||||
|
||||
router = APIRouter(prefix="/plan", tags=["plan"])
|
||||
_APPROVAL_STORE = Depends(get_approval_store)
|
||||
_AUDIT_STORE = Depends(get_audit_store)
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_LLM_CLIENT = Depends(get_llm_client)
|
||||
_PLANNER_EXECUTOR_IDENTITY = Depends(require_roles("planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
@router.post("/", response_model=PlanResponse)
|
||||
def plan_endpoint(
|
||||
request: PlanRequest,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
audit_store: AuditStore = _AUDIT_STORE,
|
||||
llm_client: LiteLLMClient = _LLM_CLIENT,
|
||||
identity: IdentityContext = _PLANNER_EXECUTOR_IDENTITY,
|
||||
) -> PlanResponse:
|
||||
planner = Planner(llm_client=llm_client)
|
||||
steps = planner.generate_plan(intent=request.intent, target=request.target)
|
||||
radius = None
|
||||
if request.target.entity_id:
|
||||
radius = blast_radius(repository, [request.target.entity_id], depth=2)
|
||||
audit_store.add(
|
||||
AuditEvent(
|
||||
event_type="plan",
|
||||
details={
|
||||
"intent": request.intent,
|
||||
"target": request.target.model_dump(mode="json"),
|
||||
"steps": len(steps),
|
||||
},
|
||||
status="ok",
|
||||
)
|
||||
)
|
||||
return PlanResponse(steps=steps, blast_radius=radius)
|
||||
|
||||
|
||||
@router.post("/execute", response_model=ExecutionResponse)
|
||||
def execute_endpoint(
|
||||
request: ExecutionRequest,
|
||||
approval_store: ApprovalStore = _APPROVAL_STORE,
|
||||
audit_store: AuditStore = _AUDIT_STORE,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> ExecutionResponse:
|
||||
needs_approval = request.requires_approval or any(
|
||||
step.requires_approval for step in request.steps
|
||||
)
|
||||
if not request.dry_run and needs_approval:
|
||||
if not request.approval_token:
|
||||
raise HTTPException(status_code=403, detail="approval token required for execution")
|
||||
approval = approval_store.get_by_token(request.approval_token)
|
||||
if not approval or approval.action != "execute":
|
||||
raise HTTPException(status_code=403, detail="invalid approval token")
|
||||
|
||||
settings = get_settings()
|
||||
engine = ExecutionEngine(repository, runtime_settings=settings.sandbox)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="execute",
|
||||
status="started",
|
||||
payload={
|
||||
"dry_run": request.dry_run,
|
||||
"steps": len(request.steps),
|
||||
},
|
||||
)
|
||||
)
|
||||
results: list[ToolExecutionResult] = []
|
||||
for step in request.steps:
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="execute.step",
|
||||
status="started",
|
||||
payload={"step_id": step.step_id, "action_type": step.action_type},
|
||||
)
|
||||
)
|
||||
result = engine.execute_step(step, dry_run=request.dry_run)
|
||||
results.append(result)
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="execute.step",
|
||||
status=result.status,
|
||||
payload={
|
||||
"step_id": step.step_id,
|
||||
"action_type": step.action_type,
|
||||
"tool": result.tool,
|
||||
"error": result.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
status = "ok" if all(result.status != "error" for result in results) else "partial_failure"
|
||||
task_event_bus.publish(
|
||||
TaskEvent(
|
||||
event_type="execute",
|
||||
status=status,
|
||||
payload={
|
||||
"dry_run": request.dry_run,
|
||||
"steps": len(request.steps),
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
)
|
||||
audit_event = AuditEvent(
|
||||
event_type="execute",
|
||||
details={
|
||||
"dry_run": request.dry_run,
|
||||
"steps": len(request.steps),
|
||||
"status": status,
|
||||
"results": [result.model_dump() for result in results],
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
audit_store.add(audit_event)
|
||||
return ExecutionResponse(
|
||||
request=request,
|
||||
results=results,
|
||||
status=status,
|
||||
audit_id=audit_event.audit_id,
|
||||
)
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import (
|
||||
get_graph_repository,
|
||||
get_llm_client,
|
||||
require_roles,
|
||||
)
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.graph import GraphPath
|
||||
from eidolon.core.models.plan import GraphQuery
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.prompts import QUERY_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str = Field(description="Natural language query")
|
||||
source_id: UUID | None = Field(
|
||||
default=None, description="Optional source node for path queries"
|
||||
)
|
||||
target_id: UUID | None = Field(
|
||||
default=None, description="Optional target node for path queries"
|
||||
)
|
||||
max_depth: int = Field(default=4, ge=1, le=8)
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
answer: str
|
||||
paths: list[GraphPath] = Field(default_factory=list)
|
||||
citations: list[dict] = Field(default_factory=list)
|
||||
graph_query: GraphQuery | None = None
|
||||
records: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
class NLQueryPlan(BaseModel):
|
||||
answer: str
|
||||
graph_query: GraphQuery | None = None
|
||||
citations: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/query", tags=["query"])
|
||||
_GRAPH_REPOSITORY = Depends(get_graph_repository)
|
||||
_LLM_CLIENT = Depends(get_llm_client)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
|
||||
|
||||
class NaturalLanguageQueryInterpreter:
|
||||
"""
|
||||
Minimal rule-based interpreter to keep queries deterministic until the NL->Cypher LLM layer
|
||||
is added.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, repository: GraphRepository, llm_client: LiteLLMClient | None = None
|
||||
) -> None:
|
||||
self.repository = repository
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _parse_rules(self, question: str) -> NLQueryPlan:
|
||||
q = question.lower()
|
||||
|
||||
path_match = re.search(r"from\s+(?P<src>[\w\.-]+)\s+to\s+(?P<dst>[\w\.-]+)", q)
|
||||
if "path" in q and path_match:
|
||||
src = path_match.group("src")
|
||||
dst = path_match.group("dst")
|
||||
cypher = (
|
||||
"MATCH (src:Asset)-[r:CAN_REACH*1..4]->(dst:Asset) "
|
||||
"WHERE ($src IN src.identifiers OR src.node_id = $src) "
|
||||
"AND ($dst IN dst.identifiers OR dst.node_id = $dst) "
|
||||
"RETURN src, dst, r LIMIT 10"
|
||||
)
|
||||
return NLQueryPlan(
|
||||
answer=f"Finding paths from {src} to {dst}.",
|
||||
graph_query=GraphQuery(
|
||||
cypher=cypher,
|
||||
parameters={"src": src, "dst": dst},
|
||||
),
|
||||
)
|
||||
|
||||
assets_in_network = re.search(r"assets? in network\s+(?P<net>[\w\./-]+)", q)
|
||||
if assets_in_network:
|
||||
network = assets_in_network.group("net")
|
||||
cypher = (
|
||||
"MATCH (a:Asset)-[:MEMBER_OF]->(n:NetworkContainer) "
|
||||
"WHERE n.cidr = $network OR n.name = $network "
|
||||
"RETURN a LIMIT 100"
|
||||
)
|
||||
return NLQueryPlan(
|
||||
answer=f"Listing assets in network {network}.",
|
||||
graph_query=GraphQuery(cypher=cypher, parameters={"network": network}),
|
||||
)
|
||||
|
||||
if "policy" in q and ("govern" in q or "attached" in q):
|
||||
cypher = "MATCH (a:Asset)-[:GOVERNED_BY]->(p:Policy) " "RETURN a, p LIMIT 100"
|
||||
return NLQueryPlan(
|
||||
answer="Fetching governed assets and their policies.",
|
||||
graph_query=GraphQuery(cypher=cypher, parameters={}),
|
||||
)
|
||||
|
||||
return NLQueryPlan(
|
||||
answer="Query interpreted but no matching pattern; no graph query executed."
|
||||
)
|
||||
|
||||
def parse(self, question: str) -> NLQueryPlan:
|
||||
plan = self._parse_rules(question)
|
||||
if plan.graph_query or not self.llm_client or not self.llm_client.is_available():
|
||||
return plan
|
||||
prompt = QUERY_PROMPT_TEMPLATE.format(question=question)
|
||||
try:
|
||||
llm_plan = self.llm_client.generate_structured(prompt, NLQueryPlan)
|
||||
except (RuntimeError, TypeError, ValueError):
|
||||
return plan
|
||||
return llm_plan if llm_plan.answer else plan
|
||||
|
||||
|
||||
@router.post("/", response_model=QueryResponse)
|
||||
def handle_query(
|
||||
request: QueryRequest,
|
||||
repository: GraphRepository = _GRAPH_REPOSITORY,
|
||||
llm_client: LiteLLMClient = _LLM_CLIENT,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> QueryResponse:
|
||||
if request.source_id and request.target_id:
|
||||
paths = repository.find_paths(request.source_id, request.target_id, request.max_depth)
|
||||
return QueryResponse(
|
||||
answer="Path search completed.",
|
||||
paths=paths,
|
||||
citations=[],
|
||||
)
|
||||
if not request.question:
|
||||
raise HTTPException(status_code=400, detail="question is required")
|
||||
|
||||
interpreter = NaturalLanguageQueryInterpreter(repository, llm_client=llm_client)
|
||||
plan = interpreter.parse(request.question)
|
||||
records: list[dict] = []
|
||||
if plan.graph_query:
|
||||
records = list(repository.run_cypher(plan.graph_query.cypher, plan.graph_query.parameters))
|
||||
|
||||
return QueryResponse(
|
||||
answer=plan.answer,
|
||||
paths=[],
|
||||
citations=plan.citations,
|
||||
graph_query=plan.graph_query,
|
||||
records=records,
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.api.dependencies import get_llm_client, get_settings_store, require_roles
|
||||
from eidolon.api.middleware.auth import IdentityContext
|
||||
from eidolon.config.settings import LLMSettings, get_settings
|
||||
from eidolon.core.models.settings import AppSettings, ThemeSettings
|
||||
from eidolon.core.stores import SettingsStore
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["settings"])
|
||||
_SETTINGS_STORE = Depends(get_settings_store)
|
||||
_VIEWER_IDENTITY = Depends(require_roles("viewer", "planner", "executor"))
|
||||
_EXECUTOR_IDENTITY = Depends(require_roles("executor"))
|
||||
|
||||
|
||||
class LLMSettingsUpdate(BaseModel):
|
||||
model: str | None = None
|
||||
api_base: str | None = None
|
||||
api_key: str | None = None
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
max_tokens: int | None = Field(default=None, ge=128)
|
||||
|
||||
|
||||
class AppSettingsUpdate(BaseModel):
|
||||
theme: ThemeSettings | None = None
|
||||
llm: LLMSettingsUpdate | None = None
|
||||
|
||||
|
||||
class AppSettingsResponse(BaseModel):
|
||||
theme: ThemeSettings
|
||||
llm: LLMSettings
|
||||
|
||||
|
||||
@router.get("/", response_model=AppSettingsResponse)
|
||||
def get_app_settings(
|
||||
store: SettingsStore = _SETTINGS_STORE,
|
||||
identity: IdentityContext = _VIEWER_IDENTITY,
|
||||
) -> AppSettingsResponse:
|
||||
settings = store.get_app_settings()
|
||||
return AppSettingsResponse(theme=settings.theme, llm=settings.llm)
|
||||
|
||||
|
||||
@router.put("/", response_model=AppSettingsResponse)
|
||||
def update_app_settings(
|
||||
payload: AppSettingsUpdate,
|
||||
store: SettingsStore = _SETTINGS_STORE,
|
||||
identity: IdentityContext = _EXECUTOR_IDENTITY,
|
||||
) -> AppSettingsResponse:
|
||||
current = store.get_app_settings()
|
||||
theme = payload.theme or current.theme
|
||||
llm = current.llm
|
||||
|
||||
if payload.llm:
|
||||
defaults = get_settings().llm.model_dump()
|
||||
data = llm.model_dump()
|
||||
for key, value in payload.llm.model_dump(exclude_unset=True).items():
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if value is None or value == "":
|
||||
data[key] = defaults.get(key)
|
||||
else:
|
||||
data[key] = value
|
||||
llm = LLMSettings(**data)
|
||||
|
||||
updated = AppSettings(theme=theme, llm=llm)
|
||||
store.update_app_settings(updated)
|
||||
get_llm_client.cache_clear()
|
||||
return AppSettingsResponse(theme=updated.theme, llm=updated.llm)
|
||||
+173
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
|
||||
from eidolon.api.app import app
|
||||
from eidolon.api.dependencies import get_scanner_store
|
||||
from eidolon.collectors.factory import build_manager
|
||||
from eidolon.core.graph.neo4j import Neo4jGraphRepository
|
||||
from eidolon.core.models.scanner import ScannerConfig
|
||||
from eidolon.core.reasoning.entity import EntityResolver
|
||||
from eidolon.worker.ingest import IngestWorker
|
||||
|
||||
|
||||
def _make_help_handler(parser: argparse.ArgumentParser):
|
||||
def _handler(_args: argparse.Namespace) -> int:
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _build_scan_config(config: ScannerConfig) -> dict:
|
||||
return {
|
||||
"network": {
|
||||
"cidrs": config.network_cidrs,
|
||||
"ping_concurrency": config.options.ping_concurrency,
|
||||
"port_scan_workers": config.options.port_scan_workers,
|
||||
"ports": config.ports,
|
||||
"port_preset": config.port_preset,
|
||||
"dns_resolution": config.options.dns_resolution,
|
||||
"aggressive": config.options.aggressive,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def cmd_scan(args: argparse.Namespace) -> int:
|
||||
store = get_scanner_store()
|
||||
record = store.get_config("cli-user")
|
||||
config = _build_scan_config(record.config)
|
||||
repository = Neo4jGraphRepository()
|
||||
try:
|
||||
resolver = EntityResolver()
|
||||
worker = IngestWorker(repository, resolver)
|
||||
|
||||
event_count = 0
|
||||
|
||||
def emit_fn(event) -> None:
|
||||
nonlocal event_count
|
||||
event_count += 1
|
||||
worker.process_event(event)
|
||||
|
||||
manager = build_manager(config, emit_fn)
|
||||
collectors = manager.list_collectors()
|
||||
print(f"Starting network scan: {', '.join(collectors)}")
|
||||
|
||||
errors = manager.run_all()
|
||||
|
||||
print(f"Scan complete: {event_count} events ingested into Neo4j")
|
||||
|
||||
if errors:
|
||||
print(f"Encountered {len(errors)} error(s):")
|
||||
for err in errors:
|
||||
print(f" - {err}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
finally:
|
||||
repository.close()
|
||||
|
||||
|
||||
def cmd_db(args: argparse.Namespace) -> int:
|
||||
repository = Neo4jGraphRepository()
|
||||
try:
|
||||
if args.action == "stats":
|
||||
result = list(
|
||||
repository.run_cypher("MATCH (n) RETURN labels(n) as label, count(*) as count")
|
||||
)
|
||||
print("")
|
||||
print("Node counts by label:")
|
||||
total = 0
|
||||
for record in result:
|
||||
labels = record.get("label") or []
|
||||
count = record.get("count", 0)
|
||||
label_str = ":".join(labels) if labels else "(unlabeled)"
|
||||
print(f" {label_str}: {count}")
|
||||
total += count
|
||||
print("")
|
||||
print(f"Total nodes: {total}")
|
||||
|
||||
result = list(
|
||||
repository.run_cypher("MATCH ()-[r]->() RETURN type(r) as type, count(*) as count")
|
||||
)
|
||||
print("")
|
||||
print("Relationship counts by type:")
|
||||
rel_total = 0
|
||||
for record in result:
|
||||
print(f" {record.get('type')}: {record.get('count')}")
|
||||
rel_total += record.get("count", 0)
|
||||
print("")
|
||||
print(f"Total relationships: {rel_total}")
|
||||
|
||||
elif args.action == "clear":
|
||||
confirm = input(
|
||||
"WARNING: This will delete ALL data from Neo4j. Type 'yes' to confirm: "
|
||||
)
|
||||
if confirm.lower() == "yes":
|
||||
repository.clear()
|
||||
print("Database cleared")
|
||||
else:
|
||||
print("Cancelled")
|
||||
|
||||
elif args.action == "query":
|
||||
if not args.cypher:
|
||||
print("cypher is required for db query")
|
||||
return 1
|
||||
result = list(repository.run_cypher(args.cypher))
|
||||
for record in result:
|
||||
print(json.dumps(dict(record), indent=2, default=str))
|
||||
finally:
|
||||
repository.close()
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_ui(args: argparse.Namespace) -> int:
|
||||
uvicorn.run(app, host=args.host, port=args.port, reload=args.reload)
|
||||
return 0
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="eidolon", description="Eidolon CLI")
|
||||
|
||||
# Default behavior: start the API server
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="API server host (default: 0.0.0.0)" # noqa: S104
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8080, help="API server port (default: 8080)")
|
||||
parser.add_argument("--reload", action="store_true", help="Enable autoreload for development")
|
||||
parser.set_defaults(func=cmd_ui)
|
||||
|
||||
sub = parser.add_subparsers(dest="command", required=False)
|
||||
|
||||
help_cmd = sub.add_parser("help", help="Show help")
|
||||
help_cmd.set_defaults(func=_make_help_handler(parser))
|
||||
|
||||
scan = sub.add_parser("scan", help="Run the network scan once")
|
||||
scan.set_defaults(func=cmd_scan)
|
||||
|
||||
db_cmd = sub.add_parser("db", help="Database operations")
|
||||
db_cmd.add_argument("action", choices=["stats", "clear", "query"], help="Action to perform")
|
||||
db_cmd.add_argument("--cypher", help="Cypher query (for 'query' action)")
|
||||
db_cmd.set_defaults(func=cmd_db)
|
||||
|
||||
ui = sub.add_parser("ui", help="Serve the web UI/API (alias for default behavior)")
|
||||
ui.add_argument("--host", default="0.0.0.0") # noqa: S104
|
||||
ui.add_argument("--port", type=int, default=8080)
|
||||
ui.add_argument("--reload", action="store_true", help="Enable autoreload for development")
|
||||
ui.set_defaults(func=cmd_ui)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
return args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from eidolon.core.models.event import CollectorEvent
|
||||
|
||||
|
||||
class BaseCollector(ABC):
|
||||
"""Base collector interface. Collectors run deterministically and emit normalized events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
emit_fn: Callable[[CollectorEvent], None] | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.emit_fn = emit_fn
|
||||
|
||||
def emit(self, event: CollectorEvent) -> None:
|
||||
if self.emit_fn:
|
||||
self.emit_fn(event)
|
||||
|
||||
@abstractmethod
|
||||
def collect(self) -> Iterable[CollectorEvent]:
|
||||
"""Run the collector and return a stream of normalized events."""
|
||||
|
||||
def run(self) -> None:
|
||||
"""Execute the collector and emit events."""
|
||||
for event in self.collect():
|
||||
self.emit(event)
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from eidolon.collectors.manager import CollectorManager
|
||||
from eidolon.collectors.network import NetworkCollector
|
||||
from eidolon.core.models.event import CollectorEvent
|
||||
|
||||
|
||||
def build_manager(
|
||||
config: dict,
|
||||
emit_fn,
|
||||
cancellation_checker: Callable[[], bool] | None = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
) -> CollectorManager:
|
||||
manager = CollectorManager(emit_fn=emit_fn)
|
||||
network_cfg: dict | None = config.get("network")
|
||||
if network_cfg is None:
|
||||
return manager
|
||||
|
||||
manager.register(
|
||||
NetworkCollector(
|
||||
cidrs=network_cfg.get("cidrs", []),
|
||||
ping_concurrency=network_cfg.get("ping_concurrency", 64),
|
||||
port_scan_workers=network_cfg.get("port_scan_workers", 32),
|
||||
ports=network_cfg.get("ports"),
|
||||
port_preset=network_cfg.get("port_preset"),
|
||||
dns_resolution=network_cfg.get("dns_resolution", True),
|
||||
aggressive=network_cfg.get("aggressive", False),
|
||||
nmap_path=network_cfg.get("nmap_path", "nmap"),
|
||||
cancellation_checker=cancellation_checker,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
)
|
||||
|
||||
return manager
|
||||
|
||||
|
||||
def collect_once(config: dict, emit_fn) -> int:
|
||||
events: list[CollectorEvent] = []
|
||||
|
||||
def _emit(event: CollectorEvent) -> None:
|
||||
events.append(event)
|
||||
emit_fn(event)
|
||||
|
||||
manager = build_manager(config, _emit)
|
||||
manager.run_all()
|
||||
return len(events)
|
||||
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from eidolon.collectors.base import BaseCollector
|
||||
from eidolon.core.models.event import CollectorEvent
|
||||
|
||||
|
||||
class CollectorManager:
|
||||
"""Simple in-process collector orchestrator with start/stop hooks."""
|
||||
|
||||
def __init__(self, emit_fn: Callable[[CollectorEvent], None]) -> None:
|
||||
self.emit_fn = emit_fn
|
||||
self._collectors: dict[str, BaseCollector] = {}
|
||||
|
||||
def register(self, collector: BaseCollector) -> None:
|
||||
collector.emit_fn = self.emit_fn
|
||||
self._collectors[collector.name] = collector
|
||||
|
||||
def run_all(self) -> list[Exception]:
|
||||
errors: list[Exception] = []
|
||||
for collector in self._collectors.values():
|
||||
try:
|
||||
collector.run()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
return errors
|
||||
|
||||
def run_selected(self, names: Iterable[str]) -> list[Exception]:
|
||||
errors: list[Exception] = []
|
||||
for name in names:
|
||||
collector = self._collectors.get(name)
|
||||
if collector:
|
||||
try:
|
||||
collector.run()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
return errors
|
||||
|
||||
def list_collectors(self) -> list[str]:
|
||||
return list(self._collectors.keys())
|
||||
@@ -0,0 +1,328 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
|
||||
from defusedxml import ElementTree as DefusedET
|
||||
|
||||
from eidolon.collectors.base import BaseCollector
|
||||
from eidolon.core.models.event import CollectorEvent
|
||||
|
||||
|
||||
class ScanCancelledError(Exception):
|
||||
"""Raised when a scan is cancelled."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NetworkCollector(BaseCollector):
|
||||
"""
|
||||
Network scanning collector backed by nmap.
|
||||
|
||||
Performs a ping sweep (-sn) to identify live hosts, then an optional targeted port scan of
|
||||
discovered hosts. No synthetic data is emitted; results mirror nmap output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cidrs: list[str],
|
||||
ping_concurrency: int = 64,
|
||||
port_scan_workers: int = 32,
|
||||
ports: list[int] | None = None,
|
||||
port_preset: str | None = None,
|
||||
dns_resolution: bool = True,
|
||||
aggressive: bool = False,
|
||||
nmap_path: str = "nmap",
|
||||
cancellation_checker: Callable[[], bool] | None = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
super().__init__(name="network")
|
||||
self.cidrs = cidrs
|
||||
self.ping_concurrency = ping_concurrency
|
||||
self.port_scan_workers = port_scan_workers
|
||||
self.ports = ports or []
|
||||
self.port_preset = port_preset
|
||||
self.dns_resolution = dns_resolution
|
||||
self.aggressive = aggressive
|
||||
self.nmap_path = nmap_path
|
||||
self.cancellation_checker = cancellation_checker
|
||||
self.progress_callback = progress_callback
|
||||
self._active_process: subprocess.Popen | None = None
|
||||
|
||||
def _send_progress(self, message: str) -> None:
|
||||
"""Send formatted progress message to callback."""
|
||||
if self.progress_callback:
|
||||
self.progress_callback(message)
|
||||
|
||||
def collect(self) -> Iterable[CollectorEvent]:
|
||||
discovered_hosts: list[str] = []
|
||||
host_to_cidr: dict[str, str] = {}
|
||||
|
||||
self._send_progress(f"Starting scan of {len(self.cidrs)} network(s)...")
|
||||
|
||||
for cidr in self.cidrs:
|
||||
# Check cancellation before each target
|
||||
self._check_cancellation()
|
||||
|
||||
self._send_progress(f"Discovering hosts in {cidr}...")
|
||||
sweep_args = ["-sn", "-oX", "-", cidr]
|
||||
sweep_args = self._with_dns_flag(sweep_args)
|
||||
sweep_args = self._with_parallelism(sweep_args, self.ping_concurrency)
|
||||
sweep_xml = self._run_nmap(sweep_args, show_output=False)
|
||||
hosts = self._parse_ping_sweep(sweep_xml, cidr)
|
||||
|
||||
if hosts:
|
||||
self._send_progress(f"Found {len(hosts)} live host(s) in {cidr}")
|
||||
for host in hosts:
|
||||
ip = host.get("ip")
|
||||
hostname = host.get("hostname")
|
||||
if ip:
|
||||
discovered_hosts.append(ip)
|
||||
host_to_cidr[ip] = cidr
|
||||
host_desc = f"{ip} ({hostname})" if hostname else ip
|
||||
self._send_progress(f" → {host_desc}")
|
||||
else:
|
||||
self._send_progress(f"No live hosts found in {cidr}")
|
||||
|
||||
for host in hosts:
|
||||
yield self._build_event(host)
|
||||
|
||||
# Check cancellation before port scan
|
||||
self._check_cancellation()
|
||||
|
||||
port_spec = self._build_port_spec()
|
||||
if port_spec and discovered_hosts:
|
||||
self._send_progress(f"\nScanning ports on {len(discovered_hosts)} host(s)...")
|
||||
port_scan_args = ["-Pn", *port_spec, "-oX", "-", *discovered_hosts]
|
||||
port_scan_args = self._with_dns_flag(port_scan_args)
|
||||
port_scan_args = self._with_parallelism(port_scan_args, self.port_scan_workers)
|
||||
if self.aggressive:
|
||||
port_scan_args.extend(["-O", "-sV"])
|
||||
self._send_progress("Using aggressive scan (OS detection + version detection)")
|
||||
|
||||
port_scan_xml = self._run_nmap(port_scan_args, show_output=False)
|
||||
|
||||
for host_payload in self._parse_port_scan(port_scan_xml):
|
||||
cidr = host_to_cidr.get(host_payload.get("ip", ""))
|
||||
if cidr:
|
||||
host_payload["cidr"] = cidr
|
||||
|
||||
# Report open ports
|
||||
ip = host_payload.get("ip")
|
||||
ports = host_payload.get("ports", [])
|
||||
open_ports = [p for p in ports if p.get("state") == "open"]
|
||||
if open_ports:
|
||||
self._send_progress(f" {ip}: {len(open_ports)} open port(s)")
|
||||
for port in open_ports:
|
||||
service = port.get("service", "unknown")
|
||||
self._send_progress(f" → {port['port']}/{service}")
|
||||
else:
|
||||
self._send_progress(f" {ip}: No open ports found")
|
||||
|
||||
yield self._build_event(host_payload)
|
||||
|
||||
self._send_progress("\nScan complete!")
|
||||
|
||||
def _check_cancellation(self) -> None:
|
||||
"""Check if scan was cancelled and raise exception if so."""
|
||||
if self.cancellation_checker and self.cancellation_checker():
|
||||
# Kill active process if running
|
||||
if self._active_process:
|
||||
with suppress(Exception):
|
||||
self._active_process.terminate()
|
||||
self._active_process.wait(timeout=2)
|
||||
with suppress(Exception):
|
||||
self._active_process.kill()
|
||||
self._active_process = None
|
||||
raise ScanCancelledError("Scan was cancelled")
|
||||
|
||||
def _run_nmap(self, args: list[str], show_output: bool = True) -> str:
|
||||
cmd = [self.nmap_path, *args]
|
||||
try:
|
||||
self._active_process = subprocess.Popen( # noqa: S603
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1, # Line buffered
|
||||
)
|
||||
|
||||
# Capture output line by line
|
||||
stdout_lines = []
|
||||
# Read stdout in real-time
|
||||
if self._active_process.stdout:
|
||||
for line in iter(self._active_process.stdout.readline, ""):
|
||||
if not line:
|
||||
break
|
||||
stdout_lines.append(line)
|
||||
# Only send raw output to UI if show_output is True (for debugging)
|
||||
# Normally we send custom formatted messages via _send_progress
|
||||
# Check cancellation while reading
|
||||
if self.cancellation_checker and self.cancellation_checker():
|
||||
self._check_cancellation()
|
||||
|
||||
# Wait for process to complete and get stderr
|
||||
_, stderr = self._active_process.communicate()
|
||||
returncode = self._active_process.returncode
|
||||
self._active_process = None
|
||||
|
||||
stdout = "".join(stdout_lines)
|
||||
|
||||
if returncode != 0:
|
||||
raise RuntimeError(f"nmap failed ({returncode}): {stderr.strip()}")
|
||||
return stdout
|
||||
except ScanCancelledError:
|
||||
# Re-raise cancellation without wrapping
|
||||
raise
|
||||
except (OSError, subprocess.SubprocessError, RuntimeError):
|
||||
self._active_process = None
|
||||
raise
|
||||
|
||||
def _build_port_spec(self) -> list[str]:
|
||||
if self.port_preset == "full":
|
||||
return ["-p-"]
|
||||
if self.ports:
|
||||
return ["-p", ",".join(str(p) for p in self.ports)]
|
||||
return []
|
||||
|
||||
def _with_dns_flag(self, args: list[str]) -> list[str]:
|
||||
return args + (["-R"] if self.dns_resolution else ["-n"])
|
||||
|
||||
def _with_parallelism(self, args: list[str], value: int) -> list[str]:
|
||||
if value <= 0:
|
||||
return args
|
||||
return [*args, "--min-parallelism", str(value), "--max-parallelism", str(value)]
|
||||
|
||||
def _parse_ping_sweep(self, xml_text: str, cidr: str) -> list[dict]:
|
||||
hosts: list[dict] = []
|
||||
root = DefusedET.fromstring(xml_text)
|
||||
for host in root.findall("host"):
|
||||
status = host.find("status")
|
||||
if status is None or status.attrib.get("state") != "up":
|
||||
continue
|
||||
addr = host.find("address")
|
||||
if addr is None:
|
||||
continue
|
||||
ip = addr.attrib.get("addr")
|
||||
if not ip:
|
||||
continue
|
||||
|
||||
# Extract hostname if available
|
||||
hostname = None
|
||||
hostnames_elem = host.find("hostnames")
|
||||
if hostnames_elem is not None:
|
||||
hostname_elem = hostnames_elem.find("hostname")
|
||||
if hostname_elem is not None:
|
||||
hostname = hostname_elem.attrib.get("name")
|
||||
|
||||
# Extract MAC address and vendor
|
||||
mac_address = None
|
||||
mac_vendor = None
|
||||
for address_elem in host.findall("address"):
|
||||
addr_type = address_elem.attrib.get("addrtype")
|
||||
if addr_type == "mac":
|
||||
mac_address = address_elem.attrib.get("addr")
|
||||
mac_vendor = address_elem.attrib.get("vendor")
|
||||
break
|
||||
|
||||
host_data = {"ip": ip, "cidr": cidr, "status": "online"}
|
||||
if hostname:
|
||||
host_data["hostname"] = hostname
|
||||
if mac_address:
|
||||
host_data["mac_address"] = mac_address
|
||||
if mac_vendor:
|
||||
host_data["vendor"] = mac_vendor
|
||||
hosts.append(host_data)
|
||||
return hosts
|
||||
|
||||
def _parse_port_scan(self, xml_text: str) -> list[dict]:
|
||||
results: list[dict] = []
|
||||
root = DefusedET.fromstring(xml_text)
|
||||
for host in root.findall("host"):
|
||||
addr = host.find("address")
|
||||
if addr is None:
|
||||
continue
|
||||
ip = addr.attrib.get("addr")
|
||||
if not ip:
|
||||
continue
|
||||
|
||||
host_data = {"ip": ip}
|
||||
|
||||
# Extract MAC address and vendor
|
||||
mac_address = None
|
||||
mac_vendor = None
|
||||
for address_elem in host.findall("address"):
|
||||
addr_type = address_elem.attrib.get("addrtype")
|
||||
if addr_type == "mac":
|
||||
mac_address = address_elem.attrib.get("addr")
|
||||
mac_vendor = address_elem.attrib.get("vendor")
|
||||
break
|
||||
|
||||
if mac_address:
|
||||
host_data["mac_address"] = mac_address
|
||||
if mac_vendor:
|
||||
host_data["vendor"] = mac_vendor
|
||||
|
||||
# Parse ports with detailed service information
|
||||
ports = []
|
||||
ports_element = host.find("ports")
|
||||
if ports_element is not None:
|
||||
for port_elem in ports_element.findall("port"):
|
||||
port_id = int(port_elem.attrib.get("portid", "0"))
|
||||
state_elem = port_elem.find("state")
|
||||
state = state_elem.attrib.get("state") if state_elem is not None else "unknown"
|
||||
|
||||
# Extract detailed service information
|
||||
service_elem = port_elem.find("service")
|
||||
service_name = None
|
||||
service_product = None
|
||||
service_version = None
|
||||
if service_elem is not None:
|
||||
service_name = service_elem.attrib.get("name")
|
||||
service_product = service_elem.attrib.get("product")
|
||||
service_version = service_elem.attrib.get("version")
|
||||
|
||||
port_data = {"port": port_id, "state": state, "service": service_name}
|
||||
if service_product:
|
||||
port_data["product"] = service_product
|
||||
if service_version:
|
||||
port_data["version"] = service_version
|
||||
|
||||
ports.append(port_data)
|
||||
|
||||
host_data["ports"] = ports
|
||||
|
||||
# Parse OS detection information
|
||||
os_element = host.find("os")
|
||||
if os_element is not None:
|
||||
osmatch = os_element.find("osmatch")
|
||||
if osmatch is not None:
|
||||
os_name = osmatch.attrib.get("name")
|
||||
os_accuracy = osmatch.attrib.get("accuracy")
|
||||
if os_name:
|
||||
host_data["os"] = os_name
|
||||
if os_accuracy:
|
||||
host_data["os_accuracy"] = f"{os_accuracy}%"
|
||||
|
||||
# Get OS class for more general info
|
||||
osclass = os_element.find("osclass")
|
||||
if osclass is not None:
|
||||
os_family = osclass.attrib.get("osfamily")
|
||||
if os_family and "os" not in host_data:
|
||||
host_data["os"] = os_family
|
||||
|
||||
results.append(host_data)
|
||||
return results
|
||||
|
||||
def _build_event(self, payload: dict) -> CollectorEvent:
|
||||
now = datetime.utcnow()
|
||||
return CollectorEvent(
|
||||
source_type="network",
|
||||
source_id=payload.get("ip", "network-scan"),
|
||||
entity_type="Asset",
|
||||
payload=payload,
|
||||
collected_at=now,
|
||||
confidence=0.8,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Neo4jSettings(BaseModel):
|
||||
uri: str = Field(default="bolt://localhost:7687", description="Neo4j Bolt URI")
|
||||
user: str = Field(default="neo4j", description="Neo4j user")
|
||||
password: str = Field(default="password", description="Neo4j password")
|
||||
database: str = Field(default="neo4j", description="Neo4j database name")
|
||||
|
||||
|
||||
class PostgresSettings(BaseModel):
|
||||
url: str = Field(
|
||||
default="postgresql://postgres:password@localhost:5432/eidolon",
|
||||
description="Postgres connection string for audit/session data",
|
||||
)
|
||||
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
model: str = Field(default="gpt-5-mini", description="Default LiteLLM model name")
|
||||
api_base: str | None = Field(default=None, description="Optional LiteLLM proxy base URL")
|
||||
api_key: str | None = Field(default=None, description="API key used by LiteLLM")
|
||||
temperature: float = Field(
|
||||
default=0.0, ge=0.0, le=1.0, description="Default generation temperature"
|
||||
)
|
||||
max_tokens: int = Field(default=1024, ge=128, description="Token cap for responses")
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None = Field(
|
||||
default=None,
|
||||
description="Optional reasoning effort hint for supported models",
|
||||
)
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="Nucleus sampling")
|
||||
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="Frequency penalty")
|
||||
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="Presence penalty")
|
||||
max_context_tokens: int = Field(default=128000, ge=1024, description="Max context window")
|
||||
max_retries: int = Field(default=5, ge=0, description="Retry attempts for rate limits")
|
||||
retry_delay: float = Field(default=2.0, ge=0.1, description="Base retry delay in seconds")
|
||||
|
||||
@field_validator("reasoning_effort", mode="before")
|
||||
@classmethod
|
||||
def normalize_reasoning_effort(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
cleaned = v.strip().lower()
|
||||
return cleaned or None
|
||||
return v
|
||||
|
||||
|
||||
class APISettings(BaseModel):
|
||||
host: str = Field(default="0.0.0.0", description="API bind host") # noqa: S104
|
||||
port: int = Field(default=8080, ge=1, le=65535, description="API bind port")
|
||||
cors_origins: list[str] = Field(
|
||||
default_factory=lambda: ["*"], description="Allowed CORS origins"
|
||||
)
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
mode: Literal["header", "jwt", "none"] = Field(
|
||||
default="header",
|
||||
description="Auth mode: header (dev), jwt (HS256), or none (no auth checks).",
|
||||
)
|
||||
jwt_secret: str | None = Field(
|
||||
default=None,
|
||||
description="Shared secret for HS256 JWT verification when auth.mode=jwt.",
|
||||
)
|
||||
jwt_issuer: str | None = Field(default=None, description="Expected JWT issuer (iss claim).")
|
||||
jwt_audience: str | None = Field(default=None, description="Expected JWT audience (aud claim).")
|
||||
header_user_id: str = Field(default="x-user-id", description="Header for user identity.")
|
||||
header_roles: str = Field(default="x-roles", description="Header for roles list.")
|
||||
|
||||
|
||||
class SandboxPermissions(BaseModel):
|
||||
allow_unsafe_tools: bool = Field(
|
||||
default=True,
|
||||
description="Allow system tools (graph queries, planning, reasoning tools).",
|
||||
)
|
||||
allow_shell: bool = Field(default=True, description="Allow terminal tool usage.")
|
||||
allow_network: bool = Field(default=True, description="Allow browser tool usage.")
|
||||
allow_file_write: bool = Field(default=False, description="Allow file_edit write operations.")
|
||||
allowed_tools: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Optional allowlist of tool names; when set, only these are permitted.",
|
||||
)
|
||||
blocked_tools: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Optional blocklist of tool names to deny.",
|
||||
)
|
||||
|
||||
@field_validator("allowed_tools", "blocked_tools", mode="before")
|
||||
@classmethod
|
||||
def parse_list_from_env(cls, v, info):
|
||||
"""Handle empty strings from environment variables for list fields."""
|
||||
if v == "":
|
||||
return None if info.field_name == "allowed_tools" else []
|
||||
return v
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
environment: str = Field(default="local", description="Runtime environment label")
|
||||
log_level: str = Field(default="INFO", description="Structured log level")
|
||||
neo4j: Neo4jSettings = Field(default_factory=Neo4jSettings)
|
||||
postgres: PostgresSettings = Field(default_factory=PostgresSettings)
|
||||
llm: LLMSettings = Field(default_factory=LLMSettings)
|
||||
api: APISettings = Field(default_factory=APISettings)
|
||||
auth: AuthSettings = Field(default_factory=AuthSettings)
|
||||
sandbox: SandboxPermissions = Field(default_factory=SandboxPermissions)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="EIDOLON_",
|
||||
env_file=".env",
|
||||
env_nested_delimiter="__",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
"""Return a cached Settings instance to avoid repeated env parsing."""
|
||||
return Settings()
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from neo4j.exceptions import Neo4jError
|
||||
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.plan import BlastRadius
|
||||
|
||||
|
||||
def blast_radius(
|
||||
repository: GraphRepository, targets: Sequence[UUID], depth: int = 2
|
||||
) -> BlastRadius:
|
||||
"""
|
||||
Estimate blast radius by traversing outbound relationships up to the given depth.
|
||||
|
||||
This uses repository neighbor lookups to avoid binding to a specific graph backend.
|
||||
"""
|
||||
visited: set[UUID] = set()
|
||||
queue: deque[tuple[UUID, int]] = deque([(target, 0) for target in targets])
|
||||
paths = []
|
||||
|
||||
while queue:
|
||||
node_id, level = queue.popleft()
|
||||
if node_id in visited or level > depth:
|
||||
continue
|
||||
visited.add(node_id)
|
||||
neighbors = repository.get_neighbors(node_id)
|
||||
for neighbor in neighbors:
|
||||
if neighbor not in visited and level + 1 <= depth:
|
||||
queue.append((neighbor, level + 1))
|
||||
|
||||
return BlastRadius(affected_nodes=list(visited), paths=paths, score=float(len(visited)))
|
||||
|
||||
|
||||
def min_cut_edges(repository: GraphRepository, source: UUID, target: UUID) -> list[dict]:
|
||||
"""
|
||||
Wrapper for computing min-cut edges between two nodes.
|
||||
|
||||
For Neo4j, this attempts to call GDS. Implementations without GDS support can
|
||||
override run_cypher to provide results or return an empty list.
|
||||
"""
|
||||
cypher = """
|
||||
CALL gds.alpha.minCut.stream({
|
||||
nodeProjection: '*',
|
||||
relationshipProjection: '*'
|
||||
})
|
||||
YIELD sourceNodeId, targetNodeId, cutCost
|
||||
RETURN gds.util.asNode(sourceNodeId).node_id AS source_id,
|
||||
gds.util.asNode(targetNodeId).node_id AS target_id,
|
||||
cutCost
|
||||
"""
|
||||
try:
|
||||
result = list(repository.run_cypher(cypher))
|
||||
except (Neo4jError, RuntimeError, TypeError, ValueError):
|
||||
result = []
|
||||
return result
|
||||
@@ -0,0 +1,445 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from neo4j import GraphDatabase, Session
|
||||
from pydantic import ValidationError
|
||||
|
||||
from eidolon.config.settings import get_settings
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.asset import (
|
||||
ActionType,
|
||||
Asset,
|
||||
Capability,
|
||||
EvidenceSource,
|
||||
Identity,
|
||||
NetworkContainer,
|
||||
Policy,
|
||||
Tool,
|
||||
)
|
||||
from eidolon.core.models.graph import Edge, EvidenceRef, GraphPath, Node
|
||||
|
||||
|
||||
class Neo4jGraphRepository(GraphRepository):
|
||||
"""Neo4j implementation of the GraphRepository using Cypher queries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str | None = None,
|
||||
user: str | None = None,
|
||||
password: str | None = None,
|
||||
database: str | None = None,
|
||||
) -> None:
|
||||
settings = get_settings()
|
||||
self._uri = uri or settings.neo4j.uri
|
||||
self._user = user or settings.neo4j.user
|
||||
self._password = password or settings.neo4j.password
|
||||
self._database = database or settings.neo4j.database
|
||||
self._driver = GraphDatabase.driver(
|
||||
self._uri,
|
||||
auth=(self._user, self._password),
|
||||
notifications_disabled_categories=["UNRECOGNIZED"],
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self._driver.close()
|
||||
|
||||
def _session(self) -> Session:
|
||||
return self._driver.session(database=self._database)
|
||||
|
||||
def _build_node(self, label: str, props: dict, evidence: list[EvidenceRef]) -> Node:
|
||||
from neo4j.time import DateTime as Neo4jDateTime
|
||||
|
||||
payload = dict(props or {})
|
||||
payload["label"] = label
|
||||
payload["evidence"] = evidence
|
||||
|
||||
# Convert Neo4j DateTime objects to Python datetime
|
||||
for key in ("created_at", "updated_at"):
|
||||
if key in payload and isinstance(payload[key], Neo4jDateTime):
|
||||
payload[key] = payload[key].to_native()
|
||||
|
||||
# Deserialize known JSON fields
|
||||
payload = self._deserialize_from_neo4j(payload)
|
||||
|
||||
model_map = {
|
||||
"Asset": Asset,
|
||||
"NetworkContainer": NetworkContainer,
|
||||
"Identity": Identity,
|
||||
"Policy": Policy,
|
||||
"Tool": Tool,
|
||||
"Capability": Capability,
|
||||
"ActionType": ActionType,
|
||||
"EvidenceSource": EvidenceSource,
|
||||
}
|
||||
model = model_map.get(label, Node)
|
||||
try:
|
||||
return model.model_validate(payload)
|
||||
except ValidationError:
|
||||
return Node.model_validate(payload)
|
||||
|
||||
@staticmethod
|
||||
def _parse_evidence(raw: list[dict]) -> list[EvidenceRef]:
|
||||
from neo4j.time import DateTime as Neo4jDateTime
|
||||
|
||||
evidence: list[EvidenceRef] = []
|
||||
for item in raw or []:
|
||||
if not item or not item.get("source_type"):
|
||||
continue
|
||||
# Convert Neo4j DateTime to Python datetime
|
||||
if "collected_at" in item and isinstance(item["collected_at"], Neo4jDateTime):
|
||||
item["collected_at"] = item["collected_at"].to_native()
|
||||
# Deserialize metadata field only
|
||||
if "metadata" in item and isinstance(item["metadata"], str):
|
||||
try:
|
||||
item["metadata"] = json.loads(item["metadata"])
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
item["metadata"] = {}
|
||||
evidence.append(EvidenceRef.model_validate(item))
|
||||
return evidence
|
||||
|
||||
@staticmethod
|
||||
def _serialize_for_neo4j(data: dict) -> dict:
|
||||
"""Recursively serialize nested dicts to JSON strings for Neo4j."""
|
||||
result = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict) or (
|
||||
isinstance(value, list) and value and all(isinstance(item, dict) for item in value)
|
||||
):
|
||||
result[key] = json.dumps(value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_from_neo4j(data: dict, json_fields: set[str] | None = None) -> dict:
|
||||
"""Deserialize JSON strings back to dicts for known JSON fields."""
|
||||
if json_fields is None:
|
||||
json_fields = {"metadata", "rules", "privileges", "groups"}
|
||||
|
||||
result = dict(data)
|
||||
for key, value in result.items():
|
||||
if key in json_fields and isinstance(value, str):
|
||||
with suppress(json.JSONDecodeError, TypeError, ValueError):
|
||||
result[key] = json.loads(value)
|
||||
return result
|
||||
|
||||
def upsert_node(self, node: Node) -> None:
|
||||
props = self._serialize_for_neo4j(node.to_properties())
|
||||
|
||||
evidence = []
|
||||
for ev in node.evidence:
|
||||
evidence.append(self._serialize_for_neo4j(ev.model_dump()))
|
||||
|
||||
cypher = f"""
|
||||
MERGE (n:{node.label} {{node_id: $node_id}})
|
||||
ON CREATE SET n.created_at = datetime()
|
||||
SET n += $props, n.updated_at = datetime()
|
||||
WITH n
|
||||
UNWIND $evidence AS ev
|
||||
MERGE (e:Evidence {{source_type: ev.source_type, source_id: ev.source_id}})
|
||||
SET e.collected_at = datetime(ev.collected_at),
|
||||
e.metadata = ev.metadata,
|
||||
e.weight = ev.weight,
|
||||
e.confidence = ev.confidence,
|
||||
e.inferred = ev.inferred
|
||||
MERGE (n)-[r:HAS_EVIDENCE]->(e)
|
||||
SET r.weight = ev.weight, r.confidence = ev.confidence
|
||||
"""
|
||||
with self._session() as session:
|
||||
session.execute_write(
|
||||
lambda tx: tx.run(
|
||||
cypher,
|
||||
node_id=str(node.node_id),
|
||||
props=props,
|
||||
evidence=evidence,
|
||||
)
|
||||
)
|
||||
|
||||
def upsert_edge(self, edge: Edge) -> None:
|
||||
evidence = []
|
||||
for ev in edge.evidence:
|
||||
evidence.append(self._serialize_for_neo4j(ev.model_dump()))
|
||||
|
||||
cypher = f"""
|
||||
MATCH (source {{node_id: $source_id}})
|
||||
MATCH (target {{node_id: $target_id}})
|
||||
MERGE (source)-[r:{edge.type}]->(target)
|
||||
SET r.edge_id = $edge_id,
|
||||
r.confidence = $confidence,
|
||||
r.first_seen = coalesce(r.first_seen, datetime($first_seen)),
|
||||
r.last_seen = datetime($last_seen)
|
||||
WITH r
|
||||
UNWIND $evidence AS ev
|
||||
MERGE (e:Evidence {{source_type: ev.source_type, source_id: ev.source_id}})
|
||||
SET e.collected_at = datetime(ev.collected_at),
|
||||
e.metadata = ev.metadata,
|
||||
e.weight = ev.weight,
|
||||
e.confidence = ev.confidence,
|
||||
e.inferred = ev.inferred
|
||||
MERGE (:EdgeEvidence {{edge_id: $edge_id}})-[re:RECORDED_BY]->(e)
|
||||
SET re.weight = ev.weight, re.confidence = ev.confidence
|
||||
"""
|
||||
with self._session() as session:
|
||||
session.execute_write(
|
||||
lambda tx: tx.run(
|
||||
cypher,
|
||||
source_id=str(edge.source),
|
||||
target_id=str(edge.target),
|
||||
edge_id=str(edge.edge_id),
|
||||
confidence=edge.confidence,
|
||||
first_seen=edge.first_seen.isoformat(),
|
||||
last_seen=edge.last_seen.isoformat(),
|
||||
evidence=evidence,
|
||||
)
|
||||
)
|
||||
|
||||
def find_paths(self, source: UUID, target: UUID, max_depth: int = 4) -> list[GraphPath]:
|
||||
cypher = """
|
||||
MATCH (source {node_id: $source_id}), (target {node_id: $target_id})
|
||||
MATCH p = (source)-[*1..5]->(target)
|
||||
WHERE length(p) <= $max_depth
|
||||
WITH p,
|
||||
reduce(
|
||||
cost = 0.0,
|
||||
rel in relationships(p) | cost + coalesce(rel.confidence, 1.0)
|
||||
) AS path_cost
|
||||
ORDER BY length(p), path_cost
|
||||
LIMIT 10
|
||||
RETURN [n IN nodes(p) | n.node_id] AS node_ids,
|
||||
[r IN relationships(p) | type(r)] AS rels,
|
||||
path_cost
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.execute_read(
|
||||
lambda tx: tx.run(
|
||||
cypher, source_id=str(source), target_id=str(target), max_depth=max_depth
|
||||
).data()
|
||||
)
|
||||
paths: list[GraphPath] = []
|
||||
for record in result:
|
||||
paths.append(
|
||||
GraphPath(
|
||||
nodes=[UUID(node_id) for node_id in record["node_ids"]],
|
||||
edges=record["rels"],
|
||||
cost=record["path_cost"],
|
||||
)
|
||||
)
|
||||
return paths
|
||||
|
||||
def get_neighbors(
|
||||
self, node_id: UUID, relationship_types: Sequence[str] | None = None
|
||||
) -> list[UUID]:
|
||||
rel_filter = ""
|
||||
params: dict = {"node_id": str(node_id)}
|
||||
if relationship_types:
|
||||
rel_filter = "WHERE type(r) IN $rel_types"
|
||||
params["rel_types"] = list(relationship_types)
|
||||
cypher = f"""
|
||||
MATCH (n {{node_id: $node_id}})-[r]->(neighbor)
|
||||
{rel_filter}
|
||||
RETURN DISTINCT neighbor.node_id AS node_id
|
||||
"""
|
||||
with self._session() as session:
|
||||
result = session.execute_read(lambda tx: tx.run(cypher, **params).data())
|
||||
return [UUID(record["node_id"]) for record in result]
|
||||
|
||||
def upsert_asset(self, asset: Asset) -> None:
|
||||
self.upsert_node(asset)
|
||||
|
||||
def upsert_network(self, network: NetworkContainer) -> None:
|
||||
self.upsert_node(network)
|
||||
|
||||
def upsert_identity(self, identity: Identity) -> None:
|
||||
self.upsert_node(identity)
|
||||
|
||||
def upsert_policy(self, policy: Policy) -> None:
|
||||
self.upsert_node(policy)
|
||||
|
||||
def run_cypher(self, cypher: str, parameters: dict | None = None) -> Iterable[dict]:
|
||||
with self._session() as session:
|
||||
result = session.execute_read(lambda tx: tx.run(cypher, parameters or {}).data())
|
||||
return result
|
||||
|
||||
def find_asset_by_identifier(self, identifier: str) -> Asset | None:
|
||||
cypher = """
|
||||
MATCH (n:Asset)
|
||||
WHERE $identifier IN n.identifiers OR n.node_id = $identifier
|
||||
OPTIONAL MATCH (n)-[:HAS_EVIDENCE]->(e:Evidence)
|
||||
RETURN properties(n) AS props,
|
||||
head(labels(n)) AS label,
|
||||
collect({
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}) AS evidence
|
||||
LIMIT 1
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_read(lambda tx: tx.run(cypher, identifier=identifier).single())
|
||||
if not record:
|
||||
return None
|
||||
props = record["props"] or {}
|
||||
evidence = self._parse_evidence(record.get("evidence", []))
|
||||
node = self._build_node("Asset", props, evidence)
|
||||
return node if isinstance(node, Asset) else None
|
||||
|
||||
def find_network_by_cidr_or_name(self, cidr_or_name: str) -> NetworkContainer | None:
|
||||
cypher = """
|
||||
MATCH (n:NetworkContainer)
|
||||
WHERE n.cidr = $value OR n.name = $value
|
||||
OPTIONAL MATCH (n)-[:HAS_EVIDENCE]->(e:Evidence)
|
||||
RETURN properties(n) AS props,
|
||||
head(labels(n)) AS label,
|
||||
collect({
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}) AS evidence
|
||||
LIMIT 1
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_read(lambda tx: tx.run(cypher, value=cidr_or_name).single())
|
||||
if not record:
|
||||
return None
|
||||
props = record["props"] or {}
|
||||
evidence = self._parse_evidence(record.get("evidence", []))
|
||||
node = self._build_node("NetworkContainer", props, evidence)
|
||||
return node if isinstance(node, NetworkContainer) else None
|
||||
|
||||
def find_identity_by_name(self, name: str) -> Identity | None:
|
||||
cypher = """
|
||||
MATCH (n:Identity)
|
||||
WHERE n.name = $name
|
||||
OPTIONAL MATCH (n)-[:HAS_EVIDENCE]->(e:Evidence)
|
||||
RETURN properties(n) AS props,
|
||||
head(labels(n)) AS label,
|
||||
collect({
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}) AS evidence
|
||||
LIMIT 1
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_read(lambda tx: tx.run(cypher, name=name).single())
|
||||
if not record:
|
||||
return None
|
||||
props = record["props"] or {}
|
||||
evidence = self._parse_evidence(record.get("evidence", []))
|
||||
node = self._build_node("Identity", props, evidence)
|
||||
return node if isinstance(node, Identity) else None
|
||||
|
||||
def get_edge_evidence(self, edge_id: UUID) -> list[EvidenceRef]:
|
||||
cypher = """
|
||||
MATCH (:EdgeEvidence {edge_id: $edge_id})-[:RECORDED_BY]->(e:Evidence)
|
||||
RETURN collect({
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}) AS evidence
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_read(lambda tx: tx.run(cypher, edge_id=str(edge_id)).single())
|
||||
return self._parse_evidence(record["evidence"]) if record else []
|
||||
|
||||
def clear(self) -> int:
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
WITH count(n) AS node_count
|
||||
MATCH (n)
|
||||
DETACH DELETE n
|
||||
RETURN node_count
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_write(lambda tx: tx.run(cypher).single())
|
||||
if not record:
|
||||
return 0
|
||||
return int(record.get("node_count") or 0)
|
||||
|
||||
def get_node(self, node_id: UUID) -> Node | None:
|
||||
cypher = """
|
||||
MATCH (n {node_id: $node_id})
|
||||
OPTIONAL MATCH (n)-[:HAS_EVIDENCE]->(e:Evidence)
|
||||
RETURN properties(n) AS props,
|
||||
head(labels(n)) AS label,
|
||||
collect({
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}) AS evidence
|
||||
"""
|
||||
with self._session() as session:
|
||||
record = session.execute_read(lambda tx: tx.run(cypher, node_id=str(node_id)).single())
|
||||
if not record:
|
||||
return None
|
||||
props = record["props"] or {}
|
||||
label = record["label"] or "Node"
|
||||
evidence = self._parse_evidence(record.get("evidence", []))
|
||||
return self._build_node(label, props, evidence)
|
||||
|
||||
def list_nodes(self, label: str | None = None, limit: int = 100) -> list[Node]:
|
||||
where = ""
|
||||
params: dict = {"limit": limit}
|
||||
if label:
|
||||
where = f"WHERE n:{label}"
|
||||
cypher = f"""
|
||||
MATCH (n)
|
||||
{where}
|
||||
OPTIONAL MATCH (n)-[:HAS_EVIDENCE]->(e:Evidence)
|
||||
WITH n, collect({{
|
||||
source_type: e.source_type,
|
||||
source_id: e.source_id,
|
||||
collected_at: e.collected_at,
|
||||
weight: e.weight,
|
||||
confidence: e.confidence,
|
||||
inferred: e.inferred,
|
||||
metadata: e.metadata
|
||||
}}) AS evidence
|
||||
RETURN properties(n) AS props, head(labels(n)) AS label, evidence
|
||||
LIMIT $limit
|
||||
"""
|
||||
with self._session() as session:
|
||||
records = session.execute_read(lambda tx: tx.run(cypher, **params).data())
|
||||
nodes: list[Node] = []
|
||||
for record in records:
|
||||
props = record["props"] or {}
|
||||
lbl = record["label"] or "Node"
|
||||
evidence = self._parse_evidence(record.get("evidence", []))
|
||||
nodes.append(self._build_node(lbl, props, evidence))
|
||||
return nodes
|
||||
|
||||
@staticmethod
|
||||
def _coerce_datetime(value: object | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
try:
|
||||
return datetime.fromisoformat(str(value))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from eidolon.core.models.asset import Asset, Identity, NetworkContainer, Policy
|
||||
from eidolon.core.models.graph import Edge, EvidenceRef, GraphPath, Node
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
"""Abstract repository interface for Eidolon's evidence-backed graph."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_node(self, node: Node) -> None:
|
||||
"""Create or update a node with evidence."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_edge(self, edge: Edge) -> None:
|
||||
"""Create or update a relationship with evidence."""
|
||||
|
||||
@abstractmethod
|
||||
def find_paths(self, source: UUID, target: UUID, max_depth: int = 4) -> list[GraphPath]:
|
||||
"""Return paths between two nodes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_neighbors(
|
||||
self, node_id: UUID, relationship_types: Sequence[str] | None = None
|
||||
) -> list[UUID]:
|
||||
"""Return neighbor node IDs for the given node."""
|
||||
|
||||
@abstractmethod
|
||||
def get_node(self, node_id: UUID) -> Node | None:
|
||||
"""Return a node by id if present."""
|
||||
|
||||
@abstractmethod
|
||||
def list_nodes(self, label: str | None = None, limit: int = 100) -> list[Node]:
|
||||
"""List nodes, optionally filtered by label."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_asset(self, asset: Asset) -> None:
|
||||
"""Helper to persist assets with standard label."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_network(self, network: NetworkContainer) -> None:
|
||||
"""Helper to persist network containers."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_identity(self, identity: Identity) -> None:
|
||||
"""Helper to persist identities."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_policy(self, policy: Policy) -> None:
|
||||
"""Helper to persist policies."""
|
||||
|
||||
@abstractmethod
|
||||
def run_cypher(self, cypher: str, parameters: dict | None = None) -> Iterable[dict]:
|
||||
"""Execute arbitrary Cypher for advanced queries (used sparingly)."""
|
||||
|
||||
@abstractmethod
|
||||
def find_asset_by_identifier(self, identifier: str) -> Asset | None:
|
||||
"""Return an Asset node that matches the identifier (IP, hostname, MAC)."""
|
||||
|
||||
@abstractmethod
|
||||
def find_network_by_cidr_or_name(self, cidr_or_name: str) -> NetworkContainer | None:
|
||||
"""Return a NetworkContainer node by CIDR or name."""
|
||||
|
||||
@abstractmethod
|
||||
def find_identity_by_name(self, name: str) -> Identity | None:
|
||||
"""Return an Identity node by canonical name."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_evidence(self, edge_id: UUID) -> list[EvidenceRef]:
|
||||
"""Return evidence references attached to the edge."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> int:
|
||||
"""Delete all nodes and edges from the graph and return count of removed nodes."""
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ApprovalRecord(BaseModel):
|
||||
"""Approval token record stored in Postgres."""
|
||||
|
||||
approval_id: UUID = Field(default_factory=uuid4, alias="id")
|
||||
user_id: str
|
||||
token: str
|
||||
action: str
|
||||
expires_at: datetime
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def create(cls, user_id: str, action: str, ttl_seconds: int) -> ApprovalRecord:
|
||||
token = str(uuid4())
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=ttl_seconds)
|
||||
return cls(user_id=user_id, token=token, action=action, expires_at=expires_at)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.utcnow() >= self.expires_at
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from eidolon.core.models.graph import Node
|
||||
|
||||
|
||||
class Asset(Node):
|
||||
"""Infrastructure asset node."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="Asset")
|
||||
kind: str = Field(
|
||||
default="host",
|
||||
description="Type of asset (host, vm, router, firewall, service, db, saas, etc.)",
|
||||
)
|
||||
env: str | None = Field(default=None, description="Environment tag (prod, staging, dev, etc.)")
|
||||
criticality: str | None = Field(default=None, description="Business impact rating")
|
||||
owner_team: str | None = Field(default=None)
|
||||
lifecycle_state: str | None = Field(
|
||||
default=None, description="active, deprecated, retired, etc."
|
||||
)
|
||||
identifiers: list[str] = Field(
|
||||
default_factory=list, description="Hostnames, IPs, MACs, instance IDs"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class NetworkContainer(Node):
|
||||
"""Network container such as VPC, subnet, VLAN, or security zone."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="NetworkContainer")
|
||||
cidr: str = Field(description="CIDR range for the network container")
|
||||
name: str | None = Field(default=None)
|
||||
network_type: str | None = Field(
|
||||
default=None, description="VPC, subnet, vlan, segment, zone, etc."
|
||||
)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Identity(Node):
|
||||
"""Identity representation for users, service accounts, or groups."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="Identity")
|
||||
identity_type: str = Field(
|
||||
default="user", description="user, service_account, role, group, etc."
|
||||
)
|
||||
name: str = Field(description="Canonical identity name")
|
||||
groups: list[str] = Field(default_factory=list)
|
||||
privileges: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Policy(Node):
|
||||
"""Policy node capturing access/change rules."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="Policy")
|
||||
policy_type: str = Field(default="access")
|
||||
description: str | None = Field(default=None)
|
||||
rules: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Tool(Node):
|
||||
"""Execution tool representation."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="Tool")
|
||||
name: str = Field(description="Tool name (ansible, terraform, ssm, etc.)")
|
||||
version: str | None = Field(default=None)
|
||||
sandbox_execution: bool = Field(default=True)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Capability(Node):
|
||||
"""Capability required by actions (ssh, winrm, api-scope)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="Capability")
|
||||
name: str = Field(description="Capability identifier")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ActionType(Node):
|
||||
"""Action type that tools implement (run_command, change_firewall_rule)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="ActionType")
|
||||
name: str = Field(description="Action name")
|
||||
description: str | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EvidenceSource(Node):
|
||||
"""Evidence source metadata."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="ignore")
|
||||
|
||||
label: str = Field(default="EvidenceSource")
|
||||
source_type: str = Field(description="flow_logs, snmp, lldp, cloud_api, iam_export, etc.")
|
||||
reference: str | None = Field(default=None, description="Opaque reference to the source record")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
message_id: UUID = Field(default_factory=uuid4)
|
||||
role: Literal["user", "assistant", "system", "tool"] = Field(default="user")
|
||||
content: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
session_id: UUID = Field(default_factory=uuid4)
|
||||
user_id: str = Field(default="anonymous")
|
||||
title: str | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
messages: list[ChatMessage] = Field(default_factory=list)
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CollectorEvent(BaseModel):
|
||||
"""Normalized collector event before graph ingestion."""
|
||||
|
||||
event_id: UUID = Field(default_factory=uuid4)
|
||||
source_type: str = Field(description="Collector type (network, cloud, identity, traffic)")
|
||||
source_id: str | None = Field(default=None, description="Opaque identifier from the collector")
|
||||
entity_type: str = Field(
|
||||
description="Target entity type: Asset, NetworkContainer, Identity, etc."
|
||||
)
|
||||
payload: dict[str, Any] = Field(default_factory=dict, description="Normalized event payload")
|
||||
collected_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""Audit event stored in Postgres for traceability."""
|
||||
|
||||
audit_id: UUID = Field(default_factory=uuid4)
|
||||
event_type: str = Field(description="prompt, tool_call, execution, approval")
|
||||
details: dict[str, Any] = Field(default_factory=dict)
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
status: str = Field(default="pending")
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class EvidenceRef(BaseModel):
|
||||
source_type: str = Field(
|
||||
description="Authoritative source type (cloud_api, flow_logs, nmap, etc.)"
|
||||
)
|
||||
source_id: str = Field(description="Opaque source identifier (ARN, filename, record pointer)")
|
||||
collected_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
weight: float = Field(default=1.0, ge=0.0, description="Relative weighting for evidence fusion")
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="Confidence score 0-1")
|
||||
inferred: bool = Field(
|
||||
default=False,
|
||||
description="True when derived from inference rather than direct observation",
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Source-specific metadata (ports, protocol, rule IDs)"
|
||||
)
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base graph node with evidence tracking."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
node_id: UUID = Field(default_factory=uuid4, alias="id")
|
||||
label: str = Field(description="Primary graph label, e.g. Asset or NetworkContainer")
|
||||
evidence: list[EvidenceRef] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_properties(self) -> dict[str, Any]:
|
||||
"""Return serialisable properties for persistence layers."""
|
||||
props = self.model_dump(
|
||||
exclude={"evidence", "label"},
|
||||
by_alias=False,
|
||||
exclude_none=True,
|
||||
)
|
||||
props["node_id"] = str(self.node_id)
|
||||
props["created_at"] = self.created_at.isoformat()
|
||||
props["updated_at"] = self.updated_at.isoformat()
|
||||
return props
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Typed edge with provenance and temporal bounds."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
edge_id: UUID = Field(default_factory=uuid4, alias="id")
|
||||
type: str = Field(description="Relationship type such as MEMBER_OF or CAN_REACH")
|
||||
source: UUID = Field(description="Source node id")
|
||||
target: UUID = Field(description="Target node id")
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
first_seen: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_seen: datetime = Field(default_factory=datetime.utcnow)
|
||||
evidence: list[EvidenceRef] = Field(default_factory=list)
|
||||
|
||||
def to_properties(self) -> dict[str, Any]:
|
||||
props = self.model_dump(exclude={"evidence"}, exclude_none=True)
|
||||
props["edge_id"] = str(self.edge_id)
|
||||
props["source"] = str(self.source)
|
||||
props["target"] = str(self.target)
|
||||
props["first_seen"] = self.first_seen.isoformat()
|
||||
props["last_seen"] = self.last_seen.isoformat()
|
||||
return props
|
||||
|
||||
|
||||
class GraphPath(BaseModel):
|
||||
"""Path result used for queries and planning."""
|
||||
|
||||
nodes: list[UUID]
|
||||
edges: list[str]
|
||||
cost: float | None = None
|
||||
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.core.models.graph import GraphPath
|
||||
|
||||
|
||||
class EntityRef(BaseModel):
|
||||
"""Resolved entity reference used throughout planning and execution."""
|
||||
|
||||
entity_id: UUID | None = Field(default=None, alias="id")
|
||||
entity_type: str = Field(description="Node label such as Asset, NetworkContainer, Identity")
|
||||
display_name: str | None = Field(default=None)
|
||||
confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GraphQuery(BaseModel):
|
||||
"""Representation of a graph query to keep LLM outputs typed."""
|
||||
|
||||
cypher: str
|
||||
parameters: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PlanStep(BaseModel):
|
||||
"""Single step in a generated plan."""
|
||||
|
||||
step_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
action_type: str = Field(
|
||||
description="Action type identifier (run_command, change_firewall_rule, etc.)"
|
||||
)
|
||||
target: EntityRef = Field(description="Primary entity or resource targeted by the step")
|
||||
tool_hint: str | None = Field(default=None, description="Preferred tool for execution")
|
||||
rationale: str = Field(default="", description="Why this step exists")
|
||||
rollback: str | None = Field(default=None, description="Rollback guidance or command")
|
||||
risk: str | None = Field(default=None, description="Risk or blast radius summary")
|
||||
requires_approval: bool = Field(default=True)
|
||||
parameters: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Execution payload or tool parameters for this step",
|
||||
)
|
||||
|
||||
|
||||
class BlastRadius(BaseModel):
|
||||
"""Output of blast radius estimation."""
|
||||
|
||||
affected_nodes: list[UUID] = Field(default_factory=list)
|
||||
paths: list[GraphPath] = Field(default_factory=list)
|
||||
score: float = Field(default=0.0, ge=0.0, description="Higher is riskier")
|
||||
|
||||
|
||||
class ExecutionRequest(BaseModel):
|
||||
"""Request to execute a plan in a gated runtime."""
|
||||
|
||||
dry_run: bool = Field(default=True)
|
||||
steps: list[PlanStep] = Field(default_factory=list)
|
||||
requires_approval: bool = Field(default=True)
|
||||
approval_reason: str | None = Field(default=None)
|
||||
approval_token: str | None = Field(
|
||||
default=None, description="Token proving approval for execution"
|
||||
)
|
||||
|
||||
|
||||
class PlanDraft(BaseModel):
|
||||
"""LLM-friendly wrapper for plan steps."""
|
||||
|
||||
steps: list[PlanStep] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
"""Result of executing a single plan step."""
|
||||
|
||||
step_id: str
|
||||
tool: str | None = None
|
||||
status: str = Field(description="ok, skipped, dry_run, error")
|
||||
output: dict = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ExecutionResponse(BaseModel):
|
||||
"""Response for execution requests, including per-step results."""
|
||||
|
||||
request: ExecutionRequest
|
||||
results: list[ToolExecutionResult] = Field(default_factory=list)
|
||||
status: str = Field(default="ok")
|
||||
audit_id: UUID | None = None
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ScannerOptions(BaseModel):
|
||||
ping_concurrency: int = Field(default=128, ge=32, le=512)
|
||||
port_scan_workers: int = Field(default=32, ge=8, le=64)
|
||||
dns_resolution: bool = True
|
||||
aggressive: bool = False
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class ScannerConfig(BaseModel):
|
||||
network_cidrs: list[str] = Field(default_factory=list)
|
||||
ports: list[int] = Field(default_factory=list)
|
||||
port_preset: str = Field(default="normal")
|
||||
options: ScannerOptions = Field(default_factory=ScannerOptions)
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class ScannerConfigRecord(BaseModel):
|
||||
id: int
|
||||
user_id: str
|
||||
config: ScannerConfig
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
def default_scanner_config() -> ScannerConfig:
|
||||
return ScannerConfig(
|
||||
network_cidrs=["192.168.1.0/24"],
|
||||
ports=[
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
25,
|
||||
53,
|
||||
80,
|
||||
110,
|
||||
143,
|
||||
443,
|
||||
465,
|
||||
587,
|
||||
993,
|
||||
995,
|
||||
3306,
|
||||
3389,
|
||||
5432,
|
||||
8080,
|
||||
8443,
|
||||
],
|
||||
port_preset="normal",
|
||||
options=ScannerOptions(),
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.config.settings import LLMSettings, get_settings
|
||||
|
||||
|
||||
class ThemeSettings(BaseModel):
|
||||
mode: Literal["dark", "light"] = Field(default="dark")
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
theme: ThemeSettings = Field(default_factory=ThemeSettings)
|
||||
llm: LLMSettings = Field(default_factory=lambda: get_settings().llm)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from eidolon.core.models.asset import Asset, Identity, NetworkContainer
|
||||
from eidolon.core.models.graph import EvidenceRef
|
||||
|
||||
|
||||
class EntityResolver:
|
||||
"""Entity resolution with lightweight fuzzy matching and confidence scoring."""
|
||||
|
||||
def __init__(self, match_threshold: float = 0.6) -> None:
|
||||
self.match_threshold = match_threshold
|
||||
|
||||
@staticmethod
|
||||
def _similarity(a: str, b: str) -> float:
|
||||
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
||||
|
||||
def best_identifier_match(self, candidates: Iterable[str], target: str) -> float:
|
||||
return max((self._similarity(candidate, target) for candidate in candidates), default=0.0)
|
||||
|
||||
def build_evidence(
|
||||
self, source_type: str, source_id: str, confidence: float, metadata: dict | None = None
|
||||
) -> EvidenceRef:
|
||||
return EvidenceRef(
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
confidence=confidence,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def resolve_asset(
|
||||
self, payload: dict, source_type: str, source_id: str, confidence: float
|
||||
) -> Asset:
|
||||
identifiers: list[str] = []
|
||||
for key in ("ip", "hostname", "mac"):
|
||||
value = payload.get(key)
|
||||
if value:
|
||||
identifiers.append(str(value))
|
||||
evidence = self.build_evidence(
|
||||
source_type, source_id, confidence, metadata={"payload": payload}
|
||||
)
|
||||
metadata = {k: v for k, v in payload.items() if k not in {"ports"}}
|
||||
if "ports" in payload:
|
||||
metadata["ports"] = payload.get("ports")
|
||||
return Asset(
|
||||
kind=payload.get("kind", "host"),
|
||||
env=payload.get("env"),
|
||||
criticality=payload.get("criticality"),
|
||||
owner_team=payload.get("owner_team"),
|
||||
lifecycle_state=payload.get("status"),
|
||||
identifiers=identifiers,
|
||||
metadata=metadata,
|
||||
evidence=[evidence],
|
||||
)
|
||||
|
||||
def resolve_network(
|
||||
self, payload: dict, source_type: str, source_id: str, confidence: float
|
||||
) -> NetworkContainer:
|
||||
evidence = self.build_evidence(
|
||||
source_type, source_id, confidence, metadata={"payload": payload}
|
||||
)
|
||||
return NetworkContainer(
|
||||
cidr=payload["cidr"],
|
||||
name=payload.get("name"),
|
||||
network_type=payload.get("network_type", "subnet"),
|
||||
metadata=payload.get("metadata", {}),
|
||||
evidence=[evidence],
|
||||
)
|
||||
|
||||
def resolve_identity(
|
||||
self, payload: dict, source_type: str, source_id: str, confidence: float
|
||||
) -> Identity:
|
||||
evidence = self.build_evidence(
|
||||
source_type, source_id, confidence, metadata={"payload": payload}
|
||||
)
|
||||
return Identity(
|
||||
name=payload["name"],
|
||||
identity_type=payload.get("identity_type", "user"),
|
||||
groups=payload.get("groups", []),
|
||||
privileges=payload.get("privileges", []),
|
||||
metadata=payload.get("metadata", {}),
|
||||
evidence=[evidence],
|
||||
)
|
||||
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from eidolon.config.settings import LLMSettings, get_settings
|
||||
from eidolon.core.reasoning.memory import ConversationMemory
|
||||
|
||||
try:
|
||||
import litellm
|
||||
|
||||
# Enable automatic dropping of unsupported parameters for model compatibility
|
||||
litellm.drop_params = True
|
||||
# Enable debug logging if EIDOLON_LLM_DEBUG env var is set
|
||||
if os.getenv("EIDOLON_LLM_DEBUG"):
|
||||
os.environ["LITELLM_LOG"] = "DEBUG"
|
||||
except ImportError: # pragma: no cover - dependency is optional in early phases
|
||||
litellm = None
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
LLM_ENV_KEYS = (
|
||||
"OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"COHERE_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"TOGETHER_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"FIREWORKS_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
def _env_has_llm_key() -> bool:
|
||||
return any(os.getenv(key) for key in LLM_ENV_KEYS)
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Thin wrapper around LiteLLM to enforce structured JSON outputs."""
|
||||
|
||||
def __init__(self, settings: LLMSettings | None = None) -> None:
|
||||
self.settings = settings or get_settings().llm
|
||||
self.memory = ConversationMemory(max_tokens=self.settings.max_context_tokens)
|
||||
if litellm is not None:
|
||||
litellm.drop_params = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if litellm is None:
|
||||
return False
|
||||
return bool(self.settings.api_key or self.settings.api_base or _env_has_llm_key())
|
||||
|
||||
def _is_rate_limit_error(self, error: Exception) -> bool:
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error).__name__.lower()
|
||||
return (
|
||||
("rate" in error_str and "limit" in error_str)
|
||||
or "ratelimit" in error_type
|
||||
or "429" in error_str
|
||||
or "too many requests" in error_str
|
||||
)
|
||||
|
||||
def _retry_with_backoff(self, call_fn, max_retries: int | None = None):
|
||||
retries = max_retries if max_retries is not None else self.settings.max_retries
|
||||
base_delay = self.settings.retry_delay
|
||||
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
return call_fn()
|
||||
except Exception as exc:
|
||||
if not self._is_rate_limit_error(exc) or attempt >= retries:
|
||||
raise
|
||||
delay = base_delay * (2**attempt) + random.uniform(0, 1)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Retry attempts exhausted")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a chat completion request, optionally with tool definitions.
|
||||
|
||||
Returns the raw response dict with 'choices' containing message and optional tool_calls.
|
||||
"""
|
||||
if litellm is None:
|
||||
raise RuntimeError("litellm is not installed in this environment")
|
||||
|
||||
completion_args: dict[str, Any] = {
|
||||
"model": self.settings.model,
|
||||
"messages": messages,
|
||||
"temperature": self.settings.temperature,
|
||||
"max_tokens": self.settings.max_tokens,
|
||||
}
|
||||
if self.settings.reasoning_effort:
|
||||
completion_args["reasoning_effort"] = self.settings.reasoning_effort
|
||||
if tools:
|
||||
completion_args["tools"] = tools
|
||||
completion_args["tool_choice"] = "auto"
|
||||
if self.settings.api_base:
|
||||
completion_args["api_base"] = self.settings.api_base
|
||||
if self.settings.api_key:
|
||||
completion_args["api_key"] = self.settings.api_key
|
||||
|
||||
response = litellm.completion(**completion_args)
|
||||
return response
|
||||
|
||||
def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[Any] | None = None,
|
||||
memory: ConversationMemory | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a response with optional tool calling support."""
|
||||
if litellm is None:
|
||||
raise RuntimeError("litellm is not installed in this environment")
|
||||
|
||||
llm_messages = [{"role": "system", "content": system_prompt}]
|
||||
mem = memory or self.memory
|
||||
history = mem.get_messages_with_summary(messages, llm_call=self._summarize_call)
|
||||
llm_messages.extend(history)
|
||||
|
||||
tool_schemas = None
|
||||
if tools:
|
||||
tool_schemas = [tool.to_openai_function() for tool in tools]
|
||||
|
||||
# Check if history contains any tool calls (Anthropic requires tools= in this case)
|
||||
history_has_tool_calls = any(
|
||||
msg.get("tool_calls") or msg.get("role") == "tool" for msg in llm_messages
|
||||
)
|
||||
|
||||
call_kwargs: dict[str, Any] = {
|
||||
"model": self.settings.model,
|
||||
"messages": llm_messages,
|
||||
"temperature": self.settings.temperature,
|
||||
"max_tokens": self.settings.max_tokens,
|
||||
}
|
||||
if self.settings.reasoning_effort:
|
||||
call_kwargs["reasoning_effort"] = self.settings.reasoning_effort
|
||||
if tool_schemas:
|
||||
call_kwargs["tools"] = tool_schemas
|
||||
call_kwargs["tool_choice"] = "auto"
|
||||
elif history_has_tool_calls and self.settings.model.startswith("claude"):
|
||||
# Anthropic requires tools= param if history contains tool calls
|
||||
# Pass empty tools list to satisfy the requirement
|
||||
call_kwargs["tools"] = []
|
||||
if self.settings.api_base:
|
||||
call_kwargs["api_base"] = self.settings.api_base
|
||||
if self.settings.api_key:
|
||||
call_kwargs["api_key"] = self.settings.api_key
|
||||
if self.settings.top_p != 1.0:
|
||||
call_kwargs["top_p"] = self.settings.top_p
|
||||
if self.settings.frequency_penalty != 0.0:
|
||||
call_kwargs["frequency_penalty"] = self.settings.frequency_penalty
|
||||
if self.settings.presence_penalty != 0.0:
|
||||
call_kwargs["presence_penalty"] = self.settings.presence_penalty
|
||||
|
||||
def _extract_response(
|
||||
raw_response: Any,
|
||||
) -> tuple[str | None, list[Any] | None, dict | None, str, str]:
|
||||
choices = (
|
||||
raw_response["choices"] if isinstance(raw_response, dict) else raw_response.choices
|
||||
)
|
||||
choice = choices[0]
|
||||
message = choice["message"] if isinstance(choice, dict) else choice.message
|
||||
tool_calls = (
|
||||
message.get("tool_calls")
|
||||
if isinstance(message, dict)
|
||||
else getattr(message, "tool_calls", None)
|
||||
)
|
||||
content = message.get("content") if isinstance(message, dict) else message.content
|
||||
usage = (
|
||||
raw_response.get("usage") if isinstance(raw_response, dict) else raw_response.usage
|
||||
)
|
||||
usage_dict = None
|
||||
if usage:
|
||||
try:
|
||||
usage_dict = dict(usage)
|
||||
except (TypeError, ValueError):
|
||||
usage_dict = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
finish_reason = (
|
||||
choice.get("finish_reason") if isinstance(choice, dict) else choice.finish_reason
|
||||
)
|
||||
model = (
|
||||
raw_response.get("model") if isinstance(raw_response, dict) else raw_response.model
|
||||
)
|
||||
return (
|
||||
content,
|
||||
tool_calls,
|
||||
usage_dict,
|
||||
finish_reason or "",
|
||||
model or self.settings.model,
|
||||
)
|
||||
|
||||
def _call_llm(kwargs: dict[str, Any]) -> Any:
|
||||
return self._retry_with_backoff(lambda: litellm.completion(**kwargs))
|
||||
|
||||
try:
|
||||
response = _call_llm(call_kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return LLMResponse(
|
||||
content=f"LLM Error: {exc}",
|
||||
tool_calls=None,
|
||||
usage=None,
|
||||
model=self.settings.model,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
content, tool_calls, usage_dict, finish_reason, model = _extract_response(response)
|
||||
|
||||
# Retry once if the model hit output length and produced no visible content.
|
||||
if not tool_calls and not content and finish_reason == "length":
|
||||
fallback_kwargs = dict(call_kwargs)
|
||||
fallback_kwargs["max_tokens"] = max(self.settings.max_tokens, 4096)
|
||||
if not self.settings.reasoning_effort:
|
||||
fallback_kwargs["reasoning_effort"] = "low"
|
||||
with suppress(Exception):
|
||||
response = _call_llm(fallback_kwargs)
|
||||
content, tool_calls, usage_dict, finish_reason, model = _extract_response(response)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
usage=usage_dict,
|
||||
model=model or self.settings.model,
|
||||
finish_reason=finish_reason or "",
|
||||
)
|
||||
|
||||
def generate_structured(self, prompt: str, schema: type[T]) -> T:
|
||||
if litellm is None:
|
||||
raise RuntimeError("litellm is not installed in this environment")
|
||||
|
||||
schema_json = schema.model_json_schema()
|
||||
schema_prompt = f"{prompt}\n\nSchema:\n{json.dumps(schema_json, indent=2)}"
|
||||
completion_args = {
|
||||
"model": self.settings.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Return JSON that matches the provided schema."},
|
||||
{"role": "user", "content": schema_prompt},
|
||||
],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": schema.__name__,
|
||||
"schema": schema_json,
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
"temperature": self.settings.temperature,
|
||||
"max_tokens": self.settings.max_tokens,
|
||||
}
|
||||
if self.settings.api_base:
|
||||
completion_args["api_base"] = self.settings.api_base
|
||||
if self.settings.api_key:
|
||||
completion_args["api_key"] = self.settings.api_key
|
||||
try:
|
||||
response = litellm.completion(**completion_args)
|
||||
except Exception: # noqa: BLE001
|
||||
completion_args["response_format"] = {"type": "json_object"}
|
||||
response = litellm.completion(**completion_args)
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
data = content if isinstance(content, dict) else json.loads(content)
|
||||
return schema.model_validate(data)
|
||||
|
||||
def _summarize_call(self, prompt: str) -> str:
|
||||
if litellm is None:
|
||||
raise RuntimeError("litellm is not installed in this environment")
|
||||
|
||||
completion_args = {
|
||||
"model": self.settings.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a terse summarizer for an infrastructure assistant.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
if self.settings.api_base:
|
||||
completion_args["api_base"] = self.settings.api_base
|
||||
if self.settings.api_key:
|
||||
completion_args["api_key"] = self.settings.api_key
|
||||
try:
|
||||
response = litellm.completion(**completion_args)
|
||||
message = (
|
||||
response["choices"][0]["message"]
|
||||
if isinstance(response, dict)
|
||||
else response.choices[0].message
|
||||
)
|
||||
if isinstance(message, dict):
|
||||
return message.get("content", "") or ""
|
||||
return message.content or ""
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return f"[Summarization failed: {exc}]"
|
||||
|
||||
def clear_memory(self) -> None:
|
||||
"""Clear conversation memory and summary cache."""
|
||||
self.memory.clear_summary_cache()
|
||||
|
||||
def get_memory_stats(self) -> dict:
|
||||
"""Get memory usage statistics."""
|
||||
return self.memory.get_stats()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from LLM."""
|
||||
|
||||
content: str | None
|
||||
tool_calls: list[Any] | None
|
||||
usage: dict | None
|
||||
model: str = ""
|
||||
finish_reason: str = ""
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Conversation memory management with summarization support."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
SUMMARY_PROMPT = """Summarize the following conversation segment for an infrastructure assistant.
|
||||
The summary will be used to continue the task, so preserve operational details and decisions.
|
||||
|
||||
What to preserve:
|
||||
- Targets, systems, and environment details
|
||||
- Tools executed and their outcomes
|
||||
- Findings, errors, and important observations
|
||||
- Decisions made and next steps
|
||||
- Paths, commands, parameters, and identifiers
|
||||
|
||||
Compression approach:
|
||||
- Consolidate repetition
|
||||
- Keep technical precision
|
||||
- Remove conversational back-and-forth
|
||||
|
||||
Conversation segment:
|
||||
{conversation}
|
||||
|
||||
Provide a concise technical summary:"""
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
"""Manages conversation history with token limits and summarization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: int = 128000,
|
||||
reserve_ratio: float = 0.8,
|
||||
recent_to_keep: int = 10,
|
||||
summarize_threshold: float = 0.6,
|
||||
):
|
||||
self.max_tokens = max_tokens
|
||||
self.reserve_ratio = reserve_ratio
|
||||
self.recent_to_keep = recent_to_keep
|
||||
self.summarize_threshold = summarize_threshold
|
||||
self._encoder = None
|
||||
self._cached_summary: str | None = None
|
||||
self._summarized_count: int = 0
|
||||
|
||||
@property
|
||||
def encoder(self):
|
||||
"""Lazy load the tokenizer."""
|
||||
if self._encoder is None:
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
self._encoder = tiktoken.get_encoding("cl100k_base")
|
||||
except ImportError:
|
||||
self._encoder = None
|
||||
return self._encoder
|
||||
|
||||
def _count_tokens_with_litellm(self, text: str, model: str) -> int | None:
|
||||
"""Try to count tokens using litellm for better accuracy."""
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
count = litellm.token_counter(model=model, text=text)
|
||||
return int(count)
|
||||
except (RuntimeError, TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@property
|
||||
def token_budget(self) -> int:
|
||||
"""Available tokens for history."""
|
||||
return int(self.max_tokens * self.reserve_ratio)
|
||||
|
||||
def get_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Get messages that fit within token limit (sync, no summarization).
|
||||
Falls back to truncation if over budget.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
if self._cached_summary and len(messages) > self._summarized_count:
|
||||
summary_msg = {
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary:\n{self._cached_summary}",
|
||||
}
|
||||
recent = messages[self._summarized_count :]
|
||||
return [
|
||||
summary_msg,
|
||||
*self._truncate_to_fit(recent, self.token_budget - self._count_tokens(summary_msg)),
|
||||
]
|
||||
|
||||
return self._truncate_to_fit(messages, self.token_budget)
|
||||
|
||||
def get_messages_with_summary(
|
||||
self, messages: list[dict], llm_call: Callable[[str], str]
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get messages, summarizing older ones if needed.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
total_tokens = self.get_total_tokens(messages)
|
||||
threshold_tokens = int(self.token_budget * self.summarize_threshold)
|
||||
|
||||
if total_tokens <= threshold_tokens:
|
||||
return messages
|
||||
|
||||
if len(messages) <= self.recent_to_keep:
|
||||
return self._truncate_to_fit(messages, self.token_budget)
|
||||
|
||||
split_point = len(messages) - self.recent_to_keep
|
||||
older = messages[:split_point]
|
||||
recent = messages[-self.recent_to_keep :]
|
||||
|
||||
if split_point <= self._summarized_count and self._cached_summary:
|
||||
result = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Previous conversation summary:\n{self._cached_summary}",
|
||||
}
|
||||
]
|
||||
result.extend(recent)
|
||||
return result
|
||||
|
||||
summary = self._summarize(older, llm_call)
|
||||
|
||||
self._cached_summary = summary
|
||||
self._summarized_count = split_point
|
||||
|
||||
result = [{"role": "system", "content": f"Previous conversation summary:\n{summary}"}]
|
||||
result.extend(recent)
|
||||
return result
|
||||
|
||||
def _summarize(self, messages: list[dict], llm_call: Callable[[str], str]) -> str:
|
||||
"""Summarize a list of messages using chunked approach."""
|
||||
if not messages:
|
||||
return "[No messages to summarize]"
|
||||
|
||||
chunk_size = 10
|
||||
summaries = []
|
||||
|
||||
for i in range(0, len(messages), chunk_size):
|
||||
chunk = messages[i : i + chunk_size]
|
||||
conversation_text = self._format_for_summary(chunk)
|
||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||
|
||||
try:
|
||||
chunk_summary = llm_call(prompt)
|
||||
if chunk_summary and chunk_summary.strip():
|
||||
summaries.append(chunk_summary.strip())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
summaries.append(
|
||||
f"[{len(chunk)} messages from segment {i // chunk_size + 1} - "
|
||||
f"summary failed: {exc}]"
|
||||
)
|
||||
|
||||
if not summaries:
|
||||
return f"[{len(messages)} earlier messages - all summarization attempts failed]"
|
||||
|
||||
if len(summaries) == 1:
|
||||
return summaries[0]
|
||||
|
||||
combined = "\n\n".join(f"Segment {i + 1}: {summary}" for i, summary in enumerate(summaries))
|
||||
return combined
|
||||
|
||||
def _format_for_summary(self, messages: list[dict]) -> str:
|
||||
"""Format messages as text for summarization."""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
max_length = 4000 if role == "tool" else 2000
|
||||
if len(content) > max_length:
|
||||
if role == "tool":
|
||||
half = max_length // 2
|
||||
content = (
|
||||
content[:half]
|
||||
+ f"\n...[{len(content) - max_length} chars truncated]...\n"
|
||||
+ content[-half:]
|
||||
)
|
||||
else:
|
||||
content = content[:max_length] + "...[truncated]"
|
||||
|
||||
if role == "user":
|
||||
lines.append(f"User: {content}")
|
||||
elif role == "assistant":
|
||||
lines.append(f"Assistant: {content}")
|
||||
elif role == "tool":
|
||||
tool_name = msg.get("name", "tool")
|
||||
lines.append(f"Tool ({tool_name}): {content}")
|
||||
elif role == "system":
|
||||
continue
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
def _truncate_to_fit(self, messages: list[dict], budget: int) -> list[dict]:
|
||||
"""Truncate messages from the beginning to fit budget."""
|
||||
total_tokens = 0
|
||||
result = []
|
||||
|
||||
for msg in reversed(messages):
|
||||
msg_tokens = self._count_tokens(msg)
|
||||
if total_tokens + msg_tokens > budget:
|
||||
break
|
||||
result.insert(0, msg)
|
||||
total_tokens += msg_tokens
|
||||
|
||||
return result
|
||||
|
||||
def _count_tokens(self, message: dict) -> int:
|
||||
"""Count tokens in a message."""
|
||||
content = message.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
if self.encoder:
|
||||
return len(self.encoder.encode(content))
|
||||
return int(len(content.split()) * 1.3)
|
||||
|
||||
return 0
|
||||
|
||||
def get_total_tokens(self, messages: list[dict]) -> int:
|
||||
"""Get total token count for messages."""
|
||||
return sum(self._count_tokens(msg) for msg in messages)
|
||||
|
||||
def fits_in_context(self, messages: list[dict]) -> bool:
|
||||
"""Check if messages fit in context window."""
|
||||
return self.get_total_tokens(messages) <= self.token_budget
|
||||
|
||||
def clear_summary_cache(self):
|
||||
"""Clear the cached summary."""
|
||||
self._cached_summary = None
|
||||
self._summarized_count = 0
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get memory statistics."""
|
||||
return {
|
||||
"max_tokens": self.max_tokens,
|
||||
"token_budget": self.token_budget,
|
||||
"summarize_threshold": int(self.token_budget * self.summarize_threshold),
|
||||
"recent_to_keep": self.recent_to_keep,
|
||||
"has_summary": self._cached_summary is not None,
|
||||
"summarized_message_count": self._summarized_count,
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from eidolon.core.models.plan import EntityRef, PlanStep
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.prompts import PLAN_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
class LLMPlanStep(BaseModel):
|
||||
action_type: str
|
||||
tool_hint: str | None = None
|
||||
rationale: str = ""
|
||||
rollback: str | None = None
|
||||
risk: str | None = None
|
||||
requires_approval: bool = False
|
||||
parameters: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class LLMPlanDraft(BaseModel):
|
||||
steps: list[LLMPlanStep] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Planner:
|
||||
"""Intent -> Plan generator that defers heavy lifting to LLMs in later phases."""
|
||||
|
||||
def __init__(self, llm_client: LiteLLMClient | None = None) -> None:
|
||||
self.llm_client = llm_client
|
||||
|
||||
def generate_plan(self, intent: str, target: EntityRef) -> list[PlanStep]:
|
||||
fallback = [
|
||||
PlanStep(
|
||||
action_type="analyze",
|
||||
target=target,
|
||||
rationale=intent,
|
||||
rollback="No-op rollback for analysis step.",
|
||||
risk="low",
|
||||
requires_approval=False,
|
||||
)
|
||||
]
|
||||
|
||||
if not self.llm_client or not self.llm_client.is_available():
|
||||
return fallback
|
||||
|
||||
prompt = PLAN_PROMPT_TEMPLATE.format(intent=intent, target=target.model_dump())
|
||||
try:
|
||||
draft = self.llm_client.generate_structured(prompt, LLMPlanDraft)
|
||||
except Exception: # noqa: BLE001
|
||||
return fallback
|
||||
|
||||
if not draft.steps:
|
||||
return fallback
|
||||
|
||||
steps: list[PlanStep] = []
|
||||
for draft_step in draft.steps:
|
||||
steps.append(
|
||||
PlanStep(
|
||||
action_type=draft_step.action_type,
|
||||
target=target,
|
||||
tool_hint=draft_step.tool_hint,
|
||||
rationale=draft_step.rationale,
|
||||
rollback=draft_step.rollback,
|
||||
risk=draft_step.risk,
|
||||
requires_approval=draft_step.requires_approval,
|
||||
parameters=draft_step.parameters,
|
||||
)
|
||||
)
|
||||
return steps
|
||||
@@ -0,0 +1,23 @@
|
||||
AGENT_SYSTEM_PROMPT = """
|
||||
You are Eidolon, a cautious infrastructure co-pilot.
|
||||
- Prefer read/plan/simulate by default.
|
||||
- Never invent evidence; cite graph-backed facts.
|
||||
- Request approvals before execution.
|
||||
"""
|
||||
|
||||
QUERY_PROMPT_TEMPLATE = """
|
||||
You translate user questions into Cypher queries over an evidence-backed graph.
|
||||
Use only these labels: Asset, NetworkContainer, Identity, Policy.
|
||||
Use only known relationships: MEMBER_OF, CAN_REACH, DEPENDS_ON, AUTHENTICATES_TO, GOVERNED_BY.
|
||||
If you cannot map the request to a safe graph query, return an answer explaining why and omit
|
||||
graph_query.
|
||||
Question: {question}
|
||||
"""
|
||||
|
||||
PLAN_PROMPT_TEMPLATE = """
|
||||
You generate a cautious, low-risk plan for an infrastructure intent.
|
||||
Return minimal steps, defaulting to read-only analysis unless execution is explicitly required.
|
||||
Include rollback guidance and risk for each step. Use tool_hint only when necessary.
|
||||
Intent: {intent}
|
||||
Target entity: {target}
|
||||
"""
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions
|
||||
from eidolon.core.models.approval import ApprovalRecord
|
||||
from eidolon.core.models.chat import ChatMessage, ChatSession
|
||||
from eidolon.core.models.event import AuditEvent
|
||||
from eidolon.core.models.scanner import (
|
||||
ScannerConfig,
|
||||
ScannerConfigRecord,
|
||||
default_scanner_config,
|
||||
)
|
||||
from eidolon.core.models.settings import AppSettings
|
||||
|
||||
|
||||
class SettingsStore(ABC):
|
||||
"""Abstract persistence for sandbox permissions."""
|
||||
|
||||
@abstractmethod
|
||||
def get_settings(self) -> SandboxPermissions:
|
||||
"""Fetch current sandbox permissions."""
|
||||
|
||||
@abstractmethod
|
||||
def update_settings(self, settings: SandboxPermissions) -> None:
|
||||
"""Update sandbox permissions."""
|
||||
|
||||
@abstractmethod
|
||||
def get_app_settings(self) -> AppSettings:
|
||||
"""Fetch application settings like theme and LLM config."""
|
||||
|
||||
@abstractmethod
|
||||
def update_app_settings(self, settings: AppSettings) -> AppSettings:
|
||||
"""Update application settings."""
|
||||
|
||||
|
||||
class InMemorySettingsStore(SettingsStore):
|
||||
def __init__(self) -> None:
|
||||
self._sandbox = SandboxPermissions()
|
||||
self._app_settings = AppSettings()
|
||||
|
||||
def get_settings(self) -> SandboxPermissions:
|
||||
return self._sandbox
|
||||
|
||||
def update_settings(self, settings: SandboxPermissions) -> None:
|
||||
self._sandbox = settings
|
||||
|
||||
def get_app_settings(self) -> AppSettings:
|
||||
return self._app_settings
|
||||
|
||||
def update_app_settings(self, settings: AppSettings) -> AppSettings:
|
||||
self._app_settings = settings
|
||||
return settings
|
||||
|
||||
|
||||
class ScannerStore(ABC):
|
||||
"""Abstract persistence for scanner configuration and run history."""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, user_id: str) -> ScannerConfigRecord:
|
||||
"""Fetch the current scanner config for a user."""
|
||||
|
||||
@abstractmethod
|
||||
def update_config(self, user_id: str, config: ScannerConfig) -> ScannerConfigRecord:
|
||||
"""Persist scanner config for a user."""
|
||||
|
||||
|
||||
class InMemoryScannerStore(ScannerStore):
|
||||
def __init__(self) -> None:
|
||||
self._configs: dict[str, ScannerConfigRecord] = {}
|
||||
self._next_config_id = 1
|
||||
|
||||
def _ensure_config(self, user_id: str) -> ScannerConfigRecord:
|
||||
record = self._configs.get(user_id)
|
||||
if record:
|
||||
return record
|
||||
config = default_scanner_config()
|
||||
record = ScannerConfigRecord(
|
||||
id=self._next_config_id,
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
self._next_config_id += 1
|
||||
self._configs[user_id] = record
|
||||
return record
|
||||
|
||||
def get_config(self, user_id: str) -> ScannerConfigRecord:
|
||||
return self._ensure_config(user_id)
|
||||
|
||||
def update_config(self, user_id: str, config: ScannerConfig) -> ScannerConfigRecord:
|
||||
record = self._ensure_config(user_id)
|
||||
updated = ScannerConfigRecord(
|
||||
id=record.id,
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
self._configs[user_id] = updated
|
||||
return updated
|
||||
|
||||
|
||||
class AuditStore(ABC):
|
||||
"""Abstract persistence for audit events."""
|
||||
|
||||
@abstractmethod
|
||||
def add(self, event: AuditEvent) -> None:
|
||||
"""Persist an audit event."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, audit_id: UUID) -> AuditEvent | None:
|
||||
"""Fetch a single audit event."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all(self, limit: int = 100) -> list[AuditEvent]:
|
||||
"""Return recent audit events."""
|
||||
|
||||
@abstractmethod
|
||||
def list_filtered(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[AuditEvent]:
|
||||
"""Return filtered and paginated audit events."""
|
||||
|
||||
@abstractmethod
|
||||
def count_filtered(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count events matching filters."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_older_than(self, cutoff_date: datetime) -> int:
|
||||
"""Delete events older than cutoff date. Returns count deleted."""
|
||||
|
||||
|
||||
class InMemoryAuditStore(AuditStore):
|
||||
def __init__(self) -> None:
|
||||
self._events: list[AuditEvent] = []
|
||||
|
||||
def add(self, event: AuditEvent) -> None:
|
||||
self._events.append(event)
|
||||
|
||||
def get(self, audit_id: UUID) -> AuditEvent | None:
|
||||
for ev in self._events:
|
||||
if ev.audit_id == audit_id:
|
||||
return ev
|
||||
return None
|
||||
|
||||
def list_all(self, limit: int = 100) -> list[AuditEvent]:
|
||||
return list(self._events)[-limit:]
|
||||
|
||||
def list_filtered(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[AuditEvent]:
|
||||
filtered = self._events
|
||||
|
||||
if event_type:
|
||||
filtered = [e for e in filtered if e.event_type == event_type]
|
||||
if start_date:
|
||||
filtered = [e for e in filtered if e.timestamp >= start_date]
|
||||
if end_date:
|
||||
filtered = [e for e in filtered if e.timestamp <= end_date]
|
||||
|
||||
# Sort by timestamp desc
|
||||
filtered = sorted(filtered, key=lambda e: e.timestamp, reverse=True)
|
||||
|
||||
# Paginate
|
||||
offset = (page - 1) * page_size
|
||||
return filtered[offset : offset + page_size]
|
||||
|
||||
def count_filtered(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> int:
|
||||
filtered = self._events
|
||||
|
||||
if event_type:
|
||||
filtered = [e for e in filtered if e.event_type == event_type]
|
||||
if start_date:
|
||||
filtered = [e for e in filtered if e.timestamp >= start_date]
|
||||
if end_date:
|
||||
filtered = [e for e in filtered if e.timestamp <= end_date]
|
||||
|
||||
return len(filtered)
|
||||
|
||||
def delete_older_than(self, cutoff_date: datetime) -> int:
|
||||
original_count = len(self._events)
|
||||
self._events = [e for e in self._events if e.timestamp >= cutoff_date]
|
||||
return original_count - len(self._events)
|
||||
|
||||
|
||||
class ApprovalStore(ABC):
|
||||
"""Abstract persistence for approval tokens."""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, user_id: str, action: str, ttl_seconds: int) -> ApprovalRecord:
|
||||
"""Create an approval token for a specific action."""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_token(self, token: str) -> ApprovalRecord | None:
|
||||
"""Lookup approval token details."""
|
||||
|
||||
|
||||
class InMemoryApprovalStore(ApprovalStore):
|
||||
def __init__(self) -> None:
|
||||
self._approvals: list[ApprovalRecord] = []
|
||||
|
||||
def create(self, user_id: str, action: str, ttl_seconds: int) -> ApprovalRecord:
|
||||
approval = ApprovalRecord.create(user_id=user_id, action=action, ttl_seconds=ttl_seconds)
|
||||
self._approvals.append(approval)
|
||||
return approval
|
||||
|
||||
def get_by_token(self, token: str) -> ApprovalRecord | None:
|
||||
for approval in self._approvals:
|
||||
if approval.token == token and not approval.is_expired():
|
||||
return approval
|
||||
return None
|
||||
|
||||
|
||||
class ChatStore(ABC):
|
||||
"""Abstract persistence for chat sessions."""
|
||||
|
||||
@abstractmethod
|
||||
def create_session(self, title: str | None = None, user_id: str | None = None) -> ChatSession:
|
||||
"""Create a new chat session."""
|
||||
|
||||
@abstractmethod
|
||||
def list_sessions(self, limit: int = 50, user_id: str | None = None) -> list[ChatSession]:
|
||||
"""Return recent chat sessions."""
|
||||
|
||||
@abstractmethod
|
||||
def get_session(self, session_id: UUID, user_id: str | None = None) -> ChatSession | None:
|
||||
"""Fetch a single chat session."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_session(self, session_id: UUID, user_id: str | None = None) -> bool:
|
||||
"""Delete a chat session and its messages."""
|
||||
|
||||
@abstractmethod
|
||||
def append_message(
|
||||
self, session_id: UUID, message: ChatMessage, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
"""Append a message to an existing session."""
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_request_messages(
|
||||
self, session_id: UUID, request_id: str, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
"""Remove messages associated with a specific request ID."""
|
||||
|
||||
|
||||
class InMemoryChatStore(ChatStore):
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[UUID, ChatSession] = {}
|
||||
|
||||
def create_session(self, title: str | None = None, user_id: str | None = None) -> ChatSession:
|
||||
session = ChatSession(title=title, user_id=user_id or "anonymous")
|
||||
self._sessions[session.session_id] = session
|
||||
return session
|
||||
|
||||
def list_sessions(self, limit: int = 50, user_id: str | None = None) -> list[ChatSession]:
|
||||
sessions = list(self._sessions.values())
|
||||
if user_id:
|
||||
sessions = [session for session in sessions if session.user_id == user_id]
|
||||
return sessions[-limit:]
|
||||
|
||||
def get_session(self, session_id: UUID, user_id: str | None = None) -> ChatSession | None:
|
||||
session = self._sessions.get(session_id)
|
||||
if session and user_id and session.user_id != user_id:
|
||||
return None
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: UUID, user_id: str | None = None) -> bool:
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return False
|
||||
if user_id and session.user_id != user_id:
|
||||
return False
|
||||
del self._sessions[session_id]
|
||||
return True
|
||||
|
||||
def append_message(
|
||||
self, session_id: UUID, message: ChatMessage, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
if user_id and session.user_id != user_id:
|
||||
return None
|
||||
session.messages.append(message)
|
||||
session.updated_at = datetime.utcnow()
|
||||
return session
|
||||
|
||||
def cleanup_request_messages(
|
||||
self, session_id: UUID, request_id: str, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
session = self._sessions.get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
if user_id and session.user_id != user_id:
|
||||
return None
|
||||
session.messages = [
|
||||
msg for msg in session.messages if msg.metadata.get("request_id") != request_id
|
||||
]
|
||||
session.updated_at = datetime.utcnow()
|
||||
return session
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// Node constraints
|
||||
CREATE CONSTRAINT asset_node_id IF NOT EXISTS FOR (n:Asset) REQUIRE n.node_id IS UNIQUE;
|
||||
CREATE CONSTRAINT network_node_id IF NOT EXISTS FOR (n:NetworkContainer) REQUIRE n.node_id IS UNIQUE;
|
||||
CREATE CONSTRAINT identity_node_id IF NOT EXISTS FOR (n:Identity) REQUIRE n.node_id IS UNIQUE;
|
||||
CREATE CONSTRAINT policy_node_id IF NOT EXISTS FOR (n:Policy) REQUIRE n.node_id IS UNIQUE;
|
||||
|
||||
// Evidence nodes
|
||||
CREATE CONSTRAINT evidence_source IF NOT EXISTS FOR (e:Evidence) REQUIRE (e.source_type, e.source_id) IS NODE KEY;
|
||||
CREATE CONSTRAINT edge_evidence_id IF NOT EXISTS FOR (e:EdgeEvidence) REQUIRE e.edge_id IS UNIQUE;
|
||||
|
||||
// Common indexes for traversal
|
||||
CREATE INDEX asset_identifier IF NOT EXISTS FOR (n:Asset) ON (n.identifiers);
|
||||
CREATE INDEX network_cidr IF NOT EXISTS FOR (n:NetworkContainer) ON (n.cidr);
|
||||
CREATE INDEX identity_name IF NOT EXISTS FOR (n:Identity) ON (n.name);
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
-- Postgres schema for audit trail, chat sessions, approvals, and rate limits.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_events (
|
||||
id UUID PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
details JSONB NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chat_sessions (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
title TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id UUID PRIMARY KEY,
|
||||
session_id UUID NOT NULL REFERENCES chat_sessions(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS approvals (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_approvals_token ON approvals (token);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_event_type ON audit_events (event_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_messages_session ON chat_messages (session_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sandbox_permissions (
|
||||
id TEXT PRIMARY KEY DEFAULT 'default',
|
||||
permissions JSONB NOT NULL,
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS app_settings (
|
||||
id TEXT PRIMARY KEY DEFAULT 'default',
|
||||
settings JSONB NOT NULL,
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scan_configs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
network_cidrs TEXT[] NOT NULL,
|
||||
ports INTEGER[] NOT NULL,
|
||||
port_preset TEXT NOT NULL,
|
||||
collectors JSONB NOT NULL,
|
||||
options JSONB NOT NULL,
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
-- Removed scan_runs table - using audit_log for scan history instead
|
||||
@@ -0,0 +1,855 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions
|
||||
from eidolon.core.models.approval import ApprovalRecord
|
||||
from eidolon.core.models.chat import ChatMessage, ChatSession
|
||||
from eidolon.core.models.event import AuditEvent
|
||||
from eidolon.core.models.scanner import (
|
||||
ScannerConfig,
|
||||
ScannerConfigRecord,
|
||||
default_scanner_config,
|
||||
)
|
||||
from eidolon.core.models.settings import AppSettings
|
||||
from eidolon.core.stores import ApprovalStore, AuditStore, ChatStore, ScannerStore, SettingsStore
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg import sql
|
||||
from psycopg.rows import dict_row
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
psycopg = None
|
||||
sql = None
|
||||
dict_row = None
|
||||
|
||||
if psycopg is None:
|
||||
POSTGRES_ERRORS: tuple[type[Exception], ...] = (RuntimeError, TypeError, ValueError)
|
||||
else:
|
||||
POSTGRES_ERRORS = (psycopg.Error, RuntimeError, TypeError, ValueError)
|
||||
|
||||
|
||||
def _ensure_uuid(value) -> UUID:
|
||||
"""Convert database UUID value to UUID object if needed."""
|
||||
return value if isinstance(value, UUID) else UUID(value)
|
||||
|
||||
|
||||
def postgres_available() -> bool:
|
||||
return psycopg is not None
|
||||
|
||||
|
||||
class PostgresStoreBase:
|
||||
def __init__(self, dsn: str) -> None:
|
||||
self._dsn = dsn
|
||||
|
||||
@contextmanager
|
||||
def _connect(self) -> Iterator[psycopg.Connection]:
|
||||
if psycopg is None:
|
||||
raise RuntimeError("psycopg is not installed")
|
||||
conn = psycopg.connect(self._dsn, row_factory=dict_row)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class PostgresAuditStore(PostgresStoreBase, AuditStore):
|
||||
def __init__(self, dsn: str, fallback: AuditStore | None = None) -> None:
|
||||
super().__init__(dsn)
|
||||
self._fallback = fallback
|
||||
|
||||
def add(self, event: AuditEvent) -> None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO audit_events (id, event_type, details, status, created_at)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
str(event.audit_id),
|
||||
event.event_type,
|
||||
json.dumps(event.details),
|
||||
event.status,
|
||||
event.timestamp,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.add(event)
|
||||
raise
|
||||
|
||||
def get(self, audit_id: UUID) -> AuditEvent | None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, event_type, details, status, created_at
|
||||
FROM audit_events
|
||||
WHERE id = %s
|
||||
""",
|
||||
(str(audit_id),),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
payload = {
|
||||
"audit_id": UUID(row["id"]),
|
||||
"event_type": row["event_type"],
|
||||
"details": row["details"],
|
||||
"status": row["status"],
|
||||
"timestamp": row["created_at"],
|
||||
}
|
||||
return AuditEvent.model_validate(payload)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.get(audit_id)
|
||||
raise
|
||||
|
||||
def list_all(self, limit: int = 100) -> list[AuditEvent]:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, event_type, details, status, created_at
|
||||
FROM audit_events
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
events: list[AuditEvent] = []
|
||||
for row in rows:
|
||||
payload = {
|
||||
"audit_id": _ensure_uuid(row["id"]),
|
||||
"event_type": row["event_type"],
|
||||
"details": row["details"],
|
||||
"status": row["status"],
|
||||
"timestamp": row["created_at"],
|
||||
}
|
||||
events.append(AuditEvent.model_validate(payload))
|
||||
return events
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.list_all(limit=limit)
|
||||
raise
|
||||
|
||||
def list_filtered(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> list[AuditEvent]:
|
||||
try:
|
||||
if sql is None:
|
||||
raise RuntimeError("psycopg is not installed")
|
||||
# Build dynamic WHERE clause
|
||||
conditions: list[sql.Composable] = []
|
||||
params: list = []
|
||||
|
||||
if event_type:
|
||||
conditions.append(sql.SQL("event_type = %s"))
|
||||
params.append(event_type)
|
||||
if start_date:
|
||||
conditions.append(sql.SQL("created_at >= %s"))
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
conditions.append(sql.SQL("created_at <= %s"))
|
||||
params.append(end_date)
|
||||
|
||||
where_clause = sql.SQL(" AND ").join(conditions) if conditions else sql.SQL("TRUE")
|
||||
offset = (page - 1) * page_size
|
||||
params.extend([page_size, offset])
|
||||
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
query = sql.SQL("""
|
||||
SELECT id, event_type, details, status, created_at
|
||||
FROM audit_events
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
""").format(where_clause=where_clause)
|
||||
cur.execute(query, tuple(params))
|
||||
rows = cur.fetchall()
|
||||
|
||||
events: list[AuditEvent] = []
|
||||
for row in rows:
|
||||
payload = {
|
||||
"audit_id": _ensure_uuid(row["id"]),
|
||||
"event_type": row["event_type"],
|
||||
"details": row["details"],
|
||||
"status": row["status"],
|
||||
"timestamp": row["created_at"],
|
||||
}
|
||||
events.append(AuditEvent.model_validate(payload))
|
||||
return events
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.list_filtered(
|
||||
page, page_size, event_type, start_date, end_date
|
||||
)
|
||||
raise
|
||||
|
||||
def count_filtered(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> int:
|
||||
try:
|
||||
if sql is None:
|
||||
raise RuntimeError("psycopg is not installed")
|
||||
conditions: list[sql.Composable] = []
|
||||
params: list = []
|
||||
|
||||
if event_type:
|
||||
conditions.append(sql.SQL("event_type = %s"))
|
||||
params.append(event_type)
|
||||
if start_date:
|
||||
conditions.append(sql.SQL("created_at >= %s"))
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
conditions.append(sql.SQL("created_at <= %s"))
|
||||
params.append(end_date)
|
||||
|
||||
where_clause = sql.SQL(" AND ").join(conditions) if conditions else sql.SQL("TRUE")
|
||||
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
query = sql.SQL(
|
||||
"SELECT COUNT(*) as total FROM audit_events WHERE {where_clause}"
|
||||
).format(where_clause=where_clause)
|
||||
cur.execute(query, tuple(params))
|
||||
result = cur.fetchone()
|
||||
return int(result["total"]) if result else 0
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.count_filtered(event_type, start_date, end_date)
|
||||
raise
|
||||
|
||||
def delete_older_than(self, cutoff_date: datetime) -> int:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM audit_events WHERE created_at < %s",
|
||||
(cutoff_date,),
|
||||
)
|
||||
deleted = cur.rowcount
|
||||
conn.commit()
|
||||
return deleted
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.delete_older_than(cutoff_date)
|
||||
raise
|
||||
|
||||
|
||||
class PostgresApprovalStore(PostgresStoreBase, ApprovalStore):
|
||||
def __init__(self, dsn: str, fallback: ApprovalStore | None = None) -> None:
|
||||
super().__init__(dsn)
|
||||
self._fallback = fallback
|
||||
|
||||
def create(self, user_id: str, action: str, ttl_seconds: int) -> ApprovalRecord:
|
||||
approval = ApprovalRecord.create(user_id=user_id, action=action, ttl_seconds=ttl_seconds)
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO approvals (id, user_id, token, action, expires_at, created_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
str(approval.approval_id),
|
||||
approval.user_id,
|
||||
approval.token,
|
||||
approval.action,
|
||||
approval.expires_at,
|
||||
approval.created_at,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.create(
|
||||
user_id=user_id, action=action, ttl_seconds=ttl_seconds
|
||||
)
|
||||
raise
|
||||
return approval
|
||||
|
||||
def get_by_token(self, token: str) -> ApprovalRecord | None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, token, action, expires_at, created_at
|
||||
FROM approvals
|
||||
WHERE token = %s
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
record = ApprovalRecord.model_validate(
|
||||
{
|
||||
"approval_id": _ensure_uuid(row["id"]),
|
||||
"user_id": row["user_id"],
|
||||
"token": row["token"],
|
||||
"action": row["action"],
|
||||
"expires_at": row["expires_at"],
|
||||
"created_at": row["created_at"],
|
||||
}
|
||||
)
|
||||
return None if record.is_expired() else record
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.get_by_token(token)
|
||||
raise
|
||||
|
||||
|
||||
class PostgresChatStore(PostgresStoreBase, ChatStore):
|
||||
def __init__(self, dsn: str, fallback: ChatStore | None = None) -> None:
|
||||
super().__init__(dsn)
|
||||
self._fallback = fallback
|
||||
self._supports_metadata: bool | None = None
|
||||
|
||||
def _metadata_supported(self, conn) -> bool:
|
||||
if self._supports_metadata is not None:
|
||||
return self._supports_metadata
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'chat_messages'
|
||||
AND column_name = 'metadata'
|
||||
""")
|
||||
self._supports_metadata = cur.fetchone() is not None
|
||||
return self._supports_metadata
|
||||
|
||||
def create_session(self, title: str | None = None, user_id: str | None = None) -> ChatSession:
|
||||
session = ChatSession(title=title, user_id=user_id or "anonymous")
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO chat_sessions (id, user_id, title, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
str(session.session_id),
|
||||
session.user_id,
|
||||
session.title,
|
||||
session.created_at,
|
||||
session.updated_at,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.create_session(title=title, user_id=session.user_id)
|
||||
raise
|
||||
return session
|
||||
|
||||
def list_sessions(self, limit: int = 50, user_id: str | None = None) -> list[ChatSession]:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
if user_id:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, title, created_at, updated_at
|
||||
FROM chat_sessions
|
||||
WHERE user_id = %s
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(user_id, limit),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, title, created_at, updated_at
|
||||
FROM chat_sessions
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
sessions: list[ChatSession] = []
|
||||
for row in rows:
|
||||
session_id = _ensure_uuid(row["id"])
|
||||
messages = self._get_messages(session_id)
|
||||
sessions.append(
|
||||
ChatSession(
|
||||
session_id=session_id,
|
||||
user_id=row["user_id"],
|
||||
title=row["title"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
messages=messages,
|
||||
)
|
||||
)
|
||||
if not sessions and self._fallback:
|
||||
fallback_sessions = self._fallback.list_sessions(limit=limit, user_id=user_id)
|
||||
if fallback_sessions:
|
||||
return fallback_sessions
|
||||
return sessions
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.list_sessions(limit=limit, user_id=user_id)
|
||||
raise
|
||||
|
||||
def get_session(self, session_id: UUID, user_id: str | None = None) -> ChatSession | None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
if user_id:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, title, created_at, updated_at
|
||||
FROM chat_sessions
|
||||
WHERE id = %s AND user_id = %s
|
||||
""",
|
||||
(str(session_id), user_id),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, user_id, title, created_at, updated_at
|
||||
FROM chat_sessions
|
||||
WHERE id = %s
|
||||
""",
|
||||
(str(session_id),),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
if self._fallback:
|
||||
return self._fallback.get_session(session_id, user_id=user_id)
|
||||
return None
|
||||
messages = self._get_messages(session_id)
|
||||
return ChatSession(
|
||||
session_id=_ensure_uuid(row["id"]),
|
||||
user_id=row["user_id"],
|
||||
title=row["title"],
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
messages=messages,
|
||||
)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.get_session(session_id, user_id=user_id)
|
||||
raise
|
||||
|
||||
def delete_session(self, session_id: UUID, user_id: str | None = None) -> bool:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
if user_id:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s AND user_id = %s",
|
||||
(str(session_id), user_id),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s",
|
||||
(str(session_id),),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
if self._fallback:
|
||||
return self._fallback.delete_session(session_id, user_id=user_id)
|
||||
return False
|
||||
cur.execute(
|
||||
"DELETE FROM chat_messages WHERE session_id = %s",
|
||||
(str(session_id),),
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM chat_sessions WHERE id = %s",
|
||||
(str(session_id),),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.delete_session(session_id, user_id=user_id)
|
||||
raise
|
||||
|
||||
def append_message(
|
||||
self, session_id: UUID, message: ChatMessage, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
supports_metadata = self._metadata_supported(conn)
|
||||
with conn.cursor() as cur:
|
||||
# Check if session exists (with or without user_id constraint)
|
||||
if user_id:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s AND user_id = %s",
|
||||
(str(session_id), user_id),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s",
|
||||
(str(session_id),),
|
||||
)
|
||||
|
||||
if not cur.fetchone():
|
||||
if self._fallback:
|
||||
fallback_session = self._fallback.append_message(
|
||||
session_id, message, user_id=user_id
|
||||
)
|
||||
if fallback_session:
|
||||
return fallback_session
|
||||
return None
|
||||
|
||||
# Session exists, insert message with or without metadata
|
||||
if supports_metadata:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO chat_messages (
|
||||
id, session_id, role, content, metadata, created_at
|
||||
)
|
||||
VALUES (%s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
str(message.message_id),
|
||||
str(session_id),
|
||||
message.role,
|
||||
message.content,
|
||||
json.dumps(message.metadata, default=str),
|
||||
message.timestamp,
|
||||
),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO chat_messages (
|
||||
id, session_id, role, content, created_at
|
||||
)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
str(message.message_id),
|
||||
str(session_id),
|
||||
message.role,
|
||||
message.content,
|
||||
message.timestamp,
|
||||
),
|
||||
)
|
||||
cur.execute(
|
||||
"UPDATE chat_sessions SET updated_at = %s WHERE id = %s",
|
||||
(message.timestamp, str(session_id)),
|
||||
)
|
||||
conn.commit()
|
||||
return self.get_session(session_id, user_id=user_id)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.append_message(session_id, message, user_id=user_id)
|
||||
raise
|
||||
|
||||
def cleanup_request_messages(
|
||||
self, session_id: UUID, request_id: str, user_id: str | None = None
|
||||
) -> ChatSession | None:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
supports_metadata = self._metadata_supported(conn)
|
||||
if not supports_metadata:
|
||||
if self._fallback:
|
||||
return self._fallback.cleanup_request_messages(
|
||||
session_id, request_id, user_id=user_id
|
||||
)
|
||||
return self.get_session(session_id, user_id=user_id)
|
||||
with conn.cursor() as cur:
|
||||
if user_id:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s AND user_id = %s",
|
||||
(str(session_id), user_id),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"SELECT id FROM chat_sessions WHERE id = %s",
|
||||
(str(session_id),),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
if self._fallback:
|
||||
return self._fallback.cleanup_request_messages(
|
||||
session_id, request_id, user_id=user_id
|
||||
)
|
||||
return None
|
||||
cur.execute(
|
||||
"""
|
||||
DELETE FROM chat_messages
|
||||
WHERE session_id = %s
|
||||
AND metadata ->> 'request_id' = %s
|
||||
""",
|
||||
(str(session_id), request_id),
|
||||
)
|
||||
cur.execute(
|
||||
"UPDATE chat_sessions SET updated_at = %s WHERE id = %s",
|
||||
(datetime.utcnow(), str(session_id)),
|
||||
)
|
||||
conn.commit()
|
||||
return self.get_session(session_id, user_id=user_id)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.cleanup_request_messages(
|
||||
session_id, request_id, user_id=user_id
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_messages(self, session_id: UUID) -> list[ChatMessage]:
|
||||
with self._connect() as conn:
|
||||
supports_metadata = self._metadata_supported(conn)
|
||||
with conn.cursor() as cur:
|
||||
if supports_metadata:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, role, content, metadata, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = %s
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(str(session_id),),
|
||||
)
|
||||
else:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, role, content, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = %s
|
||||
ORDER BY created_at ASC
|
||||
""",
|
||||
(str(session_id),),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
messages: list[ChatMessage] = []
|
||||
for row in rows:
|
||||
metadata = row.get("metadata") or {}
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
message_id=_ensure_uuid(row["id"]),
|
||||
role=row["role"],
|
||||
content=row["content"],
|
||||
metadata=metadata,
|
||||
timestamp=row["created_at"],
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
class PostgresSettingsStore(PostgresStoreBase, SettingsStore):
|
||||
"""Postgres-backed sandbox permissions store."""
|
||||
|
||||
def get_settings(self) -> SandboxPermissions:
|
||||
"""Fetch sandbox permissions from database, or return defaults if not found."""
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT permissions FROM sandbox_permissions WHERE id = %s",
|
||||
("default",),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return SandboxPermissions()
|
||||
return SandboxPermissions.model_validate(row["permissions"])
|
||||
except POSTGRES_ERRORS:
|
||||
return SandboxPermissions()
|
||||
|
||||
def update_settings(self, settings: SandboxPermissions) -> None:
|
||||
"""Update sandbox permissions in database."""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO sandbox_permissions (id, permissions, updated_at)
|
||||
VALUES (%s, %s, now())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET permissions = EXCLUDED.permissions, updated_at = now()
|
||||
""",
|
||||
("default", json.dumps(settings.model_dump())),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_app_settings(self) -> AppSettings:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT settings FROM app_settings WHERE id = %s",
|
||||
("default",),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return AppSettings()
|
||||
payload = row.get("settings") if isinstance(row, dict) else row[0]
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
payload = json.loads(payload)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
payload = {}
|
||||
return AppSettings.model_validate(payload)
|
||||
except POSTGRES_ERRORS:
|
||||
return AppSettings()
|
||||
|
||||
def update_app_settings(self, settings: AppSettings) -> AppSettings:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO app_settings (id, settings, updated_at)
|
||||
VALUES (%s, %s, now())
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET settings = EXCLUDED.settings, updated_at = now()
|
||||
""",
|
||||
("default", json.dumps(settings.model_dump())),
|
||||
)
|
||||
conn.commit()
|
||||
return settings
|
||||
|
||||
|
||||
class PostgresScannerStore(PostgresStoreBase, ScannerStore):
|
||||
def __init__(self, dsn: str, fallback: ScannerStore | None = None) -> None:
|
||||
super().__init__(dsn)
|
||||
self._fallback = fallback
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json(value, default):
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return default
|
||||
return value
|
||||
|
||||
def _row_to_record(self, row, user_id: str) -> ScannerConfigRecord:
|
||||
options = self._coerce_json(row.get("options"), {})
|
||||
config = ScannerConfig(
|
||||
network_cidrs=row.get("network_cidrs") or [],
|
||||
ports=row.get("ports") or [],
|
||||
port_preset=row.get("port_preset") or "custom",
|
||||
options=options or {},
|
||||
)
|
||||
return ScannerConfigRecord(
|
||||
id=int(row["id"]),
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
updated_at=row.get("updated_at"),
|
||||
)
|
||||
|
||||
def get_config(self, user_id: str) -> ScannerConfigRecord:
|
||||
try:
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, network_cidrs, ports, port_preset, collectors, options,
|
||||
updated_at
|
||||
FROM scan_configs
|
||||
WHERE user_id = %s
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return self._row_to_record(row, user_id)
|
||||
|
||||
config = default_scanner_config()
|
||||
collectors_payload = json.dumps({"network": True})
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scan_configs (
|
||||
user_id, network_cidrs, ports, port_preset, collectors, options,
|
||||
updated_at
|
||||
)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, now())
|
||||
RETURNING id, updated_at
|
||||
""",
|
||||
(
|
||||
user_id,
|
||||
config.network_cidrs,
|
||||
config.ports,
|
||||
config.port_preset,
|
||||
collectors_payload,
|
||||
json.dumps(config.options.model_dump()),
|
||||
),
|
||||
)
|
||||
created = cur.fetchone()
|
||||
conn.commit()
|
||||
return ScannerConfigRecord(
|
||||
id=int(created["id"]),
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
updated_at=created.get("updated_at"),
|
||||
)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.get_config(user_id)
|
||||
raise
|
||||
|
||||
def update_config(self, user_id: str, config: ScannerConfig) -> ScannerConfigRecord:
|
||||
try:
|
||||
collectors_payload = json.dumps({"network": True})
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO scan_configs (
|
||||
user_id, network_cidrs, ports, port_preset, collectors, options,
|
||||
updated_at
|
||||
)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, now())
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET network_cidrs = EXCLUDED.network_cidrs,
|
||||
ports = EXCLUDED.ports,
|
||||
port_preset = EXCLUDED.port_preset,
|
||||
collectors = EXCLUDED.collectors,
|
||||
options = EXCLUDED.options,
|
||||
updated_at = now()
|
||||
RETURNING id, updated_at
|
||||
""",
|
||||
(
|
||||
user_id,
|
||||
config.network_cidrs,
|
||||
config.ports,
|
||||
config.port_preset,
|
||||
collectors_payload,
|
||||
json.dumps(config.options.model_dump()),
|
||||
),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
conn.commit()
|
||||
return ScannerConfigRecord(
|
||||
id=int(row["id"]),
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
updated_at=row.get("updated_at"),
|
||||
)
|
||||
except POSTGRES_ERRORS:
|
||||
if self._fallback:
|
||||
return self._fallback.update_config(user_id, config)
|
||||
raise
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.plan import (
|
||||
EntityRef,
|
||||
ExecutionRequest,
|
||||
ExecutionResponse,
|
||||
ToolExecutionResult,
|
||||
)
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.planner import Planner
|
||||
from eidolon.core.stores import ApprovalStore
|
||||
from eidolon.runtime.executor import ExecutionEngine
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
WAITING = "waiting"
|
||||
RUNNING = "running"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Single-agent runtime that plans and optionally executes with policy/approval checks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: dict[str, Tool] | None,
|
||||
llm_client: LiteLLMClient,
|
||||
max_iterations: int = 8,
|
||||
*,
|
||||
repository: GraphRepository | None = None,
|
||||
approval_store: ApprovalStore | None = None,
|
||||
runtime_settings: SandboxPermissions | None = None,
|
||||
) -> None:
|
||||
self.tools = tools or {}
|
||||
self.llm_client = llm_client
|
||||
self.repository = repository
|
||||
self.approval_store = approval_store
|
||||
self.runtime_settings = runtime_settings
|
||||
self.max_iterations = max_iterations
|
||||
self.state = AgentState.WAITING
|
||||
self.trace: list[dict] = []
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
request: ExecutionRequest,
|
||||
) -> tuple[list[ToolExecutionResult], str]:
|
||||
if not self.repository:
|
||||
raise RuntimeError("repository required for execution")
|
||||
engine = ExecutionEngine(
|
||||
self.repository,
|
||||
runtime_settings=self.runtime_settings,
|
||||
extra_tools=self.tools.values(),
|
||||
)
|
||||
results: list[ToolExecutionResult] = []
|
||||
for step in request.steps:
|
||||
result = engine.execute_step(step, dry_run=request.dry_run)
|
||||
results.append(result)
|
||||
status = "ok" if all(result.status != "error" for result in results) else "partial_failure"
|
||||
self.trace.append(
|
||||
{
|
||||
"event": "execute",
|
||||
"dry_run": request.dry_run,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
return results, status
|
||||
|
||||
def run_intent(
|
||||
self,
|
||||
intent: str,
|
||||
target: EntityRef | None = None,
|
||||
*,
|
||||
dry_run: bool = True,
|
||||
approval_token: str | None = None,
|
||||
) -> dict:
|
||||
self.state = AgentState.RUNNING
|
||||
self.trace = []
|
||||
resolved_target = target or EntityRef(entity_type="Asset", display_name="unknown")
|
||||
planner = Planner(llm_client=self.llm_client)
|
||||
steps = planner.generate_plan(intent=intent, target=resolved_target)
|
||||
if len(steps) > self.max_iterations:
|
||||
steps = steps[: self.max_iterations]
|
||||
self.trace.append({"event": "plan.truncated", "limit": self.max_iterations})
|
||||
self.trace.append({"event": "plan", "steps": len(steps)})
|
||||
|
||||
if dry_run:
|
||||
self.state = AgentState.STOPPED
|
||||
return {
|
||||
"intent": intent,
|
||||
"status": "planned",
|
||||
"steps": [step.model_dump() for step in steps],
|
||||
"trace": self.trace,
|
||||
}
|
||||
|
||||
# Execute the plan
|
||||
request = ExecutionRequest(
|
||||
dry_run=False,
|
||||
steps=steps,
|
||||
requires_approval=any(step.requires_approval for step in steps),
|
||||
approval_token=approval_token,
|
||||
)
|
||||
|
||||
if request.requires_approval:
|
||||
if not approval_token:
|
||||
self.state = AgentState.STOPPED
|
||||
raise RuntimeError("approval token required for execution")
|
||||
if not self.approval_store:
|
||||
self.state = AgentState.STOPPED
|
||||
raise RuntimeError("approval store unavailable")
|
||||
approval = self.approval_store.get_by_token(approval_token)
|
||||
if not approval or approval.action != "execute":
|
||||
self.state = AgentState.STOPPED
|
||||
raise RuntimeError("invalid approval token")
|
||||
|
||||
results, status = self._execute(request)
|
||||
self.state = AgentState.STOPPED
|
||||
response = ExecutionResponse(request=request, results=results, status=status)
|
||||
return {
|
||||
"intent": intent,
|
||||
"status": status,
|
||||
"steps": [step.model_dump() for step in steps],
|
||||
"execution": response.model_dump(),
|
||||
"trace": self.trace,
|
||||
}
|
||||
@@ -0,0 +1,728 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions
|
||||
from eidolon.core.models.chat import ChatMessage
|
||||
from eidolon.core.reasoning.llm import LiteLLMClient
|
||||
from eidolon.core.reasoning.memory import ConversationMemory
|
||||
from eidolon.runtime.sandbox import SandboxRuntime
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
from eidolon.runtime.tools.todo import TodoTool
|
||||
|
||||
# Infrastructure and network tools to detect
|
||||
INFRASTRUCTURE_TOOLS = {
|
||||
"network_discovery": ["nmap", "arp-scan", "masscan", "rustscan", "zmap"],
|
||||
"network_analysis": ["tcpdump", "tshark", "wireshark", "ngrep"],
|
||||
"dns_tools": ["dig", "nslookup", "host", "dnsenum", "dnsrecon"],
|
||||
"cloud_cli": ["aws", "az", "gcloud", "kubectl", "terraform", "ansible", "pulumi"],
|
||||
"container_tools": ["docker", "podman", "kubectl", "helm", "docker-compose"],
|
||||
"monitoring": ["top", "htop", "netstat", "ss", "lsof", "iotop", "iftop"],
|
||||
"network_utilities": [
|
||||
"ping",
|
||||
"traceroute",
|
||||
"mtr",
|
||||
"curl",
|
||||
"wget",
|
||||
"nc",
|
||||
"netcat",
|
||||
"telnet",
|
||||
"ssh",
|
||||
],
|
||||
"system_info": ["ps", "systemctl", "service", "uptime", "df", "free", "uname"],
|
||||
}
|
||||
|
||||
|
||||
def detect_available_tools() -> dict[str, list[str]]:
|
||||
"""Detect which infrastructure tools are available on the system."""
|
||||
available = {}
|
||||
for category, tools in INFRASTRUCTURE_TOOLS.items():
|
||||
found = []
|
||||
for tool in tools:
|
||||
if shutil.which(tool):
|
||||
found.append(tool)
|
||||
if found:
|
||||
available[category] = found
|
||||
return available
|
||||
|
||||
|
||||
def get_graph_summary(repository: Any) -> str:
|
||||
"""Generate a lightweight summary of the graph for system prompt injection."""
|
||||
try:
|
||||
# Get node counts by label
|
||||
count_query = """
|
||||
MATCH (n)
|
||||
WHERE n.node_id IS NOT NULL
|
||||
RETURN labels(n)[0] AS label, count(n) AS count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
counts = list(repository.run_cypher(count_query, {}))
|
||||
|
||||
# Get sample node IDs (first 3)
|
||||
sample_query = """
|
||||
MATCH (n)
|
||||
WHERE n.node_id IS NOT NULL
|
||||
RETURN n.node_id AS id
|
||||
LIMIT 3
|
||||
"""
|
||||
samples = list(repository.run_cypher(sample_query, {}))
|
||||
|
||||
# Get active networks
|
||||
network_query = """
|
||||
MATCH (n:NetworkContainer)
|
||||
WHERE n.cidr IS NOT NULL
|
||||
RETURN n.cidr AS cidr
|
||||
LIMIT 5
|
||||
"""
|
||||
networks = list(repository.run_cypher(network_query, {}))
|
||||
|
||||
# Build summary
|
||||
total = sum(record.get("count", 0) for record in counts)
|
||||
node_breakdown = ", ".join(
|
||||
f"{record.get('count', 0)} {record.get('label', 'Unknown')}" for record in counts[:4]
|
||||
)
|
||||
sample_ids = [str(record.get("id", ""))[:8] + "..." for record in samples[:3]]
|
||||
network_list = [record.get("cidr", "") for record in networks[:5]]
|
||||
|
||||
timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC")
|
||||
|
||||
summary = f"""## Infrastructure Graph Summary (as of {timestamp})
|
||||
- Total nodes: {total} ({node_breakdown})
|
||||
- Sample node IDs: {', '.join(sample_ids) if sample_ids else 'none'}
|
||||
- Active networks: {', '.join(network_list) if network_list else 'none'}
|
||||
"""
|
||||
return summary
|
||||
except Exception: # noqa: BLE001
|
||||
# If graph query fails, return minimal summary
|
||||
return """## Infrastructure Graph Summary
|
||||
- Graph data available via graph_query tool
|
||||
- Use queries to explore nodes, relationships, and metadata
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
result: dict | None = None
|
||||
error: str | None = None
|
||||
success: bool = True
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
tools: Iterable[Tool], permissions: SandboxPermissions, repository: Any | None = None
|
||||
) -> str:
|
||||
tool_lines = "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
allowed = ", ".join(permissions.allowed_tools) if permissions.allowed_tools else "all"
|
||||
blocked = ", ".join(permissions.blocked_tools) if permissions.blocked_tools else "none"
|
||||
|
||||
# Capture system environment info
|
||||
os_type = platform.system() # Windows, Linux, Darwin (macOS)
|
||||
os_release = platform.release()
|
||||
arch = platform.machine() # x86_64, ARM64, etc.
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
hostname = platform.node()
|
||||
|
||||
# Detect available CLI tools
|
||||
available_tools = detect_available_tools()
|
||||
tools_summary = []
|
||||
for category, tools in available_tools.items():
|
||||
tools_summary.append(f" {category}: {', '.join(tools)}")
|
||||
tools_str = (
|
||||
"\n".join(tools_summary) if tools_summary else " (no infrastructure tools detected)"
|
||||
)
|
||||
|
||||
# Get graph summary if repository is available
|
||||
graph_summary = get_graph_summary(repository) if repository else ""
|
||||
|
||||
return f"""You are Eidolon, a network and infrastructure assistant.
|
||||
|
||||
## Operating Environment
|
||||
- OS: {os_type} {os_release}
|
||||
- Architecture: {arch}
|
||||
- Python: {python_version}
|
||||
- Hostname: {hostname}
|
||||
- Shell: {'PowerShell' if os_type == 'Windows' else 'bash'}
|
||||
|
||||
## Available CLI Tools
|
||||
{tools_str}
|
||||
|
||||
{graph_summary}
|
||||
|
||||
## IMPORTANT: Always Check the Graph First
|
||||
|
||||
**The infrastructure graph contains discovered network data from scans.** Before running manual
|
||||
network discovery:
|
||||
1. Query the graph using `graph_query` to see what's already discovered
|
||||
2. Check for networks, assets, and their metadata
|
||||
3. Only use manual tools (nmap, etc.) if the graph lacks the specific information needed
|
||||
|
||||
The graph is your PRIMARY data source for network infrastructure information.
|
||||
|
||||
## Available Tools
|
||||
{tool_lines}
|
||||
|
||||
## Neo4j Graph Reference
|
||||
|
||||
The infrastructure graph uses Neo4j 5.x with the following structure:
|
||||
|
||||
**Node Types:**
|
||||
- Asset (hosts, services) - has `node_id`, `kind`, `metadata` (JSON string)
|
||||
- NetworkContainer (networks) - has `node_id`, `cidr`
|
||||
- Identity (users, accounts) - has `node_id`, `name`, `kind`
|
||||
- Policy (rules) - has `node_id`, `name`, `rule_type`
|
||||
|
||||
**Relationships:**
|
||||
- MEMBER_OF (Asset → NetworkContainer)
|
||||
- CONNECTS_TO (Asset → Asset)
|
||||
- HAS_IDENTITY (Asset → Identity)
|
||||
- GOVERNED_BY (Asset/Network → Policy)
|
||||
|
||||
**Metadata Handling:**
|
||||
The `metadata` field is a JSON string, not a map. Parse after retrieval:
|
||||
```cypher
|
||||
MATCH (a:Asset)
|
||||
WHERE a.metadata IS NOT NULL
|
||||
RETURN a.node_id, a.metadata
|
||||
```
|
||||
|
||||
**Common Metadata Fields** (populated by collectors):
|
||||
- `ip` - IPv4/IPv6 address
|
||||
- `hostname` - DNS hostname
|
||||
- `mac_address` - MAC address (from ARP/nmap)
|
||||
- `vendor` - Network interface vendor (from MAC OUI lookup via nmap)
|
||||
- `ports` - Array of port objects:
|
||||
`[{{"port": 22, "state": "open", "service": "ssh", "version": "..."}}]`
|
||||
- `status` - Host status: "online", "offline", "idle"
|
||||
- `os` - Operating system fingerprint (if available from nmap)
|
||||
- `cidr` - Network CIDR the host belongs to
|
||||
|
||||
**Example - Find hosts by vendor:**
|
||||
```cypher
|
||||
MATCH (a:Asset)
|
||||
WHERE a.metadata CONTAINS '"vendor"'
|
||||
RETURN a.node_id, a.metadata
|
||||
```
|
||||
|
||||
Note: Metadata is stored as JSON string. To search for specific values, use CONTAINS with flexible
|
||||
patterns:
|
||||
- `WHERE a.metadata CONTAINS 'Samsung'` (case-sensitive substring)
|
||||
- Use CONTAINS with just the value to avoid JSON formatting issues
|
||||
|
||||
**Query Examples:**
|
||||
|
||||
List assets:
|
||||
```cypher
|
||||
MATCH (a:Asset) WHERE a.node_id IS NOT NULL
|
||||
RETURN a.node_id, a.metadata LIMIT 100
|
||||
```
|
||||
|
||||
Blast radius (nodes within N hops):
|
||||
```cypher
|
||||
MATCH path = (start:Asset)-[*1..2]-(affected)
|
||||
WHERE start.node_id = $target_id AND affected.node_id IS NOT NULL
|
||||
RETURN DISTINCT affected.node_id, affected.metadata, length(path) AS distance
|
||||
ORDER BY distance
|
||||
```
|
||||
|
||||
Network membership:
|
||||
```cypher
|
||||
MATCH (a:Asset)-[:MEMBER_OF]->(n:NetworkContainer)
|
||||
WHERE n.cidr = $cidr
|
||||
RETURN a.node_id, a.metadata
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- Use `WHERE n.node_id IS NOT NULL` to filter auxiliary nodes
|
||||
- Use `IS NOT NULL` instead of deprecated `exists()`
|
||||
- Parameterize inputs: `WHERE a.node_id = $param`
|
||||
- Add LIMIT to prevent overwhelming results
|
||||
|
||||
## Output Guidelines
|
||||
|
||||
When presenting technical data to users:
|
||||
- Use human-readable identifiers (IPs, hostnames) from metadata, not raw UUIDs
|
||||
- Format results clearly (tables, lists, summaries)
|
||||
- Parse JSON metadata strings to extract meaningful fields
|
||||
|
||||
## Todo Workflow
|
||||
|
||||
When working with todo items:
|
||||
1. **Create** todos for multi-step tasks by calling the `todo` tool with action "set"
|
||||
2. **Execute immediately** - once a todo is created, start working on it in the next iteration
|
||||
3. **Mark progress** - call `todo` with action "complete" when finishing each item, or "skip"
|
||||
if blocked
|
||||
4. **Never ask permission** - do not ask "Want me to run it now?" or similar questions
|
||||
5. **Stay focused** - work through the todo list without unnecessary intermediate messages
|
||||
|
||||
The todo tool is for YOUR planning and tracking, not for soliciting user input.
|
||||
|
||||
## Sandbox Permissions
|
||||
- allow_shell: {{permissions.allow_shell}}
|
||||
- allow_network: {{permissions.allow_network}}
|
||||
- allow_file_write: {{permissions.allow_file_write}}
|
||||
- allowed_tools: {allowed}
|
||||
- blocked_tools: {blocked}
|
||||
"""
|
||||
|
||||
|
||||
class AssistantAgent:
|
||||
"""Single-mode assistant loop with tool calling and todo-driven iterations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: LiteLLMClient,
|
||||
sandbox: SandboxRuntime,
|
||||
system_prompt: str,
|
||||
max_iterations: int = 8,
|
||||
) -> None:
|
||||
self.llm_client = llm_client
|
||||
self.sandbox = sandbox
|
||||
self.system_prompt = system_prompt
|
||||
self.max_iterations = max_iterations
|
||||
self.memory = ConversationMemory(max_tokens=self.llm_client.settings.max_context_tokens)
|
||||
|
||||
def run(
|
||||
self, history: list[ChatMessage], cancellation_token: Any | None = None
|
||||
) -> list[ChatMessage]:
|
||||
"""Run the assistant loop and return new messages."""
|
||||
return list(self.run_iter(history, cancellation_token=cancellation_token))
|
||||
|
||||
def run_iter(
|
||||
self, history: list[ChatMessage], cancellation_token: Any | None = None
|
||||
) -> Iterable[ChatMessage]:
|
||||
"""Run the assistant loop and yield messages as they are produced."""
|
||||
working_history = list(history)
|
||||
todo_tool = self._get_todo_tool()
|
||||
if todo_tool:
|
||||
self._restore_todo_state(todo_tool, working_history)
|
||||
todo_engaged = bool(todo_tool and todo_tool.items)
|
||||
completed = False
|
||||
iterations = 0
|
||||
|
||||
while iterations < self.max_iterations:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
iterations += 1
|
||||
response = self.llm_client.generate(
|
||||
system_prompt=self.system_prompt,
|
||||
messages=self._format_messages_for_llm(working_history),
|
||||
tools=list(self.sandbox.active_tools.values()),
|
||||
memory=self.memory,
|
||||
)
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
|
||||
if not response.tool_calls:
|
||||
content = response.content or ""
|
||||
if not content:
|
||||
error_summary = self._summarize_recent_tool_errors(working_history)
|
||||
if error_summary:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
error_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=error_summary,
|
||||
metadata={"kind": "message", "tool_error": True},
|
||||
)
|
||||
working_history.append(error_msg)
|
||||
yield error_msg
|
||||
# Check if there are pending todos before breaking
|
||||
todo_pending = todo_tool.has_pending() if todo_tool else False
|
||||
if not todo_pending:
|
||||
completed = True
|
||||
break
|
||||
# Continue loop to execute pending todos
|
||||
continue
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
# Check if there are pending todos before breaking
|
||||
todo_pending = todo_tool.has_pending() if todo_tool else False
|
||||
if not todo_pending:
|
||||
empty_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content="Agent returned an empty response.",
|
||||
metadata={"kind": "thinking", "empty_response": True},
|
||||
)
|
||||
working_history.append(empty_msg)
|
||||
yield empty_msg
|
||||
break
|
||||
# Continue loop to execute pending todos
|
||||
continue
|
||||
|
||||
todo_pending = todo_tool.has_pending() if todo_tool else False
|
||||
kind = "thinking" if todo_pending else "message"
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
metadata={
|
||||
"kind": kind,
|
||||
"intermediate": todo_pending,
|
||||
"usage": response.usage,
|
||||
},
|
||||
)
|
||||
working_history.append(msg)
|
||||
yield msg
|
||||
if not todo_pending:
|
||||
completed = True
|
||||
break
|
||||
continue
|
||||
|
||||
if response.content:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
thinking_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.content,
|
||||
metadata={
|
||||
"kind": "thinking",
|
||||
"intermediate": True,
|
||||
"transient": True,
|
||||
"usage": response.usage,
|
||||
},
|
||||
)
|
||||
working_history.append(thinking_msg)
|
||||
yield thinking_msg
|
||||
|
||||
tool_calls = self._normalize_tool_calls(response.tool_calls)
|
||||
if not tool_calls:
|
||||
content = response.content or ""
|
||||
if not content:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
empty_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content="Agent returned an empty tool call payload.",
|
||||
metadata={"kind": "warning", "empty_tool_calls": True},
|
||||
)
|
||||
working_history.append(empty_msg)
|
||||
yield empty_msg
|
||||
break
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
metadata={"kind": "message", "usage": response.usage},
|
||||
)
|
||||
working_history.append(msg)
|
||||
yield msg
|
||||
completed = True
|
||||
break
|
||||
|
||||
if any(call.get("name") == "todo" for call in tool_calls):
|
||||
todo_engaged = True
|
||||
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
tool_call_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.content or "",
|
||||
metadata={"kind": "tool_call", "tool_calls": tool_calls, "usage": response.usage},
|
||||
)
|
||||
working_history.append(tool_call_msg)
|
||||
yield tool_call_msg
|
||||
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
tool_results = self._execute_tools(tool_calls)
|
||||
for result in tool_results:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
content = self._serialize_tool_output(result.result, result.error)
|
||||
tool_msg = ChatMessage(
|
||||
role="tool",
|
||||
content=content,
|
||||
metadata={
|
||||
"kind": "tool_result",
|
||||
"tool_call_id": result.tool_call_id,
|
||||
"tool_name": result.tool_name,
|
||||
"success": result.success,
|
||||
"result": self._safe_json(result.result),
|
||||
"error": result.error,
|
||||
},
|
||||
)
|
||||
working_history.append(tool_msg)
|
||||
yield tool_msg
|
||||
|
||||
if any(call.get("name") == "finish" for call in tool_calls):
|
||||
completed = True
|
||||
break
|
||||
|
||||
for call, result in zip(tool_calls, tool_results, strict=False):
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
if call.get("name") != "todo":
|
||||
continue
|
||||
action = call.get("arguments", {}).get("action")
|
||||
if action not in {"set", "add"}:
|
||||
continue
|
||||
items = []
|
||||
if isinstance(result.result, dict):
|
||||
items = result.result.get("items", [])
|
||||
steps = [
|
||||
str(item.get("text"))
|
||||
for item in items
|
||||
if isinstance(item, dict) and item.get("text")
|
||||
]
|
||||
if not steps:
|
||||
continue
|
||||
plan_content = "\n".join(f"{idx + 1}. {step}" for idx, step in enumerate(steps))
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
plan_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=plan_content,
|
||||
metadata={"kind": "plan", "steps": steps},
|
||||
)
|
||||
working_history.append(plan_msg)
|
||||
yield plan_msg
|
||||
|
||||
if todo_engaged and todo_tool and not todo_tool.has_pending():
|
||||
summary_msg = self._generate_summary(working_history)
|
||||
if summary_msg:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
break
|
||||
working_history.append(summary_msg)
|
||||
yield summary_msg
|
||||
completed = True
|
||||
break
|
||||
|
||||
if not completed and iterations >= self.max_iterations:
|
||||
if self._is_cancelled(cancellation_token):
|
||||
return
|
||||
warning = ChatMessage(
|
||||
role="assistant",
|
||||
content=f"Reached iteration limit ({self.max_iterations}).",
|
||||
metadata={"kind": "warning", "max_iterations": True},
|
||||
)
|
||||
yield warning
|
||||
|
||||
def _is_cancelled(self, cancellation_token: Any | None) -> bool:
|
||||
if not cancellation_token:
|
||||
return False
|
||||
is_set = getattr(cancellation_token, "is_set", None)
|
||||
if callable(is_set):
|
||||
try:
|
||||
return bool(is_set())
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
return False
|
||||
|
||||
def _get_todo_tool(self) -> TodoTool | None:
|
||||
tool = self.sandbox.active_tools.get("todo")
|
||||
return tool if isinstance(tool, TodoTool) else None
|
||||
|
||||
def _summarize_recent_tool_errors(self, history: list[ChatMessage]) -> str | None:
|
||||
errors: list[str] = []
|
||||
for msg in reversed(history[-6:]):
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
error = msg.metadata.get("error")
|
||||
if error:
|
||||
errors.append(str(error))
|
||||
if not errors:
|
||||
return None
|
||||
if len(errors) == 1:
|
||||
return f"Tool error: {errors[0]}"
|
||||
joined = "\n".join(f"- {err}" for err in errors)
|
||||
return f"Multiple tool errors:\n{joined}"
|
||||
|
||||
def _restore_todo_state(self, todo_tool: TodoTool, history: list[ChatMessage]) -> None:
|
||||
if todo_tool.items:
|
||||
return
|
||||
for msg in reversed(history):
|
||||
if msg.role != "tool":
|
||||
continue
|
||||
if msg.metadata.get("tool_name") == "finish":
|
||||
return
|
||||
if msg.metadata.get("tool_name") != "todo":
|
||||
continue
|
||||
result = msg.metadata.get("result")
|
||||
if not isinstance(result, dict):
|
||||
break
|
||||
items = result.get("items")
|
||||
if not isinstance(items, list):
|
||||
break
|
||||
restored = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
restored.append(item)
|
||||
else:
|
||||
restored.append(
|
||||
{"id": len(restored) + 1, "text": str(item), "status": "pending"}
|
||||
)
|
||||
todo_tool.items = restored
|
||||
max_id = max((item.get("id", 0) for item in restored), default=0)
|
||||
todo_tool._next_id = int(max_id) + 1
|
||||
break
|
||||
|
||||
def _generate_summary(self, history: list[ChatMessage]) -> ChatMessage | None:
|
||||
response = self.llm_client.generate(
|
||||
system_prompt="You are a helpful assistant. Provide a concise summary of the results.",
|
||||
messages=self._format_messages_for_llm(history),
|
||||
tools=None,
|
||||
memory=self.memory,
|
||||
)
|
||||
content = response.content or ""
|
||||
if not content.strip():
|
||||
content = "Task complete."
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
metadata={"kind": "message", "summary": True, "usage": response.usage},
|
||||
)
|
||||
|
||||
def _format_messages_for_llm(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if msg.metadata.get("transient"):
|
||||
continue
|
||||
if msg.role == "tool":
|
||||
tool_call_id = msg.metadata.get("tool_call_id")
|
||||
entry = {"role": "tool", "content": msg.content}
|
||||
if tool_call_id:
|
||||
entry["tool_call_id"] = tool_call_id
|
||||
tool_name = msg.metadata.get("tool_name")
|
||||
if tool_name:
|
||||
entry["name"] = tool_name
|
||||
formatted.append(entry)
|
||||
continue
|
||||
|
||||
entry = {"role": msg.role, "content": msg.content}
|
||||
tool_calls = msg.metadata.get("tool_calls")
|
||||
if tool_calls:
|
||||
entry["tool_calls"] = [
|
||||
{
|
||||
"id": call["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call["name"],
|
||||
"arguments": json.dumps(call["arguments"], default=str),
|
||||
},
|
||||
}
|
||||
for call in tool_calls
|
||||
]
|
||||
formatted.append(entry)
|
||||
return formatted
|
||||
|
||||
def _normalize_tool_calls(self, raw_calls: list[Any]) -> list[dict[str, Any]]:
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for index, call in enumerate(raw_calls or []):
|
||||
call_id = getattr(call, "id", None)
|
||||
if not call_id and isinstance(call, dict):
|
||||
call_id = call.get("id")
|
||||
if not call_id:
|
||||
call_id = f"call_{index}"
|
||||
|
||||
func = getattr(call, "function", None)
|
||||
if func is None and isinstance(call, dict):
|
||||
func = call.get("function", {})
|
||||
if hasattr(func, "name"):
|
||||
name = func.name
|
||||
args_raw = func.arguments
|
||||
else:
|
||||
func_dict = func if isinstance(func, dict) else {}
|
||||
name = func_dict.get("name") or (
|
||||
call.get("name", "") if isinstance(call, dict) else ""
|
||||
)
|
||||
args_raw = func_dict.get("arguments")
|
||||
if args_raw is None and isinstance(call, dict):
|
||||
args_raw = call.get("arguments", {})
|
||||
|
||||
if not name:
|
||||
continue
|
||||
|
||||
arguments = self._parse_arguments(args_raw)
|
||||
tool_calls.append({"id": call_id, "name": name, "arguments": arguments})
|
||||
return tool_calls
|
||||
|
||||
def _parse_arguments(self, args: Any) -> dict:
|
||||
if isinstance(args, dict):
|
||||
return args
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
return json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
return {"raw": args}
|
||||
return {}
|
||||
|
||||
def _execute_tools(self, tool_calls: list[dict[str, Any]]) -> list[ToolResult]:
|
||||
results: list[ToolResult] = []
|
||||
todo_tool = self._get_todo_tool()
|
||||
todo_locked = bool(todo_tool and todo_tool.items)
|
||||
for call in tool_calls:
|
||||
tool_name = call["name"]
|
||||
tool_call_id = call["id"]
|
||||
arguments = call.get("arguments", {})
|
||||
if tool_name == "todo" and todo_locked:
|
||||
action = str(arguments.get("action", "")).lower()
|
||||
if action not in {"list", "complete", "skip"}:
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
error=(
|
||||
"todo list is already initialized; only 'complete', 'skip', or "
|
||||
"'list' allowed until finish"
|
||||
),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
continue
|
||||
try:
|
||||
result = self.sandbox.execute(tool_name, arguments)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
error=str(exc),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
error=str(result.get("error")),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
ToolResult(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
result=result if isinstance(result, dict) else {"result": result},
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def _serialize_tool_output(self, result: dict | None, error: str | None) -> str:
|
||||
payload: dict[str, Any] = {"error": error} if error else result or {"result": "ok"}
|
||||
try:
|
||||
return json.dumps(payload, ensure_ascii=True)
|
||||
except TypeError:
|
||||
return json.dumps({"result": str(payload)}, ensure_ascii=True)
|
||||
|
||||
def _safe_json(self, value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return json.loads(json.dumps(value, default=str))
|
||||
except TypeError:
|
||||
return {"result": str(value)}
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions, get_settings
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.plan import PlanStep, ToolExecutionResult
|
||||
from eidolon.runtime.sandbox import SandboxRuntime
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
from eidolon.runtime.tools.browser import BrowserTool
|
||||
from eidolon.runtime.tools.file_edit import FileEditTool
|
||||
from eidolon.runtime.tools.finish import FinishTool
|
||||
from eidolon.runtime.tools.graph_query import GraphQueryTool
|
||||
from eidolon.runtime.tools.terminal import TerminalTool
|
||||
from eidolon.runtime.tools.thinking import ThinkingTool
|
||||
from eidolon.runtime.tools.todo import TodoTool
|
||||
|
||||
DEFAULT_ACTION_TOOL: dict[str, str] = {
|
||||
"run_command": "terminal",
|
||||
"open_url": "browser",
|
||||
"edit_file": "file_edit",
|
||||
"graph_query": "graph_query",
|
||||
}
|
||||
|
||||
|
||||
class ExecutionEngine:
|
||||
"""Execute plan steps using the registered tool runtime."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: GraphRepository,
|
||||
runtime_settings: SandboxPermissions | None = None,
|
||||
extra_tools: Iterable[Tool] | None = None,
|
||||
) -> None:
|
||||
settings = runtime_settings or get_settings().sandbox
|
||||
self.runtime = SandboxRuntime(settings=settings)
|
||||
self.runtime.register_tool(TerminalTool())
|
||||
self.runtime.register_tool(BrowserTool())
|
||||
self.runtime.register_tool(FileEditTool())
|
||||
self.runtime.register_tool(ThinkingTool())
|
||||
self.runtime.register_tool(TodoTool())
|
||||
self.runtime.register_tool(FinishTool())
|
||||
self.runtime.register_tool(GraphQueryTool(repository))
|
||||
for tool in extra_tools or []:
|
||||
self.runtime.register_tool(tool)
|
||||
|
||||
def _resolve_tool(self, step: PlanStep) -> str | None:
|
||||
if step.tool_hint:
|
||||
return step.tool_hint
|
||||
return DEFAULT_ACTION_TOOL.get(step.action_type)
|
||||
|
||||
def execute_step(self, step: PlanStep, dry_run: bool = True) -> ToolExecutionResult:
|
||||
tool = self._resolve_tool(step)
|
||||
if not tool:
|
||||
return ToolExecutionResult(
|
||||
step_id=step.step_id,
|
||||
tool=None,
|
||||
status="skipped",
|
||||
error="no tool mapped for action_type",
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
return ToolExecutionResult(step_id=step.step_id, tool=tool, status="dry_run")
|
||||
|
||||
payload = step.parameters or {}
|
||||
output = self.runtime.execute(tool, payload)
|
||||
status = "ok"
|
||||
error = None
|
||||
if isinstance(output, dict):
|
||||
if output.get("error"):
|
||||
status = "error"
|
||||
error = str(output.get("error"))
|
||||
elif output.get("returncode") not in (None, 0):
|
||||
status = "error"
|
||||
error = output.get("stderr") or "command failed"
|
||||
|
||||
return ToolExecutionResult(
|
||||
step_id=step.step_id,
|
||||
tool=tool,
|
||||
status=status,
|
||||
output=output if isinstance(output, dict) else {"result": output},
|
||||
error=error,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from eidolon.config.settings import SandboxPermissions, get_settings
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class SandboxRuntime:
|
||||
"""Sandbox runtime that enforces capability gates before dispatching tools."""
|
||||
|
||||
def __init__(self, settings: SandboxPermissions | None = None) -> None:
|
||||
self.settings = settings or get_settings().sandbox
|
||||
self.active_tools: dict[str, Tool] = {}
|
||||
|
||||
def register_tool(self, tool: Tool) -> None:
|
||||
self.active_tools[tool.name] = tool
|
||||
|
||||
def _is_tool_allowed(self, tool: Tool, payload: dict[str, Any]) -> tuple[bool, str | None]:
|
||||
settings = self.settings
|
||||
name = tool.name
|
||||
if settings.allowed_tools is not None and name not in settings.allowed_tools:
|
||||
return False, f"tool {name} is not in allowlist"
|
||||
if name in settings.blocked_tools:
|
||||
return False, f"tool {name} is blocked"
|
||||
if not tool.sandbox_execution and not settings.allow_unsafe_tools:
|
||||
return False, f"tool {name} is not permitted in the sandbox"
|
||||
if name == "terminal" and not settings.allow_shell:
|
||||
return False, "terminal tool is disabled"
|
||||
if name == "browser" and not settings.allow_network:
|
||||
return False, "browser tool is disabled"
|
||||
if name == "file_edit":
|
||||
action = str(payload.get("action", "read")).lower()
|
||||
if action == "write" and not settings.allow_file_write:
|
||||
return False, "file write operations are disabled"
|
||||
return True, None
|
||||
|
||||
def execute(self, tool_name: str, payload: dict[str, Any] | None) -> dict[str, Any]:
|
||||
tool = self.active_tools.get(tool_name)
|
||||
if not tool:
|
||||
return {"error": f"tool {tool_name} not registered"}
|
||||
safe_payload = payload or {}
|
||||
allowed, reason = self._is_tool_allowed(tool, safe_payload)
|
||||
if not allowed:
|
||||
return {"error": reason or "tool execution not permitted"}
|
||||
return tool.run(safe_payload)
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskEvent:
|
||||
event_type: str
|
||||
status: str
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
message: str | None = None
|
||||
event_id: UUID = field(default_factory=uuid4)
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"event_id": str(self.event_id),
|
||||
"event_type": self.event_type,
|
||||
"status": self.status,
|
||||
"payload": self.payload,
|
||||
"message": self.message,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class TaskEventBus:
|
||||
def __init__(self, history_size: int = 200, queue_size: int = 200) -> None:
|
||||
self._history: deque[TaskEvent] = deque(maxlen=history_size)
|
||||
self._subscribers: set[queue.Queue[TaskEvent]] = set()
|
||||
self._async_subscribers: set[asyncio.Queue[TaskEvent]] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._queue_size = queue_size
|
||||
self._shutdown = False
|
||||
|
||||
def publish(self, event: TaskEvent) -> None:
|
||||
with self._lock:
|
||||
self._history.append(event)
|
||||
subscribers = list(self._subscribers)
|
||||
async_subscribers = list(self._async_subscribers)
|
||||
|
||||
# Publish to sync subscribers
|
||||
for subscriber in subscribers:
|
||||
try:
|
||||
subscriber.put_nowait(event)
|
||||
except queue.Full:
|
||||
with suppress(queue.Empty):
|
||||
subscriber.get_nowait()
|
||||
with suppress(queue.Full):
|
||||
subscriber.put_nowait(event)
|
||||
|
||||
# Publish to async subscribers
|
||||
for subscriber in async_subscribers:
|
||||
try:
|
||||
subscriber.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
with suppress(asyncio.QueueEmpty):
|
||||
subscriber.get_nowait()
|
||||
with suppress(asyncio.QueueFull):
|
||||
subscriber.put_nowait(event)
|
||||
|
||||
def subscribe(self) -> queue.Queue[TaskEvent]:
|
||||
subscriber: queue.Queue[TaskEvent] = queue.Queue(maxsize=self._queue_size)
|
||||
with self._lock:
|
||||
self._subscribers.add(subscriber)
|
||||
return subscriber
|
||||
|
||||
def unsubscribe(self, subscriber: queue.Queue[TaskEvent]) -> None:
|
||||
with self._lock:
|
||||
self._subscribers.discard(subscriber)
|
||||
|
||||
def subscribe_async(self) -> asyncio.Queue[TaskEvent]:
|
||||
"""Subscribe with an async queue for proper cancellation support."""
|
||||
subscriber: asyncio.Queue[TaskEvent] = asyncio.Queue(maxsize=self._queue_size)
|
||||
with self._lock:
|
||||
self._async_subscribers.add(subscriber)
|
||||
return subscriber
|
||||
|
||||
def unsubscribe_async(self, subscriber: asyncio.Queue[TaskEvent]) -> None:
|
||||
with self._lock:
|
||||
self._async_subscribers.discard(subscriber)
|
||||
|
||||
def history(self) -> Iterable[TaskEvent]:
|
||||
with self._lock:
|
||||
return list(self._history)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Signal shutdown to all async subscribers."""
|
||||
with self._lock:
|
||||
self._shutdown = True
|
||||
# Wake up all async subscribers with None sentinel
|
||||
for subscriber in self._async_subscribers:
|
||||
with suppress(asyncio.QueueFull):
|
||||
subscriber.put_nowait(None) # type: ignore
|
||||
|
||||
|
||||
task_event_bus = TaskEventBus()
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""Base class for agent tools."""
|
||||
|
||||
name: str = "tool"
|
||||
description: str = ""
|
||||
sandbox_execution: bool = True
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters. Override in subclasses for specific schemas."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
def to_openai_function(self) -> dict[str, Any]:
|
||||
"""Convert to OpenAI function calling format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters_schema,
|
||||
},
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the tool with a typed payload."""
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class BrowserTool(Tool):
|
||||
name = "browser"
|
||||
description = "Issue HTTP requests against web endpoints."
|
||||
sandbox_execution = True
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to request",
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method (GET, POST, PUT, PATCH, DELETE, HEAD, OPTIONS)",
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Optional request headers",
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "Optional query parameters",
|
||||
},
|
||||
"json": {
|
||||
"type": "object",
|
||||
"description": "Optional JSON body",
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Optional form/body payload as a string",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Timeout in seconds",
|
||||
},
|
||||
"follow_redirects": {
|
||||
"type": "boolean",
|
||||
"description": "Follow HTTP redirects",
|
||||
},
|
||||
"max_chars": {
|
||||
"type": "integer",
|
||||
"description": "Max response characters to return",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
url = payload.get("url")
|
||||
if not url:
|
||||
return {"error": "url is required"}
|
||||
|
||||
method = str(payload.get("method", "GET")).upper()
|
||||
headers = payload.get("headers") or {}
|
||||
params = payload.get("params") or {}
|
||||
json_body = payload.get("json")
|
||||
data = payload.get("data")
|
||||
timeout = payload.get("timeout", 10)
|
||||
follow_redirects = bool(payload.get("follow_redirects", True))
|
||||
max_chars_raw = payload.get("max_chars", 2000)
|
||||
try:
|
||||
max_chars = int(max_chars_raw)
|
||||
except (TypeError, ValueError):
|
||||
max_chars = 2000
|
||||
|
||||
if method not in {"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"}:
|
||||
return {"error": f"unsupported method {method}"}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=timeout, follow_redirects=follow_redirects) as client:
|
||||
response = client.request(
|
||||
method,
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=json_body,
|
||||
data=data,
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
return {"url": url, "error": str(exc)}
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
text = response.text
|
||||
if max_chars > 0 and len(text) > max_chars:
|
||||
text = f"{text[:max_chars]}...(truncated)"
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"url": url,
|
||||
"status_code": response.status_code,
|
||||
"content_type": content_type,
|
||||
"headers": dict(response.headers),
|
||||
"text": text,
|
||||
}
|
||||
if "application/json" in content_type:
|
||||
with suppress(ValueError):
|
||||
result["json"] = response.json()
|
||||
if response.status_code >= 400:
|
||||
result["error"] = f"HTTP {response.status_code}"
|
||||
return result
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class FileEditTool(Tool):
|
||||
name = "file_edit"
|
||||
description = "Read or write files with explicit intents."
|
||||
sandbox_execution = True
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["read", "write"],
|
||||
"description": "Action to perform on the file",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write (for write action)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "path"],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
action = str(payload.get("action", "read")).lower()
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"error": "path is required"}
|
||||
target = Path(path)
|
||||
if action == "read":
|
||||
if not target.exists():
|
||||
return {"error": f"{path} not found"}
|
||||
return {"path": path, "content": target.read_text(encoding="utf-8")}
|
||||
if action == "write":
|
||||
content = payload.get("content", "")
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_text(content, encoding="utf-8")
|
||||
return {"path": path, "status": "written"}
|
||||
return {"error": f"unsupported action {action}"}
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class FinishTool(Tool):
|
||||
name = "finish"
|
||||
description = "Signal task completion and return final payload."
|
||||
sandbox_execution = False
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Brief completion summary",
|
||||
},
|
||||
"details": {
|
||||
"type": "object",
|
||||
"description": "Optional structured completion payload",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"result": payload}
|
||||
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class GraphQueryTool(Tool):
|
||||
name = "graph_query"
|
||||
description = """Execute Cypher queries against the Eidolon infrastructure graph (Neo4j 5.x).
|
||||
|
||||
CRITICAL syntax requirements:
|
||||
- Use `n.property IS NOT NULL` instead of `exists(n.property)` (deprecated)
|
||||
- Available labels: NetworkContainer, Asset, Identity, Policy
|
||||
- Common patterns:
|
||||
* List networks: MATCH (n:NetworkContainer) WHERE n.cidr IS NOT NULL RETURN n
|
||||
* Find assets: MATCH (a:Asset) WHERE a.asset_id IS NOT NULL RETURN a
|
||||
* Get relationships: MATCH (a)-[r]->(b) RETURN a, type(r), b"""
|
||||
sandbox_execution = False
|
||||
|
||||
def __init__(self, repository: GraphRepository) -> None:
|
||||
self.repository = repository
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cypher": {
|
||||
"type": "string",
|
||||
"description": "The Cypher query to execute against the Neo4j graph",
|
||||
},
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"description": "Optional parameters for the Cypher query",
|
||||
},
|
||||
},
|
||||
"required": ["cypher"],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
cypher = payload.get("cypher")
|
||||
parameters = payload.get("parameters") or {}
|
||||
if not cypher:
|
||||
return {"error": "cypher is required"}
|
||||
records = list(self.repository.run_cypher(cypher, parameters))
|
||||
return {"records": records}
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class TerminalTool(Tool):
|
||||
name = "terminal"
|
||||
description = "Execute shell commands in a sandboxed environment."
|
||||
sandbox_execution = True
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"workdir": {
|
||||
"type": "string",
|
||||
"description": "Working directory for the command (optional)",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
command = payload.get("command")
|
||||
workdir = payload.get("workdir")
|
||||
if not command:
|
||||
return {"error": "command is required"}
|
||||
result = subprocess.run( # noqa: S602
|
||||
command,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=workdir,
|
||||
)
|
||||
return {
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"returncode": result.returncode,
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class ThinkingTool(Tool):
|
||||
name = "thinking"
|
||||
description = "Structured reasoning scratchpad."
|
||||
sandbox_execution = False
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thoughts": {
|
||||
"type": "string",
|
||||
"description": "Reasoning or plan notes",
|
||||
}
|
||||
},
|
||||
"required": ["thoughts"],
|
||||
}
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
thoughts = payload.get("thoughts", "")
|
||||
return {"thoughts": thoughts, "status": "captured"}
|
||||
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from eidolon.runtime.tools.base import Tool
|
||||
|
||||
|
||||
class TodoTool(Tool):
|
||||
name = "todo"
|
||||
description = "Manage a task list during a session."
|
||||
sandbox_execution = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.items: list[dict[str, Any]] = []
|
||||
self._next_id = 1
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["set", "complete", "skip", "list"],
|
||||
"description": (
|
||||
"Action: 'set' to initialize list once, 'complete'/'skip' to update "
|
||||
"status, 'list' to view"
|
||||
),
|
||||
},
|
||||
"item": {
|
||||
"type": "string",
|
||||
"description": "Single task item (for add/set)",
|
||||
},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Multiple task items (for add/set)",
|
||||
},
|
||||
"id": {
|
||||
"type": "integer",
|
||||
"description": "Task id to complete or remove",
|
||||
},
|
||||
"result": {
|
||||
"type": "string",
|
||||
"description": "Optional completion result or note",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
return any(item.get("status") == "pending" for item in self.items)
|
||||
|
||||
def _normalize_items(self, payload: dict[str, Any]) -> list[str]:
|
||||
items = payload.get("items")
|
||||
if isinstance(items, list):
|
||||
return [str(item).strip() for item in items if str(item).strip()]
|
||||
item = payload.get("item")
|
||||
if isinstance(item, str) and item.strip():
|
||||
return [item.strip()]
|
||||
return []
|
||||
|
||||
def _add_item(self, text: str) -> dict[str, Any]:
|
||||
item = {"id": self._next_id, "text": text, "status": "pending"}
|
||||
self.items.append(item)
|
||||
self._next_id += 1
|
||||
return item
|
||||
|
||||
def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
action = str(payload.get("action", "list")).lower()
|
||||
|
||||
if action == "add":
|
||||
for item in self._normalize_items(payload):
|
||||
self._add_item(item)
|
||||
return {"items": list(self.items)}
|
||||
|
||||
if action == "set":
|
||||
self.items = []
|
||||
self._next_id = 1
|
||||
for item in self._normalize_items(payload):
|
||||
self._add_item(item)
|
||||
return {"items": list(self.items)}
|
||||
|
||||
if action == "complete":
|
||||
item_id = payload.get("id")
|
||||
if item_id is None:
|
||||
return {"error": "id is required for complete"}
|
||||
try:
|
||||
item_id = int(item_id)
|
||||
except (TypeError, ValueError):
|
||||
return {"error": "id must be an integer"}
|
||||
target = next((item for item in self.items if item["id"] == item_id), None)
|
||||
if not target:
|
||||
return {"error": f"task id {item_id} not found"}
|
||||
target["status"] = "complete"
|
||||
result = payload.get("result")
|
||||
if isinstance(result, str) and result.strip():
|
||||
target["result"] = result.strip()
|
||||
return {"items": list(self.items), "completed": target}
|
||||
|
||||
if action == "skip":
|
||||
item_id = payload.get("id")
|
||||
if item_id is None:
|
||||
return {"error": "id is required for skip"}
|
||||
try:
|
||||
item_id = int(item_id)
|
||||
except (TypeError, ValueError):
|
||||
return {"error": "id must be an integer"}
|
||||
target = next((item for item in self.items if item["id"] == item_id), None)
|
||||
if not target:
|
||||
return {"error": f"task id {item_id} not found"}
|
||||
target["status"] = "skipped"
|
||||
result = payload.get("result")
|
||||
if isinstance(result, str) and result.strip():
|
||||
target["result"] = result.strip()
|
||||
return {"items": list(self.items), "skipped": target}
|
||||
|
||||
if action == "remove":
|
||||
item_id = payload.get("id")
|
||||
if item_id is None:
|
||||
return {"error": "id is required for remove"}
|
||||
try:
|
||||
item_id = int(item_id)
|
||||
except (TypeError, ValueError):
|
||||
return {"error": "id must be an integer"}
|
||||
self.items = [item for item in self.items if item["id"] != item_id]
|
||||
return {"items": list(self.items)}
|
||||
|
||||
if action == "clear":
|
||||
self.items = []
|
||||
self._next_id = 1
|
||||
return {"items": []}
|
||||
|
||||
if action == "list":
|
||||
return {"items": list(self.items)}
|
||||
|
||||
return {"error": f"unsupported action {action}"}
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from eidolon.core.graph.repository import GraphRepository
|
||||
from eidolon.core.models.asset import Asset, Identity, NetworkContainer, Policy
|
||||
from eidolon.core.models.graph import Edge, EvidenceRef, GraphPath, Node
|
||||
|
||||
|
||||
class InMemoryGraphRepository(GraphRepository):
|
||||
def __init__(self) -> None:
|
||||
self.nodes: dict[UUID, Node] = {}
|
||||
self.edges: list[Edge] = []
|
||||
self.adjacency: dict[UUID, list[Edge]] = defaultdict(list)
|
||||
|
||||
def upsert_node(self, node: Node) -> None:
|
||||
self.nodes[node.node_id] = node
|
||||
|
||||
def upsert_edge(self, edge: Edge) -> None:
|
||||
self.edges.append(edge)
|
||||
self.adjacency[edge.source].append(edge)
|
||||
|
||||
def find_paths(self, source: UUID, target: UUID, max_depth: int = 4) -> list[GraphPath]:
|
||||
queue: deque[tuple[UUID, list[UUID], list[str]]] = deque()
|
||||
queue.append((source, [source], []))
|
||||
paths: list[GraphPath] = []
|
||||
while queue:
|
||||
node_id, path, rels = queue.popleft()
|
||||
if node_id == target:
|
||||
paths.append(GraphPath(nodes=path, edges=rels, cost=float(len(rels))))
|
||||
continue
|
||||
if len(path) > max_depth:
|
||||
continue
|
||||
for edge in self.adjacency.get(node_id, []):
|
||||
if edge.target not in path:
|
||||
queue.append((edge.target, [*path, edge.target], [*rels, edge.type]))
|
||||
return paths
|
||||
|
||||
def get_neighbors(
|
||||
self, node_id: UUID, relationship_types: Sequence[str] | None = None
|
||||
) -> list[UUID]:
|
||||
neighbors = []
|
||||
for edge in self.adjacency.get(node_id, []):
|
||||
if relationship_types and edge.type not in relationship_types:
|
||||
continue
|
||||
neighbors.append(edge.target)
|
||||
return neighbors
|
||||
|
||||
def upsert_asset(self, asset: Asset) -> None:
|
||||
self.upsert_node(asset)
|
||||
|
||||
def upsert_network(self, network: NetworkContainer) -> None:
|
||||
self.upsert_node(network)
|
||||
|
||||
def upsert_identity(self, identity: Identity) -> None:
|
||||
self.upsert_node(identity)
|
||||
|
||||
def upsert_policy(self, policy: Policy) -> None:
|
||||
self.upsert_node(policy)
|
||||
|
||||
def get_node(self, node_id: UUID) -> Node | None:
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def list_nodes(self, label: str | None = None, limit: int = 100) -> list[Node]:
|
||||
nodes = list(self.nodes.values())
|
||||
if label:
|
||||
nodes = [n for n in nodes if n.label == label]
|
||||
return nodes[:limit]
|
||||
|
||||
def run_cypher(self, cypher: str, parameters: dict | None = None) -> Iterable[dict]:
|
||||
return []
|
||||
|
||||
def find_asset_by_identifier(self, identifier: str) -> Asset | None:
|
||||
for node in self.nodes.values():
|
||||
if isinstance(node, Asset) and identifier in node.identifiers:
|
||||
return node
|
||||
if (
|
||||
node.label == "Asset"
|
||||
and getattr(node, "identifiers", None)
|
||||
and identifier in node.identifiers
|
||||
):
|
||||
return node # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def find_network_by_cidr_or_name(self, cidr_or_name: str) -> NetworkContainer | None:
|
||||
for node in self.nodes.values():
|
||||
if isinstance(node, NetworkContainer) and (
|
||||
node.cidr == cidr_or_name or node.name == cidr_or_name
|
||||
):
|
||||
return node
|
||||
return None
|
||||
|
||||
def find_identity_by_name(self, name: str) -> Identity | None:
|
||||
for node in self.nodes.values():
|
||||
if isinstance(node, Identity) and node.name == name:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_edge_evidence(self, edge_id: UUID) -> list[EvidenceRef]:
|
||||
for edge in self.edges:
|
||||
if edge.edge_id == edge_id:
|
||||
return list(edge.evidence)
|
||||
return []
|
||||
|
||||
def clear(self) -> int:
|
||||
count = len(self.nodes)
|
||||
self.nodes.clear()
|
||||
self.edges.clear()
|
||||
self.adjacency.clear()
|
||||
return count
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_repo() -> InMemoryGraphRepository:
|
||||
return InMemoryGraphRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def viewer_headers() -> dict:
|
||||
return {"x-roles": "viewer"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def planner_headers() -> dict:
|
||||
return {"x-roles": "planner"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def executor_headers() -> dict:
|
||||
return {"x-roles": "executor"}
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from eidolon.core.graph.algorithms import blast_radius
|
||||
from eidolon.core.models.graph import Edge, Node
|
||||
|
||||
|
||||
def test_blast_radius_traversal(in_memory_repo) -> None:
|
||||
root = Node(label="Asset")
|
||||
child = Node(label="Asset")
|
||||
in_memory_repo.upsert_node(root)
|
||||
in_memory_repo.upsert_node(child)
|
||||
in_memory_repo.upsert_edge(
|
||||
Edge(
|
||||
type="CAN_REACH",
|
||||
source=root.node_id,
|
||||
target=child.node_id,
|
||||
first_seen=datetime.utcnow(),
|
||||
last_seen=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
radius = blast_radius(in_memory_repo, [root.node_id], depth=1)
|
||||
assert set(radius.affected_nodes) == {root.node_id, child.node_id}
|
||||
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from eidolon.api.app import create_app
|
||||
from eidolon.api.dependencies import get_graph_repository
|
||||
from eidolon.tests.conftest import InMemoryGraphRepository
|
||||
|
||||
|
||||
def test_plan_endpoint_generates_steps(planner_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = InMemoryGraphRepository
|
||||
client = TestClient(app)
|
||||
|
||||
payload = {
|
||||
"intent": "Explain how to isolate subnet X safely.",
|
||||
"target": {
|
||||
"entity_type": "NetworkContainer",
|
||||
"display_name": "subnet-x",
|
||||
"confidence": 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
response = client.post("/plan/", json=payload, headers=planner_headers)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["steps"]
|
||||
assert body["decision"]["effect"] in {"allow", "needs_approval"}
|
||||
|
||||
|
||||
def test_execute_requires_token_for_non_dry_run(executor_headers) -> None:
|
||||
app = create_app()
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/plan/execute",
|
||||
json={
|
||||
"dry_run": False,
|
||||
"requires_approval": True,
|
||||
"steps": [],
|
||||
},
|
||||
headers=executor_headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from eidolon.api.app import create_app
|
||||
from eidolon.api.dependencies import get_graph_repository
|
||||
from eidolon.core.models.graph import Edge, Node
|
||||
|
||||
|
||||
def test_query_path_endpoint(in_memory_repo, viewer_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
client = TestClient(app)
|
||||
|
||||
source = Node(label="Asset")
|
||||
target = Node(label="Asset")
|
||||
in_memory_repo.upsert_node(source)
|
||||
in_memory_repo.upsert_node(target)
|
||||
in_memory_repo.upsert_edge(
|
||||
Edge(
|
||||
type="CAN_REACH",
|
||||
source=source.node_id,
|
||||
target=target.node_id,
|
||||
first_seen=datetime.utcnow(),
|
||||
last_seen=datetime.utcnow(),
|
||||
)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/query/",
|
||||
json={
|
||||
"question": "find path",
|
||||
"source_id": str(source.node_id),
|
||||
"target_id": str(target.node_id),
|
||||
"max_depth": 3,
|
||||
},
|
||||
headers=viewer_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["answer"] == "Path search completed."
|
||||
assert data["paths"][0]["nodes"][0] == str(source.node_id)
|
||||
|
||||
|
||||
def test_nl_query_generates_cypher(in_memory_repo, viewer_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/query/",
|
||||
json={"question": "list assets in network 10.0.0.0/24"},
|
||||
headers=viewer_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["graph_query"] is not None
|
||||
assert "network" in data["graph_query"]["parameters"]
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from eidolon.api.app import create_app
|
||||
from eidolon.api.dependencies import get_entity_resolver, get_graph_repository
|
||||
|
||||
|
||||
def test_collector_scan_route(in_memory_repo, planner_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
app.dependency_overrides[get_entity_resolver] = get_entity_resolver
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/collector/scan", headers=planner_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["events_processed"] >= 0
|
||||
assert data["status"] in {"ok", "partial_failure"}
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from eidolon.core.reasoning.entity import EntityResolver
|
||||
|
||||
|
||||
def test_resolve_asset() -> None:
|
||||
resolver = EntityResolver()
|
||||
payload = {"ip": "10.0.0.5", "hostname": "app-server", "env": "prod", "criticality": "high"}
|
||||
asset = resolver.resolve_asset(
|
||||
payload, source_type="network", source_id="scan-1", confidence=0.8
|
||||
)
|
||||
|
||||
assert "10.0.0.5" in asset.identifiers
|
||||
assert asset.evidence[0].source_type == "network"
|
||||
assert asset.kind == "host"
|
||||
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from eidolon.core.models.graph import Edge, Node
|
||||
|
||||
|
||||
def test_upsert_and_find_paths(in_memory_repo) -> None:
|
||||
source = Node(label="Asset")
|
||||
target = Node(label="Asset")
|
||||
in_memory_repo.upsert_node(source)
|
||||
in_memory_repo.upsert_node(target)
|
||||
|
||||
edge = Edge(
|
||||
type="CAN_REACH",
|
||||
source=source.node_id,
|
||||
target=target.node_id,
|
||||
first_seen=datetime.utcnow(),
|
||||
last_seen=datetime.utcnow(),
|
||||
)
|
||||
in_memory_repo.upsert_edge(edge)
|
||||
|
||||
paths = in_memory_repo.find_paths(source.node_id, target.node_id)
|
||||
assert len(paths) == 1
|
||||
assert paths[0].nodes == [source.node_id, target.node_id]
|
||||
assert paths[0].edges == ["CAN_REACH"]
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from eidolon.api.app import create_app
|
||||
from eidolon.api.dependencies import get_graph_repository
|
||||
from eidolon.core.models.graph import Edge, Node
|
||||
|
||||
|
||||
def test_graph_assets_endpoints(in_memory_repo, viewer_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
client = TestClient(app)
|
||||
|
||||
asset = Node(label="Asset")
|
||||
in_memory_repo.upsert_node(asset)
|
||||
|
||||
resp = client.get(f"/graph/assets/{asset.node_id}", headers=viewer_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["label"] == "Asset"
|
||||
|
||||
resp_list = client.get("/graph/assets", headers=viewer_headers)
|
||||
assert resp_list.status_code == 200
|
||||
assert len(resp_list.json()) >= 1
|
||||
|
||||
|
||||
def test_graph_paths_endpoint(in_memory_repo, viewer_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
client = TestClient(app)
|
||||
|
||||
a = Node(label="Asset")
|
||||
b = Node(label="Asset")
|
||||
in_memory_repo.upsert_node(a)
|
||||
in_memory_repo.upsert_node(b)
|
||||
in_memory_repo.upsert_edge(Edge(type="CAN_REACH", source=a.node_id, target=b.node_id))
|
||||
|
||||
resp = client.get(
|
||||
"/graph/paths",
|
||||
params={"source_id": str(a.node_id), "target_id": str(b.node_id), "max_depth": 3},
|
||||
headers=viewer_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
paths = resp.json()
|
||||
assert paths and paths[0]["nodes"][0] == str(a.node_id)
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from eidolon.api.app import create_app
|
||||
from eidolon.api.dependencies import get_entity_resolver, get_graph_repository
|
||||
from eidolon.core.models.event import CollectorEvent
|
||||
|
||||
|
||||
def test_ingest_events(in_memory_repo, executor_headers) -> None:
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_graph_repository] = lambda: in_memory_repo
|
||||
app.dependency_overrides[get_entity_resolver] = get_entity_resolver
|
||||
client = TestClient(app)
|
||||
|
||||
event = CollectorEvent(
|
||||
source_type="network",
|
||||
source_id="ingest-test",
|
||||
entity_type="Asset",
|
||||
payload={"ip": "10.0.0.10", "cidr": "10.0.0.0/24"},
|
||||
collected_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/ingest/events",
|
||||
json=[event.model_dump(mode="json")],
|
||||
headers=executor_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["accepted"] == 1
|
||||
assert len(in_memory_repo.nodes) == 2
|
||||
assert any(edge.type == "MEMBER_OF" for edge in in_memory_repo.edges)
|
||||
@@ -0,0 +1,3 @@
|
||||
node_modules
|
||||
dist
|
||||
*.log
|
||||
@@ -0,0 +1,18 @@
|
||||
FROM node:20-alpine AS build
|
||||
|
||||
WORKDIR /app
|
||||
COPY package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY . .
|
||||
ARG VITE_API_BASE=http://localhost:8080
|
||||
ENV VITE_API_BASE=$VITE_API_BASE
|
||||
RUN npm run build
|
||||
|
||||
FROM nginx:1.27-alpine AS runtime
|
||||
|
||||
COPY nginx.conf /etc/nginx/conf.d/default.conf
|
||||
COPY --from=build /app/dist /usr/share/nginx/html
|
||||
|
||||
EXPOSE 80
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Eidolon</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,11 @@
|
||||
server {
|
||||
listen 80;
|
||||
server_name _;
|
||||
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
|
||||
location / {
|
||||
try_files $uri /index.html;
|
||||
}
|
||||
}
|
||||
Generated
+3699
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "eidolon-ui",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"preview": "vite preview",
|
||||
"lint": "echo \"add lint tooling (eslint/tailwind) later\""
|
||||
},
|
||||
"dependencies": {
|
||||
"lucide-react": "^0.562.0",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-force-graph-2d": "^1.29.0",
|
||||
"react-markdown": "^9.1.0",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/d3-force": "^3.0.10",
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"typescript": "^5.7.2",
|
||||
"vite": "^6.0.1"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { Sidebar } from "./components/Sidebar";
|
||||
import { StatsRow } from "./components/StatsRow";
|
||||
import { NetworkMap } from "./components/NetworkMap";
|
||||
import { DeviceList } from "./components/DeviceList";
|
||||
import { NetworkList } from "./components/NetworkList";
|
||||
import { EventLog } from "./components/EventLog";
|
||||
import { ScannerControl } from "./components/ScannerControl";
|
||||
import { ChatInterface } from "./components/ChatInterface";
|
||||
import { RightPanel } from "./components/RightPanel";
|
||||
import { ToastContainer, useToast } from "./components/Toast";
|
||||
import { SettingsPanel } from "./components/SettingsPanel";
|
||||
import { GraphView } from "./components/GraphView";
|
||||
import {
|
||||
clearAuditEvents,
|
||||
createChatSession,
|
||||
deleteAllChatSessions,
|
||||
deleteChatSession,
|
||||
getAppSettings,
|
||||
getChatSession,
|
||||
getNodeId,
|
||||
listAssets,
|
||||
listAuditEvents,
|
||||
listChatSessions,
|
||||
listNetworks,
|
||||
resetGraph,
|
||||
updateAppSettings,
|
||||
type AppSettings,
|
||||
type AppSettingsUpdate,
|
||||
type AuditEvent,
|
||||
type ChatSession,
|
||||
type ChatSessionSummary,
|
||||
type GraphNode,
|
||||
} from "./api";
|
||||
|
||||
const normalizeTitle = (value: string) => value.replace(/\s+/g, " ").trim();
|
||||
|
||||
const buildSessionTitle = (value: string, maxLength = 56) => {
|
||||
const normalized = normalizeTitle(value);
|
||||
if (!normalized) {
|
||||
return "New chat";
|
||||
}
|
||||
if (normalized.length <= maxLength) {
|
||||
return normalized;
|
||||
}
|
||||
return `${normalized.slice(0, maxLength).trim()}...`;
|
||||
};
|
||||
|
||||
const toSessionSummary = (session: ChatSession): ChatSessionSummary => ({
|
||||
session_id: session.session_id,
|
||||
title: session.title ?? null,
|
||||
created_at: session.created_at,
|
||||
updated_at: session.updated_at,
|
||||
message_count: session.messages.length,
|
||||
});
|
||||
|
||||
const buildExportFilename = (prefix: string) => {
|
||||
const stamp = new Date().toISOString().replace(/[:.]/g, "-");
|
||||
return `${prefix}-${stamp}.json`;
|
||||
};
|
||||
|
||||
const downloadJson = (filename: string, payload: unknown) => {
|
||||
const blob = new Blob([JSON.stringify(payload, null, 2)], { type: "application/json" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const anchor = document.createElement("a");
|
||||
anchor.href = url;
|
||||
anchor.download = filename;
|
||||
anchor.click();
|
||||
URL.revokeObjectURL(url);
|
||||
};
|
||||
|
||||
export default function App() {
|
||||
const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false);
|
||||
const [isRightPanelCollapsed, setIsRightPanelCollapsed] = useState(false);
|
||||
const [activeTab, setActiveTab] = useState("chat");
|
||||
const [appSettings, setAppSettings] = useState<AppSettings | null>(null);
|
||||
const [isAppSettingsLoading, setIsAppSettingsLoading] = useState(true);
|
||||
const [sessions, setSessions] = useState<ChatSessionSummary[]>([]);
|
||||
const [activeSessionId, setActiveSessionId] = useState<string | null>(null);
|
||||
const [assets, setAssets] = useState<GraphNode[]>([]);
|
||||
const [networks, setNetworks] = useState<GraphNode[]>([]);
|
||||
const [auditEvents, setAuditEvents] = useState<AuditEvent[]>([]);
|
||||
const [auditTotal, setAuditTotal] = useState(0);
|
||||
const [selectedNetworkId, setSelectedNetworkId] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isWipingChats, setIsWipingChats] = useState(false);
|
||||
const [isClearingAudit, setIsClearingAudit] = useState(false);
|
||||
const [isResettingGraph, setIsResettingGraph] = useState(false);
|
||||
const [isExportingChats, setIsExportingChats] = useState(false);
|
||||
const [isExportingGraph, setIsExportingGraph] = useState(false);
|
||||
const [dataRefreshCount, setDataRefreshCount] = useState(0);
|
||||
const { toasts, showToast, dismissToast} = useToast();
|
||||
|
||||
const loadData = useCallback(async () => {
|
||||
try {
|
||||
const [assetsData, networksData, auditData] = await Promise.all([
|
||||
listAssets(),
|
||||
listNetworks(),
|
||||
listAuditEvents({ page: 1, page_size: 50 }),
|
||||
]);
|
||||
setAssets(assetsData);
|
||||
setNetworks(networksData);
|
||||
setAuditEvents(auditData.events);
|
||||
setAuditTotal(auditData.total);
|
||||
setSelectedNetworkId((current) => current ?? getNodeId(networksData[0] ?? {}));
|
||||
setDataRefreshCount((prev) => prev + 1);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to load data";
|
||||
showToast(`API error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const loadSessions = useCallback(async () => {
|
||||
try {
|
||||
const data = await listChatSessions();
|
||||
setSessions(data);
|
||||
setActiveSessionId((current) => {
|
||||
if (current && data.some((session) => session.session_id === current)) {
|
||||
return current;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to load chat sessions";
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const loadAppSettings = useCallback(async () => {
|
||||
setIsAppSettingsLoading(true);
|
||||
try {
|
||||
const settings = await getAppSettings();
|
||||
setAppSettings(settings);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to load settings";
|
||||
showToast(`Settings error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsAppSettingsLoading(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleUpdateAppSettings = useCallback(
|
||||
async (payload: AppSettingsUpdate) => {
|
||||
try {
|
||||
const updated = await updateAppSettings(payload);
|
||||
setAppSettings(updated);
|
||||
showToast("Settings saved", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to update settings";
|
||||
showToast(`Settings error: ${message}`, "error");
|
||||
}
|
||||
},
|
||||
[showToast]
|
||||
);
|
||||
|
||||
const handleWipeChats = useCallback(async () => {
|
||||
if (!window.confirm("Delete all chat sessions? This cannot be undone.")) {
|
||||
return;
|
||||
}
|
||||
setIsWipingChats(true);
|
||||
try {
|
||||
await deleteAllChatSessions();
|
||||
setSessions([]);
|
||||
setActiveSessionId(null);
|
||||
showToast("All chats deleted", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to wipe chats";
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsWipingChats(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleClearAudit = useCallback(async () => {
|
||||
if (!window.confirm("Clear all audit/event logs? This cannot be undone.")) {
|
||||
return;
|
||||
}
|
||||
setIsClearingAudit(true);
|
||||
try {
|
||||
await clearAuditEvents();
|
||||
setAuditEvents([]);
|
||||
setAuditTotal(0);
|
||||
showToast("Audit logs cleared", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to clear audit logs";
|
||||
showToast(`Audit error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsClearingAudit(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleResetGraph = useCallback(async () => {
|
||||
if (!window.confirm("Reset all graph data? This will remove networks and assets.")) {
|
||||
return;
|
||||
}
|
||||
setIsResettingGraph(true);
|
||||
try {
|
||||
await resetGraph();
|
||||
setAssets([]);
|
||||
setNetworks([]);
|
||||
setSelectedNetworkId(null);
|
||||
showToast("Graph data cleared", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to reset graph data";
|
||||
showToast(`Graph error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsResettingGraph(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleExportChats = useCallback(async () => {
|
||||
setIsExportingChats(true);
|
||||
try {
|
||||
const summaries = await listChatSessions();
|
||||
const sessionsData = await Promise.all(
|
||||
summaries.map((session) => getChatSession(session.session_id))
|
||||
);
|
||||
downloadJson(buildExportFilename("eidolon-chats"), sessionsData);
|
||||
showToast("Chat export ready", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to export chats";
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsExportingChats(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleExportGraph = useCallback(async () => {
|
||||
setIsExportingGraph(true);
|
||||
try {
|
||||
const [assetsData, networksData] = await Promise.all([listAssets(500), listNetworks(500)]);
|
||||
downloadJson(buildExportFilename("eidolon-graph"), {
|
||||
assets: assetsData,
|
||||
networks: networksData,
|
||||
});
|
||||
showToast("Graph export ready", "success");
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to export graph data";
|
||||
showToast(`Graph error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsExportingGraph(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleStartNewChat = useCallback(() => {
|
||||
setActiveSessionId(null);
|
||||
setActiveTab("chat");
|
||||
}, []);
|
||||
|
||||
const handleCreateSession = useCallback(async (initialMessage: string) => {
|
||||
const title = buildSessionTitle(initialMessage);
|
||||
try {
|
||||
const session = await createChatSession(title);
|
||||
const summary = toSessionSummary(session);
|
||||
setSessions((prev) => [summary, ...prev.filter((item) => item.session_id !== summary.session_id)]);
|
||||
setActiveSessionId(session.session_id);
|
||||
setActiveTab("chat");
|
||||
return session;
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to create chat session";
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
return null;
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
const handleSelectSession = useCallback((sessionId: string) => {
|
||||
setActiveSessionId(sessionId);
|
||||
setActiveTab("chat");
|
||||
}, []);
|
||||
|
||||
const handleDeleteSession = useCallback(
|
||||
async (sessionId: string) => {
|
||||
const previousSessions = sessions;
|
||||
const previousActive = activeSessionId;
|
||||
const nextSessions = previousSessions.filter((session) => session.session_id !== sessionId);
|
||||
setSessions(nextSessions);
|
||||
if (previousActive === sessionId) {
|
||||
setActiveSessionId(nextSessions[0]?.session_id ?? null);
|
||||
}
|
||||
try {
|
||||
await deleteChatSession(sessionId);
|
||||
} catch (err) {
|
||||
setSessions(previousSessions);
|
||||
setActiveSessionId(previousActive);
|
||||
const message = err instanceof Error ? err.message : "Failed to delete chat session";
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
}
|
||||
},
|
||||
[sessions, activeSessionId, showToast]
|
||||
);
|
||||
|
||||
const handleSessionUpdated = useCallback((session: ChatSession) => {
|
||||
const summary = toSessionSummary(session);
|
||||
setSessions((prev) => [summary, ...prev.filter((item) => item.session_id !== summary.session_id)]);
|
||||
}, []);
|
||||
|
||||
// Auto-collapse panels on narrow screens
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (window.innerWidth <= 1024) {
|
||||
setIsSidebarCollapsed(true);
|
||||
} else {
|
||||
setIsSidebarCollapsed(false);
|
||||
}
|
||||
|
||||
if (window.innerWidth <= 768) {
|
||||
setIsRightPanelCollapsed(true);
|
||||
} else {
|
||||
setIsRightPanelCollapsed(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Set initial state
|
||||
handleResize();
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
loadData();
|
||||
}, [loadData]);
|
||||
|
||||
useEffect(() => {
|
||||
loadSessions();
|
||||
}, [loadSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
loadAppSettings();
|
||||
}, [loadAppSettings]);
|
||||
|
||||
useEffect(() => {
|
||||
const theme = appSettings?.theme.mode;
|
||||
if (!theme) {
|
||||
return;
|
||||
}
|
||||
document.documentElement.dataset.theme = theme;
|
||||
}, [appSettings?.theme.mode]);
|
||||
|
||||
const selectedNetwork = useMemo(
|
||||
() => networks.find((network) => getNodeId(network) === selectedNetworkId) ?? null,
|
||||
[networks, selectedNetworkId]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="layout">
|
||||
<ToastContainer toasts={toasts} onDismiss={dismissToast} />
|
||||
<Sidebar
|
||||
isCollapsed={isSidebarCollapsed}
|
||||
onToggle={() => setIsSidebarCollapsed(!isSidebarCollapsed)}
|
||||
activeTab={activeTab}
|
||||
onTabChange={setActiveTab}
|
||||
sessions={sessions}
|
||||
activeSessionId={activeSessionId}
|
||||
onNewChat={handleStartNewChat}
|
||||
onSelectSession={handleSelectSession}
|
||||
onDeleteSession={handleDeleteSession}
|
||||
/>
|
||||
|
||||
<main className="main">
|
||||
{activeTab === "chat" && (
|
||||
<ChatInterface
|
||||
sessionId={activeSessionId}
|
||||
onSessionUpdated={handleSessionUpdated}
|
||||
onCreateSession={handleCreateSession}
|
||||
/>
|
||||
)}
|
||||
{activeTab === "networks" && (
|
||||
<>
|
||||
<NetworkList
|
||||
networks={networks}
|
||||
assets={assets}
|
||||
selectedNetworkId={selectedNetworkId}
|
||||
onSelectNetwork={setSelectedNetworkId}
|
||||
isLoading={isLoading}
|
||||
/>
|
||||
<StatsRow assets={assets} network={selectedNetwork} isLoading={isLoading} />
|
||||
<NetworkMap assets={assets} network={selectedNetwork} isLoading={isLoading} />
|
||||
<DeviceList assets={assets} network={selectedNetwork} isLoading={isLoading} />
|
||||
</>
|
||||
)}
|
||||
{activeTab === "graph" && <GraphView key="graph-view" showToast={showToast} refreshTrigger={dataRefreshCount} />}
|
||||
{activeTab === "audit" && (
|
||||
<EventLog events={auditEvents} isLoading={isLoading} limit={50} showFilters={true} total={auditTotal} />
|
||||
)}
|
||||
{activeTab === "settings" && (
|
||||
<SettingsPanel
|
||||
settings={appSettings}
|
||||
isLoading={isAppSettingsLoading}
|
||||
onSave={handleUpdateAppSettings}
|
||||
actions={{
|
||||
wipeChats: handleWipeChats,
|
||||
clearAudit: handleClearAudit,
|
||||
resetGraph: handleResetGraph,
|
||||
exportChats: handleExportChats,
|
||||
exportGraph: handleExportGraph,
|
||||
}}
|
||||
busy={{
|
||||
wipeChats: isWipingChats,
|
||||
clearAudit: isClearingAudit,
|
||||
resetGraph: isResettingGraph,
|
||||
exportChats: isExportingChats,
|
||||
exportGraph: isExportingGraph,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
<RightPanel
|
||||
isCollapsed={isRightPanelCollapsed}
|
||||
onToggle={() => setIsRightPanelCollapsed(!isRightPanelCollapsed)}
|
||||
>
|
||||
<ScannerControl
|
||||
onRefreshData={loadData}
|
||||
onOpenAudit={() => setActiveTab("audit")}
|
||||
showToast={showToast}
|
||||
/>
|
||||
</RightPanel>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,642 @@
|
||||
export type GraphNode = {
|
||||
id?: string;
|
||||
node_id?: string;
|
||||
label: string;
|
||||
identifiers?: string[];
|
||||
metadata?: Record<string, unknown>;
|
||||
cidr?: string;
|
||||
name?: string;
|
||||
network_type?: string;
|
||||
lifecycle_state?: string;
|
||||
kind?: string;
|
||||
env?: string;
|
||||
criticality?: string;
|
||||
owner_team?: string;
|
||||
};
|
||||
|
||||
export type GraphPath = {
|
||||
nodes: string[];
|
||||
edges: string[];
|
||||
cost?: number;
|
||||
};
|
||||
|
||||
export type GraphOverviewNode = {
|
||||
node_id: string;
|
||||
label: string;
|
||||
name?: string | null;
|
||||
kind?: string | null;
|
||||
metadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type GraphOverviewEdge = {
|
||||
source: string;
|
||||
target: string;
|
||||
type: string;
|
||||
confidence?: number | null;
|
||||
};
|
||||
|
||||
export type GraphOverviewResponse = {
|
||||
nodes: GraphOverviewNode[];
|
||||
edges: GraphOverviewEdge[];
|
||||
};
|
||||
|
||||
export type GraphQuery = {
|
||||
cypher: string;
|
||||
parameters: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type QueryResponse = {
|
||||
answer: string;
|
||||
paths?: GraphPath[];
|
||||
citations?: Record<string, unknown>[];
|
||||
graph_query?: GraphQuery;
|
||||
records?: Record<string, unknown>[];
|
||||
};
|
||||
|
||||
export type AuditEvent = {
|
||||
audit_id?: string;
|
||||
id?: string;
|
||||
event_type: string;
|
||||
details: Record<string, unknown>;
|
||||
timestamp: string;
|
||||
status: string;
|
||||
};
|
||||
|
||||
export type AuditListResponse = {
|
||||
events: AuditEvent[];
|
||||
total: number;
|
||||
page: number;
|
||||
page_size: number;
|
||||
has_more: boolean;
|
||||
};
|
||||
|
||||
export type ChatMessage = {
|
||||
message_id: string;
|
||||
role: "user" | "assistant" | "system" | "tool";
|
||||
content: string;
|
||||
timestamp: string;
|
||||
metadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type ChatStreamEvent =
|
||||
| { type: "message"; message: ChatMessage }
|
||||
| { type: "done" };
|
||||
|
||||
export type ChatSession = {
|
||||
session_id: string;
|
||||
user_id?: string;
|
||||
title?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
messages: ChatMessage[];
|
||||
};
|
||||
|
||||
export type ChatSessionSummary = {
|
||||
session_id: string;
|
||||
title?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
message_count: number;
|
||||
};
|
||||
|
||||
export type BulkDeleteResponse = {
|
||||
status: string;
|
||||
deleted: number;
|
||||
};
|
||||
|
||||
export type AuditClearResponse = {
|
||||
status: string;
|
||||
deleted: number;
|
||||
};
|
||||
|
||||
export type GraphClearResponse = {
|
||||
status: string;
|
||||
nodes_deleted: number;
|
||||
};
|
||||
|
||||
export type SandboxPermissions = {
|
||||
allow_shell: boolean;
|
||||
allow_network: boolean;
|
||||
allow_file_write: boolean;
|
||||
allow_unsafe_tools: boolean;
|
||||
allowed_tools: string[] | null;
|
||||
blocked_tools: string[];
|
||||
};
|
||||
|
||||
export type ThemeSettings = {
|
||||
mode: "dark" | "light";
|
||||
};
|
||||
|
||||
export type LLMSettings = {
|
||||
model: string;
|
||||
api_base?: string | null;
|
||||
api_key?: string | null;
|
||||
temperature: number;
|
||||
max_tokens: number;
|
||||
top_p: number;
|
||||
frequency_penalty: number;
|
||||
presence_penalty: number;
|
||||
max_context_tokens: number;
|
||||
max_retries: number;
|
||||
retry_delay: number;
|
||||
};
|
||||
|
||||
export type AppSettings = {
|
||||
theme: ThemeSettings;
|
||||
llm: LLMSettings;
|
||||
};
|
||||
|
||||
export type LLMSettingsUpdate = {
|
||||
model?: string | null;
|
||||
api_base?: string | null;
|
||||
api_key?: string | null;
|
||||
temperature?: number | null;
|
||||
max_tokens?: number | null;
|
||||
};
|
||||
|
||||
export type AppSettingsUpdate = {
|
||||
theme?: ThemeSettings;
|
||||
llm?: LLMSettingsUpdate;
|
||||
};
|
||||
|
||||
export type PermissionsResponse = {
|
||||
sandbox: SandboxPermissions;
|
||||
};
|
||||
|
||||
export type CollectorRunResponse = {
|
||||
task_id: string;
|
||||
status: string;
|
||||
};
|
||||
|
||||
export type ScannerOptions = {
|
||||
ping_concurrency: number;
|
||||
port_scan_workers: number;
|
||||
dns_resolution: boolean;
|
||||
aggressive: boolean;
|
||||
};
|
||||
|
||||
export type ScannerConfig = {
|
||||
network_cidrs: string[];
|
||||
ports: number[];
|
||||
port_preset: string;
|
||||
options: ScannerOptions;
|
||||
};
|
||||
|
||||
export type ScanHistoryItem = {
|
||||
id: number;
|
||||
started_at?: string | null;
|
||||
completed_at?: string | null;
|
||||
status: string;
|
||||
events_collected: number;
|
||||
error_message?: string | null;
|
||||
config_summary?: string | null;
|
||||
};
|
||||
|
||||
export type ScanHistoryResponse = {
|
||||
scans: ScanHistoryItem[];
|
||||
};
|
||||
|
||||
const API_TOKEN = import.meta.env.VITE_API_TOKEN as string | undefined;
|
||||
|
||||
const DEFAULT_HEADERS: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
"x-user-id": "ui-user",
|
||||
"x-roles": "viewer,planner,executor",
|
||||
...(API_TOKEN ? { Authorization: `Bearer ${API_TOKEN}` } : {}),
|
||||
};
|
||||
|
||||
const IP_REGEX = /^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$/;
|
||||
const MAC_REGEX = /^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$/;
|
||||
|
||||
export const API_BASE =
|
||||
(import.meta.env.VITE_API_BASE as string | undefined) ??
|
||||
(import.meta.env.DEV ? "http://localhost:8080" : window.location.origin);
|
||||
|
||||
async function fetchJson<T>(path: string, init: RequestInit = {}): Promise<T> {
|
||||
const response = await fetch(`${API_BASE}${path}`, {
|
||||
...init,
|
||||
headers: { ...DEFAULT_HEADERS, ...(init.headers || {}) },
|
||||
});
|
||||
const text = await response.text();
|
||||
if (!response.ok) {
|
||||
let detail = text;
|
||||
try {
|
||||
const payload = JSON.parse(text) as { detail?: string };
|
||||
if (payload.detail) {
|
||||
detail = payload.detail;
|
||||
}
|
||||
} catch {
|
||||
// ignore JSON parse errors
|
||||
}
|
||||
throw new Error(detail || `Request failed (${response.status})`);
|
||||
}
|
||||
if (!text) {
|
||||
return {} as T;
|
||||
}
|
||||
return JSON.parse(text) as T;
|
||||
}
|
||||
|
||||
export function listAssets(limit = 200): Promise<GraphNode[]> {
|
||||
return fetchJson(`/graph/assets?limit=${limit}`);
|
||||
}
|
||||
|
||||
export function listNetworks(limit = 100): Promise<GraphNode[]> {
|
||||
return fetchJson(`/graph/networks?limit=${limit}`);
|
||||
}
|
||||
|
||||
export function getGraphOverview(params?: {
|
||||
node_limit?: number;
|
||||
edge_limit?: number;
|
||||
}): Promise<GraphOverviewResponse> {
|
||||
const queryParams = new URLSearchParams();
|
||||
if (params?.node_limit) queryParams.set("node_limit", params.node_limit.toString());
|
||||
if (params?.edge_limit) queryParams.set("edge_limit", params.edge_limit.toString());
|
||||
const url = `/graph/overview${queryParams.toString() ? `?${queryParams.toString()}` : ""}`;
|
||||
return fetchJson(url);
|
||||
}
|
||||
|
||||
export function listAuditEvents(params?: {
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
event_type?: string;
|
||||
start_date?: string;
|
||||
end_date?: string;
|
||||
}): Promise<AuditListResponse> {
|
||||
const queryParams = new URLSearchParams();
|
||||
if (params?.page) queryParams.set("page", params.page.toString());
|
||||
if (params?.page_size) queryParams.set("page_size", params.page_size.toString());
|
||||
if (params?.event_type) queryParams.set("event_type", params.event_type);
|
||||
if (params?.start_date) queryParams.set("start_date", params.start_date);
|
||||
if (params?.end_date) queryParams.set("end_date", params.end_date);
|
||||
|
||||
const url = `/audit/${queryParams.toString() ? '?' + queryParams.toString() : ''}`;
|
||||
return fetchJson(url);
|
||||
}
|
||||
|
||||
export function clearAuditEvents(): Promise<AuditClearResponse> {
|
||||
return fetchJson("/audit/", { method: "DELETE" });
|
||||
}
|
||||
|
||||
export function getSandboxPermissions(): Promise<SandboxPermissions> {
|
||||
return fetchJson("/permissions/").then((data: PermissionsResponse) => data.sandbox);
|
||||
}
|
||||
|
||||
export function updateSandboxPermissions(permissions: SandboxPermissions): Promise<SandboxPermissions> {
|
||||
return fetchJson("/permissions/", {
|
||||
method: "PUT",
|
||||
body: JSON.stringify(permissions),
|
||||
}).then((data: PermissionsResponse) => data.sandbox);
|
||||
}
|
||||
|
||||
export function getAppSettings(): Promise<AppSettings> {
|
||||
return fetchJson("/settings/");
|
||||
}
|
||||
|
||||
export function updateAppSettings(payload: AppSettingsUpdate): Promise<AppSettings> {
|
||||
return fetchJson("/settings/", {
|
||||
method: "PUT",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
}
|
||||
|
||||
export function runQuery(question: string): Promise<QueryResponse> {
|
||||
return fetchJson("/query", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ question }),
|
||||
});
|
||||
}
|
||||
|
||||
export function startScan(): Promise<CollectorRunResponse> {
|
||||
return fetchJson("/collector/scan", { method: "POST" });
|
||||
}
|
||||
|
||||
export function cancelScan(taskId: string): Promise<{ status: string }> {
|
||||
return fetchJson("/collector/scan/cancel", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ task_id: taskId }),
|
||||
});
|
||||
}
|
||||
|
||||
export function getCollectorConfig(): Promise<ScannerConfig> {
|
||||
return fetchJson("/collector/config");
|
||||
}
|
||||
|
||||
export function updateCollectorConfig(payload: ScannerConfig): Promise<ScannerConfig> {
|
||||
return fetchJson("/collector/config", {
|
||||
method: "PUT",
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
}
|
||||
|
||||
export function listScanHistory(limit = 5): Promise<ScanHistoryResponse> {
|
||||
return fetchJson(`/collector/scan/history?limit=${limit}`);
|
||||
}
|
||||
|
||||
export function listChatSessions(): Promise<ChatSessionSummary[]> {
|
||||
return fetchJson("/chat/sessions");
|
||||
}
|
||||
|
||||
export function deleteAllChatSessions(): Promise<BulkDeleteResponse> {
|
||||
return fetchJson("/chat/sessions", { method: "DELETE" });
|
||||
}
|
||||
|
||||
export function createChatSession(title?: string): Promise<ChatSession> {
|
||||
return fetchJson("/chat/sessions", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ title: title ?? null }),
|
||||
});
|
||||
}
|
||||
|
||||
export function getChatSession(sessionId: string): Promise<ChatSession> {
|
||||
return fetchJson(`/chat/sessions/${sessionId}`);
|
||||
}
|
||||
|
||||
export function deleteChatSession(sessionId: string): Promise<{ status: string }> {
|
||||
return fetchJson(`/chat/sessions/${sessionId}`, { method: "DELETE" });
|
||||
}
|
||||
|
||||
export function resetGraph(): Promise<GraphClearResponse> {
|
||||
return fetchJson("/graph/", { method: "DELETE" });
|
||||
}
|
||||
|
||||
export function getGraphPaths(
|
||||
sourceId: string,
|
||||
targetId: string,
|
||||
maxDepth = 4
|
||||
): Promise<GraphPath[]> {
|
||||
const params = new URLSearchParams({
|
||||
source_id: sourceId,
|
||||
target_id: targetId,
|
||||
max_depth: maxDepth.toString(),
|
||||
});
|
||||
return fetchJson(`/graph/paths?${params.toString()}`);
|
||||
}
|
||||
|
||||
export function appendChatMessage(
|
||||
sessionId: string,
|
||||
message: {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string;
|
||||
metadata?: Record<string, unknown>;
|
||||
request_id?: string;
|
||||
}
|
||||
): Promise<ChatSession> {
|
||||
return fetchJson(`/chat/sessions/${sessionId}/messages`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(message),
|
||||
});
|
||||
}
|
||||
|
||||
export async function appendChatMessageStream(
|
||||
sessionId: string,
|
||||
message: { role: "user" | "assistant" | "system"; content: string; metadata?: Record<string, unknown> },
|
||||
onEvent: (event: ChatStreamEvent) => void,
|
||||
signal?: AbortSignal,
|
||||
requestId?: string
|
||||
): Promise<void> {
|
||||
const payload = requestId ? { ...message, request_id: requestId } : message;
|
||||
const response = await fetch(`${API_BASE}/chat/sessions/${sessionId}/messages?stream=true`, {
|
||||
method: "POST",
|
||||
headers: { ...DEFAULT_HEADERS },
|
||||
body: JSON.stringify(payload),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(text || `Request failed (${response.status})`);
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const event = JSON.parse(trimmed) as ChatStreamEvent;
|
||||
onEvent(event);
|
||||
if (event.type === "done") {
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors for partial lines
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function cancelChatRequest(
|
||||
sessionId: string,
|
||||
requestId: string
|
||||
): Promise<{ status: string }> {
|
||||
return fetchJson(`/chat/sessions/${sessionId}/cancel`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({ request_id: requestId }),
|
||||
});
|
||||
}
|
||||
|
||||
export function isIpv4(value: string): boolean {
|
||||
if (!IP_REGEX.test(value)) {
|
||||
return false;
|
||||
}
|
||||
return value.split(".").every((part) => {
|
||||
const num = Number(part);
|
||||
return Number.isInteger(num) && num >= 0 && num <= 255;
|
||||
});
|
||||
}
|
||||
|
||||
export function getAssetIp(asset: GraphNode): string | null {
|
||||
const identifiers = asset.identifiers ?? [];
|
||||
const fromIdentifiers = identifiers.find((id) => isIpv4(id));
|
||||
if (fromIdentifiers) {
|
||||
return fromIdentifiers;
|
||||
}
|
||||
const metadata = (asset.metadata ?? {}) as Record<string, unknown>;
|
||||
const candidate =
|
||||
(metadata.ip as string | undefined) ??
|
||||
(metadata.ip_address as string | undefined) ??
|
||||
(metadata.public_ip as string | undefined);
|
||||
if (candidate && isIpv4(candidate)) {
|
||||
return candidate;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getAssetName(asset: GraphNode): string {
|
||||
// Try to get hostname from metadata first
|
||||
const metadata = (asset.metadata ?? {}) as Record<string, unknown>;
|
||||
const hostname =
|
||||
(metadata.hostname as string | undefined) ??
|
||||
(metadata.name as string | undefined) ??
|
||||
(metadata.host as string | undefined);
|
||||
|
||||
if (hostname && typeof hostname === "string" && hostname.trim()) {
|
||||
return hostname.trim();
|
||||
}
|
||||
|
||||
// Fall back to hostname from identifiers
|
||||
const identifiers = asset.identifiers ?? [];
|
||||
const hostnameId = identifiers.find((id) => !isIpv4(id) && !MAC_REGEX.test(id));
|
||||
if (hostnameId) {
|
||||
return hostnameId;
|
||||
}
|
||||
|
||||
// Fall back to IP address
|
||||
const ip = getAssetIp(asset);
|
||||
if (ip) {
|
||||
return ip;
|
||||
}
|
||||
|
||||
// Last resort: use node_id
|
||||
return getNodeId(asset) ?? "unknown";
|
||||
}
|
||||
|
||||
export function getAssetMac(asset: GraphNode): string | null {
|
||||
const identifiers = asset.identifiers ?? [];
|
||||
const fromIdentifiers = identifiers.find((id) => MAC_REGEX.test(id));
|
||||
if (fromIdentifiers) {
|
||||
return fromIdentifiers;
|
||||
}
|
||||
const metadata = (asset.metadata ?? {}) as Record<string, unknown>;
|
||||
const candidate =
|
||||
(metadata.mac as string | undefined) ?? (metadata.mac_address as string | undefined);
|
||||
if (candidate && MAC_REGEX.test(candidate)) {
|
||||
return candidate;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getAssetStatus(asset: GraphNode): "online" | "idle" | "offline" {
|
||||
const metadata = (asset.metadata ?? {}) as Record<string, unknown>;
|
||||
const rawStatus = String(
|
||||
asset.lifecycle_state ?? metadata.status ?? metadata.state ?? ""
|
||||
).toLowerCase();
|
||||
if (
|
||||
rawStatus.includes("online") ||
|
||||
rawStatus.includes("up") ||
|
||||
rawStatus.includes("active") ||
|
||||
rawStatus.includes("reachable")
|
||||
) {
|
||||
return "online";
|
||||
}
|
||||
if (rawStatus.includes("idle") || rawStatus.includes("warning") || rawStatus.includes("degraded")) {
|
||||
return "idle";
|
||||
}
|
||||
return "offline";
|
||||
}
|
||||
|
||||
export function getNetworkName(network: GraphNode | null): string {
|
||||
if (!network) {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
// Try name property first
|
||||
if (network.name && typeof network.name === "string" && network.name.trim()) {
|
||||
return network.name.trim();
|
||||
}
|
||||
|
||||
// Fall back to CIDR
|
||||
if (network.cidr && typeof network.cidr === "string") {
|
||||
return network.cidr;
|
||||
}
|
||||
|
||||
// Fall back to network_type
|
||||
const metadata = (network.metadata ?? {}) as Record<string, unknown>;
|
||||
const networkType = metadata.network_type as string | undefined;
|
||||
if (networkType) {
|
||||
return `${networkType} network`;
|
||||
}
|
||||
|
||||
return "Unnamed network";
|
||||
}
|
||||
|
||||
export function filterAssetsByNetwork(
|
||||
assets: GraphNode[],
|
||||
network: GraphNode | null
|
||||
): GraphNode[] {
|
||||
const cidr = network?.cidr;
|
||||
if (!cidr) {
|
||||
return assets;
|
||||
}
|
||||
return assets.filter((asset) => {
|
||||
const ip = getAssetIp(asset);
|
||||
return ip ? isIpInCidr(ip, cidr) : false;
|
||||
});
|
||||
}
|
||||
|
||||
function ipToNumber(ip: string): number | null {
|
||||
if (!isIpv4(ip)) {
|
||||
return null;
|
||||
}
|
||||
const [a, b, c, d] = ip.split(".").map((part) => Number(part));
|
||||
return (((a << 24) >>> 0) + (b << 16) + (c << 8) + d) >>> 0;
|
||||
}
|
||||
|
||||
function numberToIp(value: number): string {
|
||||
return [
|
||||
(value >>> 24) & 255,
|
||||
(value >>> 16) & 255,
|
||||
(value >>> 8) & 255,
|
||||
value & 255,
|
||||
].join(".");
|
||||
}
|
||||
|
||||
function parseCidr(cidr: string): { base: number; maskBits: number } | null {
|
||||
const parts = cidr.split("/");
|
||||
if (parts.length !== 2) {
|
||||
return null;
|
||||
}
|
||||
const base = ipToNumber(parts[0]);
|
||||
const maskBits = Number(parts[1]);
|
||||
if (base === null || !Number.isInteger(maskBits) || maskBits < 0 || maskBits > 32) {
|
||||
return null;
|
||||
}
|
||||
const mask = maskBits === 0 ? 0 : (~0 << (32 - maskBits)) >>> 0;
|
||||
return { base: base & mask, maskBits };
|
||||
}
|
||||
|
||||
export function isIpInCidr(ip: string, cidr: string): boolean {
|
||||
const parsed = parseCidr(cidr);
|
||||
const ipValue = ipToNumber(ip);
|
||||
if (!parsed || ipValue === null) {
|
||||
return false;
|
||||
}
|
||||
const mask = parsed.maskBits === 0 ? 0 : (~0 << (32 - parsed.maskBits)) >>> 0;
|
||||
return (ipValue & mask) === parsed.base;
|
||||
}
|
||||
|
||||
export function expandCidr(cidr: string, maxHosts = 256): string[] {
|
||||
const parsed = parseCidr(cidr);
|
||||
if (!parsed) {
|
||||
return [];
|
||||
}
|
||||
const hostCount = Math.min(2 ** (32 - parsed.maskBits), maxHosts);
|
||||
const addresses: string[] = [];
|
||||
for (let i = 0; i < hostCount; i += 1) {
|
||||
addresses.push(numberToIp((parsed.base + i) >>> 0));
|
||||
}
|
||||
return addresses;
|
||||
}
|
||||
|
||||
export function getNodeId(node: GraphNode | Record<string, unknown>): string | null {
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
const asGraphNode = node as GraphNode;
|
||||
return (asGraphNode.node_id ?? asGraphNode.id ?? null) as string | null;
|
||||
}
|
||||
@@ -0,0 +1,620 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { Send, Square, Bot, User, Wrench, Kanban } from "lucide-react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import {
|
||||
appendChatMessage,
|
||||
appendChatMessageStream,
|
||||
cancelChatRequest,
|
||||
getChatSession,
|
||||
type ChatMessage,
|
||||
type ChatSession,
|
||||
} from "../api";
|
||||
import { useToast } from "./Toast";
|
||||
|
||||
type PlanItem = {
|
||||
id: number;
|
||||
text: string;
|
||||
status: string;
|
||||
result?: string;
|
||||
};
|
||||
|
||||
type PlanState = {
|
||||
items: PlanItem[];
|
||||
currentIndex: number;
|
||||
completedCount: number;
|
||||
};
|
||||
|
||||
type ChatInterfaceProps = {
|
||||
sessionId: string | null;
|
||||
onSessionUpdated?: (session: ChatSession) => void;
|
||||
onCreateSession?: (firstMessage: string) => Promise<ChatSession | null> | ChatSession | null;
|
||||
};
|
||||
|
||||
export function ChatInterface({ sessionId, onSessionUpdated, onCreateSession }: ChatInterfaceProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [isSending, setIsSending] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [planAnchorId, setPlanAnchorId] = useState<string | null>(null);
|
||||
const [messageHistory, setMessageHistory] = useState<string[]>([]);
|
||||
const [historyIndex, setHistoryIndex] = useState(-1);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const currentRequestRef = useRef<string | null>(null);
|
||||
const currentSessionRef = useRef<string | null>(null);
|
||||
const suppressLoadRef = useRef<string | null>(null);
|
||||
const { showToast } = useToast();
|
||||
|
||||
const upsertMessage = (msg: ChatMessage) => {
|
||||
setMessages((prev) => {
|
||||
const index = prev.findIndex((item) => item.message_id === msg.message_id);
|
||||
if (index === -1) {
|
||||
return [...prev, msg];
|
||||
}
|
||||
const next = [...prev];
|
||||
next[index] = msg;
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const formatPayload = (value: unknown): string => {
|
||||
if (value === null || value === undefined) {
|
||||
return "";
|
||||
}
|
||||
if (typeof value === "string") {
|
||||
const trimmed = value.trim();
|
||||
if ((trimmed.startsWith("{") && trimmed.endsWith("}")) || (trimmed.startsWith("[") && trimmed.endsWith("]"))) {
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
return JSON.stringify(parsed, null, 2);
|
||||
} catch {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
try {
|
||||
return JSON.stringify(value, null, 2);
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
};
|
||||
|
||||
const summarizeText = (value: string, maxLength = 120) => {
|
||||
const trimmed = value.replace(/\s+/g, " ").trim();
|
||||
if (trimmed.length <= maxLength) {
|
||||
return trimmed;
|
||||
}
|
||||
return `${trimmed.slice(0, maxLength)}...`;
|
||||
};
|
||||
|
||||
const getPlanState = (itemsRaw: unknown): PlanState => {
|
||||
if (!Array.isArray(itemsRaw)) {
|
||||
return { items: [], currentIndex: -1, completedCount: 0 };
|
||||
}
|
||||
const items = itemsRaw
|
||||
.map((item, index) => {
|
||||
if (typeof item === "string") {
|
||||
return { id: index + 1, text: item, status: "pending" };
|
||||
}
|
||||
if (item && typeof item === "object") {
|
||||
const record = item as Record<string, unknown>;
|
||||
const text = String(record.text || record.item || "");
|
||||
if (!text) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
id: Number(record.id || index + 1),
|
||||
text,
|
||||
status: String(record.status || "pending"),
|
||||
result: record.result ? String(record.result) : undefined,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter((item): item is PlanItem => Boolean(item));
|
||||
|
||||
const isComplete = (status: string) => status === "complete" || status === "skip";
|
||||
const completedCount = items.filter((item) => isComplete(item.status)).length;
|
||||
const currentIndex = items.findIndex((item) => item.status === "pending");
|
||||
return {
|
||||
items,
|
||||
currentIndex,
|
||||
completedCount,
|
||||
};
|
||||
};
|
||||
|
||||
const planWindowMessages = useMemo(() => {
|
||||
if (!planAnchorId) {
|
||||
return messages;
|
||||
}
|
||||
const anchorIndex = messages.findIndex((msg) => msg.message_id === planAnchorId);
|
||||
if (anchorIndex === -1) {
|
||||
return messages;
|
||||
}
|
||||
return messages.slice(anchorIndex + 1);
|
||||
}, [messages, planAnchorId]);
|
||||
|
||||
const planState = useMemo(() => {
|
||||
for (let i = planWindowMessages.length - 1; i >= 0; i -= 1) {
|
||||
const msg = planWindowMessages[i];
|
||||
if (msg.role !== "tool") {
|
||||
continue;
|
||||
}
|
||||
if (msg.metadata?.tool_name !== "todo") {
|
||||
continue;
|
||||
}
|
||||
const result = msg.metadata?.result as Record<string, unknown> | undefined;
|
||||
const items = result?.items;
|
||||
const state = getPlanState(items);
|
||||
if (state.items.length > 0) {
|
||||
return state;
|
||||
}
|
||||
}
|
||||
return { items: [], currentIndex: -1, completedCount: 0 };
|
||||
}, [planWindowMessages]);
|
||||
|
||||
const shouldHideMessage = (msg: ChatMessage) => {
|
||||
const kind = (msg.metadata?.kind as string | undefined) || "";
|
||||
if (kind === "plan") {
|
||||
return true;
|
||||
}
|
||||
if (kind === "internal" || msg.metadata?.cancelled) {
|
||||
return true;
|
||||
}
|
||||
if (msg.role === "tool" && msg.metadata?.tool_name === "todo") {
|
||||
return true;
|
||||
}
|
||||
if (kind === "tool_result" && msg.metadata?.tool_name === "todo") {
|
||||
return true;
|
||||
}
|
||||
if (kind === "tool_call") {
|
||||
const toolCalls = (msg.metadata?.tool_calls as Array<Record<string, unknown>>) || [];
|
||||
if (
|
||||
toolCalls.length > 0 &&
|
||||
toolCalls.every((call) => String(call.name || "") === "todo")
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const visibleMessages = useMemo(
|
||||
() => messages.filter((msg) => !shouldHideMessage(msg)),
|
||||
[messages]
|
||||
);
|
||||
|
||||
const inputPlaceholder = "Ask Eidolon";
|
||||
|
||||
const renderMessageBody = (msg: ChatMessage) => {
|
||||
const kind = (msg.metadata?.kind as string | undefined) || "";
|
||||
if (kind === "tool_call") {
|
||||
const toolCalls = (msg.metadata?.tool_calls as Array<Record<string, unknown>>) || [];
|
||||
const grouped = toolCalls.reduce<Record<string, number>>((acc, call) => {
|
||||
const name = String(call.name || "tool");
|
||||
acc[name] = (acc[name] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
const summaryLabel = Object.entries(grouped)
|
||||
.map(([name, count]) => (count > 1 ? `${name} x${count}` : name))
|
||||
.join(", ");
|
||||
return (
|
||||
<details className="tool-details">
|
||||
<summary className="tool-summary">
|
||||
<span className="tool-summary-title">Tool call</span>
|
||||
<span className="tool-summary-name">{summaryLabel || "tool"}</span>
|
||||
</summary>
|
||||
<div className="tool-block">
|
||||
{toolCalls.map((call, index) => {
|
||||
const name = String(call.name || "tool");
|
||||
const args = call.arguments || {};
|
||||
return (
|
||||
<div key={`${name}-${index}`} className="tool-entry">
|
||||
<div className="tool-title">{name}</div>
|
||||
<pre>{formatPayload(args)}</pre>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</details>
|
||||
);
|
||||
}
|
||||
|
||||
if (kind === "tool_result") {
|
||||
const toolName = String(msg.metadata?.tool_name || "tool");
|
||||
const success = msg.metadata?.success !== false;
|
||||
const payload = msg.metadata?.result || msg.metadata?.error || msg.content;
|
||||
const preview = summarizeText(formatPayload(payload), 140);
|
||||
return (
|
||||
<details className="tool-details">
|
||||
<summary className="tool-summary">
|
||||
<span className="tool-summary-title">Tool result</span>
|
||||
<span className={`tool-summary-name ${success ? "ok" : "error"}`}>{toolName}</span>
|
||||
<span className="tool-summary-preview">{preview}</span>
|
||||
</summary>
|
||||
<div className="tool-block">
|
||||
<div className={`tool-title ${success ? "ok" : "error"}`}>{toolName}</div>
|
||||
<pre>{formatPayload(payload)}</pre>
|
||||
</div>
|
||||
</details>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="message-text">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]}>
|
||||
{msg.content}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const labelForKind = (kind: string) => {
|
||||
switch (kind) {
|
||||
case "thinking":
|
||||
return "Thinking";
|
||||
case "warning":
|
||||
return "Warning";
|
||||
case "error":
|
||||
return "Error";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let active = true;
|
||||
if (abortControllerRef.current) {
|
||||
// Only abort if the session ID has changed to something other than what we're currently processing
|
||||
// This prevents aborting the initial request when a new session is created
|
||||
if (currentSessionRef.current !== sessionId) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
currentRequestRef.current = null;
|
||||
currentSessionRef.current = null;
|
||||
setIsSending(false);
|
||||
}
|
||||
}
|
||||
setPlanAnchorId(null);
|
||||
if (!sessionId) {
|
||||
currentSessionRef.current = null;
|
||||
setMessages([]);
|
||||
setIsLoading(false);
|
||||
return () => {
|
||||
active = false;
|
||||
};
|
||||
}
|
||||
// Update ref to match new session
|
||||
currentSessionRef.current = sessionId;
|
||||
setIsLoading(true);
|
||||
const loadSession = async () => {
|
||||
try {
|
||||
const session = await getChatSession(sessionId);
|
||||
if (!active) {
|
||||
return;
|
||||
}
|
||||
if (suppressLoadRef.current === sessionId && session.messages.length === 0) {
|
||||
return;
|
||||
}
|
||||
setMessages(session.messages);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to load chat session";
|
||||
if (active) {
|
||||
showToast(`Chat error: ${message}`, "error");
|
||||
}
|
||||
} finally {
|
||||
if (active) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
loadSession();
|
||||
return () => {
|
||||
active = false;
|
||||
};
|
||||
}, [sessionId, showToast, onSessionUpdated]);
|
||||
|
||||
const refreshSession = async (activeSessionId: string) => {
|
||||
const session = await getChatSession(activeSessionId);
|
||||
const shouldUpdate =
|
||||
sessionId === activeSessionId || currentSessionRef.current === activeSessionId;
|
||||
if (!shouldUpdate) {
|
||||
return session;
|
||||
}
|
||||
setMessages(session.messages);
|
||||
if (onSessionUpdated) {
|
||||
onSessionUpdated(session);
|
||||
}
|
||||
return session;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length === 0) {
|
||||
return;
|
||||
}
|
||||
const lastUser = [...messages].reverse().find((msg) => msg.role === "user");
|
||||
if (!lastUser) {
|
||||
return;
|
||||
}
|
||||
const anchorExists = planAnchorId
|
||||
? messages.some((msg) => msg.message_id === planAnchorId)
|
||||
: false;
|
||||
if (!planAnchorId || !anchorExists) {
|
||||
setPlanAnchorId(lastUser.message_id);
|
||||
}
|
||||
}, [messages, planAnchorId]);
|
||||
|
||||
// Auto-scroll to bottom when messages change
|
||||
useEffect(() => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [visibleMessages]);
|
||||
|
||||
const handleSend = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed || isSending) {
|
||||
return;
|
||||
}
|
||||
setIsSending(true);
|
||||
|
||||
let activeSessionId = sessionId;
|
||||
let createdSession = false;
|
||||
if (!activeSessionId) {
|
||||
if (!onCreateSession) {
|
||||
showToast("Unable to start a new chat", "error");
|
||||
setIsSending(false);
|
||||
return;
|
||||
}
|
||||
let created: ChatSession | null = null;
|
||||
try {
|
||||
created = await onCreateSession(trimmed);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to start a new chat";
|
||||
showToast(message, "error");
|
||||
setIsSending(false);
|
||||
return;
|
||||
}
|
||||
activeSessionId = created?.session_id ?? null;
|
||||
if (!activeSessionId) {
|
||||
showToast("Failed to start a new chat", "error");
|
||||
setIsSending(false);
|
||||
return;
|
||||
}
|
||||
createdSession = true;
|
||||
}
|
||||
|
||||
const requestId = `req_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`;
|
||||
currentRequestRef.current = requestId;
|
||||
currentSessionRef.current = activeSessionId;
|
||||
if (createdSession) {
|
||||
suppressLoadRef.current = activeSessionId;
|
||||
}
|
||||
const userMsg = { role: "user" as const, content: trimmed };
|
||||
const optimistic = {
|
||||
message_id: `temp-${Date.now()}`,
|
||||
role: "user" as const,
|
||||
content: trimmed,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
setMessages((prev) => [...prev, optimistic]);
|
||||
setPlanAnchorId(optimistic.message_id);
|
||||
|
||||
// Add to message history
|
||||
setMessageHistory((prev) => [...prev, trimmed]);
|
||||
setHistoryIndex(-1);
|
||||
|
||||
setInput("");
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
let receivedAny = false;
|
||||
try {
|
||||
await appendChatMessageStream(
|
||||
activeSessionId,
|
||||
userMsg,
|
||||
(event) => {
|
||||
if (event.type === "message") {
|
||||
receivedAny = true;
|
||||
upsertMessage(event.message);
|
||||
}
|
||||
},
|
||||
abortController.signal,
|
||||
requestId
|
||||
);
|
||||
try {
|
||||
await refreshSession(activeSessionId);
|
||||
} catch {
|
||||
// Ignore refresh errors; stream already rendered messages
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
try {
|
||||
await refreshSession(activeSessionId);
|
||||
} catch {
|
||||
// Ignore refresh errors after abort
|
||||
}
|
||||
return;
|
||||
}
|
||||
const message = err instanceof Error ? err.message : "Request failed";
|
||||
if (!receivedAny) {
|
||||
try {
|
||||
const session = await appendChatMessage(activeSessionId, {
|
||||
...userMsg,
|
||||
request_id: requestId,
|
||||
});
|
||||
setMessages(session.messages);
|
||||
if (onSessionUpdated) {
|
||||
onSessionUpdated(session);
|
||||
}
|
||||
} catch {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
message_id: `temp-${Date.now() + 1}`,
|
||||
role: "assistant",
|
||||
content: `Unable to complete request: ${message}`,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
showToast(`Query error: ${message}`, "error");
|
||||
}
|
||||
} else {
|
||||
showToast(`Streaming error: ${message}`, "warning");
|
||||
}
|
||||
} finally {
|
||||
currentRequestRef.current = null;
|
||||
currentSessionRef.current = null;
|
||||
suppressLoadRef.current = null;
|
||||
abortControllerRef.current = null;
|
||||
setIsSending(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleStop = async () => {
|
||||
const activeSessionId = currentSessionRef.current;
|
||||
if (!activeSessionId || !abortControllerRef.current || !currentRequestRef.current) {
|
||||
return;
|
||||
}
|
||||
const requestId = currentRequestRef.current;
|
||||
abortControllerRef.current.abort();
|
||||
try {
|
||||
const result = await cancelChatRequest(activeSessionId, requestId);
|
||||
await refreshSession(activeSessionId);
|
||||
if (result.status === "cancelled") {
|
||||
showToast("Request cancelled", "warning");
|
||||
} else {
|
||||
showToast("No active request to cancel", "info");
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to cancel request";
|
||||
showToast(message, "error");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
if (messageHistory.length === 0) return;
|
||||
|
||||
const newIndex = historyIndex === -1
|
||||
? messageHistory.length - 1
|
||||
: Math.max(0, historyIndex - 1);
|
||||
|
||||
setHistoryIndex(newIndex);
|
||||
setInput(messageHistory[newIndex]);
|
||||
} else if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
if (historyIndex === -1) return;
|
||||
|
||||
const newIndex = historyIndex + 1;
|
||||
if (newIndex >= messageHistory.length) {
|
||||
setHistoryIndex(-1);
|
||||
setInput("");
|
||||
} else {
|
||||
setHistoryIndex(newIndex);
|
||||
setInput(messageHistory[newIndex]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="chat-view">
|
||||
{planState.items.length > 0 && (
|
||||
<div className="plan-dock">
|
||||
<details className="plan-details">
|
||||
<summary className="plan-summary">
|
||||
<span className="plan-title">Todo</span>
|
||||
<span className="plan-progress">
|
||||
{planState.completedCount}/{planState.items.length}
|
||||
</span>
|
||||
<span className="plan-current">
|
||||
{planState.currentIndex >= 0
|
||||
? planState.items[planState.currentIndex]?.text
|
||||
: "All steps complete"}
|
||||
</span>
|
||||
</summary>
|
||||
<ol className="plan-list">
|
||||
{planState.items.map((item, index) => {
|
||||
const isCurrent = index === planState.currentIndex;
|
||||
const status = item.status;
|
||||
return (
|
||||
<li
|
||||
key={item.id}
|
||||
className={`plan-step ${status} ${isCurrent ? "current" : ""}`}
|
||||
>
|
||||
<span className="plan-step-title">{item.text}</span>
|
||||
{item.result ? (
|
||||
<span className="plan-step-result">{item.result}</span>
|
||||
) : null}
|
||||
</li>
|
||||
);
|
||||
})}
|
||||
</ol>
|
||||
</details>
|
||||
</div>
|
||||
)}
|
||||
<div className="chat-scroll-container">
|
||||
<div className="chat-content">
|
||||
{visibleMessages.length === 0 && !isSending && (
|
||||
<div className="chat-empty">
|
||||
<Kanban className="chat-empty-icon" strokeWidth={1.5} style={{ transform: 'rotate(270deg)' }} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{visibleMessages.map((msg) => {
|
||||
const kind = (msg.metadata?.kind as string | undefined) || "";
|
||||
const kindClass = kind && kind !== "message" ? kind : "";
|
||||
const label = labelForKind(kind);
|
||||
return (
|
||||
<div
|
||||
key={msg.message_id}
|
||||
className={`message ${msg.role} ${kindClass}`}
|
||||
>
|
||||
<div className="message-content">
|
||||
{label ? <div className="message-label">{label}</div> : null}
|
||||
{renderMessageBody(msg)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="chat-footer">
|
||||
<form className="chat-input-area" onSubmit={handleSend}>
|
||||
<input
|
||||
type="text"
|
||||
className="chat-input"
|
||||
placeholder={inputPlaceholder}
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
{isSending ? (
|
||||
<button
|
||||
type="button"
|
||||
className="chat-send-btn chat-stop-btn"
|
||||
onClick={handleStop}
|
||||
title="Stop AI"
|
||||
>
|
||||
<Square size={18} />
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
type="submit"
|
||||
className="chat-send-btn"
|
||||
disabled={isLoading}
|
||||
title="Send message"
|
||||
>
|
||||
<Send size={18} />
|
||||
</button>
|
||||
)}
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
import React, { useMemo } from "react";
|
||||
import { Globe, Terminal } from "lucide-react";
|
||||
import {
|
||||
filterAssetsByNetwork,
|
||||
getAssetIp,
|
||||
getAssetMac,
|
||||
getAssetName,
|
||||
getAssetStatus,
|
||||
getNodeId,
|
||||
type GraphNode,
|
||||
} from "../api";
|
||||
|
||||
interface DeviceListProps {
|
||||
assets: GraphNode[];
|
||||
network: GraphNode | null;
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
export function DeviceList({ assets, network, isLoading = false }: DeviceListProps) {
|
||||
const scopedAssets = useMemo(
|
||||
() => filterAssetsByNetwork(assets, network),
|
||||
[assets, network]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="section-title">Network Devices</div>
|
||||
{isLoading && <div className="text-muted">Loading devices...</div>}
|
||||
{!isLoading && scopedAssets.length === 0 && (
|
||||
<div className="text-muted">No devices discovered yet.</div>
|
||||
)}
|
||||
{!isLoading && scopedAssets.length > 0 && (
|
||||
<div className="device-grid">
|
||||
{scopedAssets.map((asset) => {
|
||||
const name = getAssetName(asset);
|
||||
const ip = getAssetIp(asset);
|
||||
const mac = getAssetMac(asset);
|
||||
const status = getAssetStatus(asset);
|
||||
const metadata = (asset.metadata ?? {}) as Record<string, unknown>;
|
||||
const vendor = metadata.vendor as string | undefined;
|
||||
const hostname = metadata.hostname as string | undefined;
|
||||
return (
|
||||
<div key={getNodeId(asset) ?? ip} className={`device-card ${status}`}>
|
||||
<div className="device-ip">{ip ?? mac ?? "unknown"}</div>
|
||||
{hostname && hostname !== ip && <div className="device-mac">{hostname}</div>}
|
||||
{mac && <div className="device-mac" style={{ fontSize: "10px", opacity: 0.7 }}>{mac}</div>}
|
||||
{vendor && <div className="device-mac" style={{ fontSize: "9px", opacity: 0.6 }}>{vendor}</div>}
|
||||
<div className="device-icons">
|
||||
<Globe size={12} />
|
||||
<Terminal size={12} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
import React, { useState } from "react";
|
||||
import { type AuditEvent, listAuditEvents } from "../api";
|
||||
|
||||
interface EventLogProps {
|
||||
events: AuditEvent[];
|
||||
isLoading?: boolean;
|
||||
limit?: number;
|
||||
showFilters?: boolean;
|
||||
total?: number;
|
||||
}
|
||||
|
||||
function formatRelativeTime(timestamp: string): string {
|
||||
const now = Date.now();
|
||||
const then = new Date(timestamp).getTime();
|
||||
if (Number.isNaN(then)) {
|
||||
return "unknown";
|
||||
}
|
||||
const diffSeconds = Math.max(0, Math.floor((now - then) / 1000));
|
||||
if (diffSeconds < 60) {
|
||||
return `${diffSeconds}s ago`;
|
||||
}
|
||||
const diffMinutes = Math.floor(diffSeconds / 60);
|
||||
if (diffMinutes < 60) {
|
||||
return `${diffMinutes}m ago`;
|
||||
}
|
||||
const diffHours = Math.floor(diffMinutes / 60);
|
||||
if (diffHours < 24) {
|
||||
return `${diffHours}h ago`;
|
||||
}
|
||||
const diffDays = Math.floor(diffHours / 24);
|
||||
return `${diffDays}d ago`;
|
||||
}
|
||||
|
||||
function describeEvent(event: AuditEvent): string {
|
||||
const details = event.details || {};
|
||||
const eventsProcessed =
|
||||
typeof details.events_processed === "number" ? details.events_processed : 0;
|
||||
const totalEvents =
|
||||
typeof details.total_events === "number" ? details.total_events : 0;
|
||||
const accepted = typeof details.accepted === "number" ? details.accepted : 0;
|
||||
const steps = typeof details.steps === "number" ? details.steps : 0;
|
||||
const pathsFound = typeof details.paths_found === "number" ? details.paths_found : 0;
|
||||
const status = typeof details.status === "string" ? details.status : event.status;
|
||||
const question = typeof details.question === "string" ? details.question : "request";
|
||||
const collectors = Array.isArray(details.collectors) ? details.collectors.join(", ") : "";
|
||||
const configSummary = typeof details.config_summary === "string" ? details.config_summary : "";
|
||||
|
||||
switch (event.event_type) {
|
||||
case "collector.scan.started":
|
||||
return `Scan started: ${collectors}`;
|
||||
case "collector.scan.complete":
|
||||
return configSummary ? `${configSummary} (${totalEvents} events)` : `Scan complete (${totalEvents} events)`;
|
||||
case "collector.scan.cancelled":
|
||||
return configSummary ? `Scan cancelled: ${configSummary}` : "Scan cancelled";
|
||||
case "collector.scan.failed":
|
||||
return "Scan failed";
|
||||
case "collector.scan":
|
||||
return `Collector scan (${eventsProcessed} events)`;
|
||||
case "ingest":
|
||||
return `Ingested ${accepted} events`;
|
||||
case "plan":
|
||||
return `Plan generated (${steps} steps)`;
|
||||
case "execute":
|
||||
return `Execution ${status}`;
|
||||
case "query":
|
||||
return `Query: ${question}`;
|
||||
case "query.paths":
|
||||
return `Path query (${pathsFound} paths)`;
|
||||
default:
|
||||
// Skip collector.* events that are per-collector (redundant with scan.complete)
|
||||
if (event.event_type.startsWith("collector.") && !event.event_type.includes("scan")) {
|
||||
return null; // Filter out in component
|
||||
}
|
||||
return event.event_type.replace(/\./g, " ");
|
||||
}
|
||||
}
|
||||
|
||||
export function EventLog({ events: initialEvents, isLoading = false, limit = 15, showFilters = false, total = 0 }: EventLogProps) {
|
||||
const [events, setEvents] = useState(initialEvents);
|
||||
const [filteredTotal, setFilteredTotal] = useState(total);
|
||||
const [page, setPage] = useState(1);
|
||||
const [eventTypeFilter, setEventTypeFilter] = useState<string>("");
|
||||
const [dateRangeFilter, setDateRangeFilter] = useState<string>("all");
|
||||
const [loading, setLoading] = useState(false);
|
||||
const pageSize = showFilters ? 50 : limit;
|
||||
|
||||
// Update local events when props change
|
||||
React.useEffect(() => {
|
||||
setEvents(initialEvents);
|
||||
}, [initialEvents]);
|
||||
|
||||
const applyFilters = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
let start_date: string | undefined;
|
||||
let end_date: string | undefined;
|
||||
|
||||
if (dateRangeFilter === "24h") {
|
||||
start_date = new Date(Date.now() - 24 * 60 * 60 * 1000).toISOString();
|
||||
} else if (dateRangeFilter === "7d") {
|
||||
start_date = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000).toISOString();
|
||||
} else if (dateRangeFilter === "30d") {
|
||||
start_date = new Date(Date.now() - 30 * 24 * 60 * 60 * 1000).toISOString();
|
||||
}
|
||||
|
||||
const response = await listAuditEvents({
|
||||
page,
|
||||
page_size: pageSize,
|
||||
event_type: eventTypeFilter || undefined,
|
||||
start_date,
|
||||
end_date,
|
||||
});
|
||||
setEvents(response.events);
|
||||
setFilteredTotal(response.total);
|
||||
} catch (err) {
|
||||
console.error("Failed to fetch filtered events:", err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
if (showFilters) {
|
||||
applyFilters();
|
||||
}
|
||||
}, [page, eventTypeFilter, dateRangeFilter, showFilters]);
|
||||
|
||||
const recentLogs = showFilters ? events : [...events]
|
||||
.sort((a, b) => new Date(b.timestamp).getTime() - new Date(a.timestamp).getTime())
|
||||
.slice(0, limit);
|
||||
|
||||
const eventTypes = Array.from(new Set(initialEvents.map(e => e.event_type)));
|
||||
const totalPages = Math.ceil((showFilters ? filteredTotal : (total || events.length)) / pageSize);
|
||||
|
||||
return (
|
||||
<div className="event-log-section">
|
||||
<div className="section-title">Event Logs</div>
|
||||
|
||||
{showFilters && (
|
||||
<div style={{ marginBottom: "1rem", display: "flex", gap: "1rem", flexWrap: "wrap" }}>
|
||||
<div>
|
||||
<label style={{ marginRight: "0.5rem", fontSize: "0.9rem" }}>Event Type:</label>
|
||||
<select
|
||||
value={eventTypeFilter}
|
||||
onChange={(e) => { setEventTypeFilter(e.target.value); setPage(1); }}
|
||||
style={{ padding: "0.3rem", borderRadius: "4px", border: "1px solid #ddd" }}
|
||||
>
|
||||
<option value="">All Types</option>
|
||||
{eventTypes.map(type => (
|
||||
<option key={type} value={type}>{type}</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label style={{ marginRight: "0.5rem", fontSize: "0.9rem" }}>Date Range:</label>
|
||||
<select
|
||||
value={dateRangeFilter}
|
||||
onChange={(e) => { setDateRangeFilter(e.target.value); setPage(1); }}
|
||||
style={{ padding: "0.3rem", borderRadius: "4px", border: "1px solid #ddd" }}
|
||||
>
|
||||
<option value="all">All Time</option>
|
||||
<option value="24h">Last 24 Hours</option>
|
||||
<option value="7d">Last 7 Days</option>
|
||||
<option value="30d">Last 30 Days</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{totalPages > 1 && (
|
||||
<div style={{ marginLeft: "auto", display: "flex", alignItems: "center", gap: "0.25rem", fontSize: "0.8rem" }}>
|
||||
<button
|
||||
onClick={() => setPage(p => Math.max(1, p - 1))}
|
||||
disabled={page === 1}
|
||||
style={{ padding: "0.15rem 0.4rem", cursor: page === 1 ? "not-allowed" : "pointer", fontSize: "0.75rem" }}
|
||||
>
|
||||
←
|
||||
</button>
|
||||
<span style={{ fontSize: "0.75rem", opacity: 0.7 }}>Page {page} of {totalPages}</span>
|
||||
<button
|
||||
onClick={() => setPage(p => Math.min(totalPages, p + 1))}
|
||||
disabled={page === totalPages}
|
||||
style={{ padding: "0.15rem 0.4rem", cursor: page === totalPages ? "not-allowed" : "pointer", fontSize: "0.75rem" }}
|
||||
>
|
||||
→
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(isLoading || loading) && <div className="text-muted">Loading events...</div>}
|
||||
{!isLoading && !loading && recentLogs.length === 0 && (
|
||||
<div className="text-muted">No events yet.</div>
|
||||
)}
|
||||
{!isLoading && !loading && recentLogs.length > 0 && (
|
||||
<ul className="log-list">
|
||||
{recentLogs.map((event) => {
|
||||
const message = describeEvent(event);
|
||||
if (!message) return null; // Skip filtered events
|
||||
const time = formatRelativeTime(event.timestamp);
|
||||
return (
|
||||
<li key={event.audit_id ?? event.id ?? message} className="log-item" title={message}>
|
||||
<span className="log-icon"></span>
|
||||
<span className="log-msg">{message}</span>
|
||||
<span className="log-time">({time})</span>
|
||||
</li>
|
||||
);
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,634 @@
|
||||
import React, { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import ForceGraph2D, { type ForceGraphMethods, type NodeObject, type LinkObject } from "react-force-graph-2d";
|
||||
import {
|
||||
RefreshCw,
|
||||
Search,
|
||||
ZoomIn,
|
||||
ZoomOut,
|
||||
Maximize2,
|
||||
Target,
|
||||
X,
|
||||
ChevronRight,
|
||||
AlertTriangle,
|
||||
Wifi,
|
||||
Server,
|
||||
Network,
|
||||
Shield,
|
||||
User,
|
||||
Info,
|
||||
} from "lucide-react";
|
||||
import {
|
||||
getGraphOverview,
|
||||
type GraphOverviewEdge,
|
||||
type GraphOverviewNode,
|
||||
} from "../api";
|
||||
|
||||
interface GraphViewProps {
|
||||
showToast: (message: string, type?: "error" | "warning" | "info" | "success") => void;
|
||||
refreshTrigger?: number;
|
||||
}
|
||||
|
||||
// Extended node type for force graph
|
||||
interface GraphNode extends NodeObject {
|
||||
id: string;
|
||||
node_id: string;
|
||||
label: string;
|
||||
name: string | null;
|
||||
kind: string | null;
|
||||
metadata: Record<string, unknown>;
|
||||
// Force graph adds these
|
||||
x?: number;
|
||||
y?: number;
|
||||
vx?: number;
|
||||
vy?: number;
|
||||
}
|
||||
|
||||
interface GraphLink extends LinkObject {
|
||||
source: string | GraphNode;
|
||||
target: string | GraphNode;
|
||||
type: string;
|
||||
confidence: number | null;
|
||||
}
|
||||
|
||||
// Use CSS variables for theme-aware colors
|
||||
const getThemeColor = (varName: string): string => {
|
||||
return getComputedStyle(document.documentElement).getPropertyValue(varName).trim();
|
||||
};
|
||||
|
||||
const LABEL_COLORS: Record<string, string> = {
|
||||
Asset: "#7fbf9c",
|
||||
NetworkContainer: "#6b9bd1",
|
||||
Identity: "#d3a86a",
|
||||
Policy: "#d07a7a",
|
||||
Service: "#a78bfa",
|
||||
Vulnerability: "#e07b8f",
|
||||
};
|
||||
|
||||
const LABEL_ICONS: Record<string, React.ElementType> = {
|
||||
Asset: Server,
|
||||
NetworkContainer: Network,
|
||||
Identity: User,
|
||||
Policy: Shield,
|
||||
Service: Wifi,
|
||||
Vulnerability: AlertTriangle,
|
||||
};
|
||||
|
||||
const getNodeColor = (label: string): string => {
|
||||
return LABEL_COLORS[label] || getThemeColor("--muted") || "#9aa3b2";
|
||||
};
|
||||
|
||||
const formatLabel = (label: string): string =>
|
||||
label.replace(/([a-z])([A-Z])/g, "$1 $2").replace(/_/g, " ");
|
||||
|
||||
const formatMetadataValue = (key: string, value: unknown): string => {
|
||||
if (Array.isArray(value)) {
|
||||
if (key === "ports" && value.length > 0 && typeof value[0] === "object") {
|
||||
// Filter to only show open ports
|
||||
const openPorts = value.filter((p: any) => p.state === "open");
|
||||
if (openPorts.length === 0) {
|
||||
return "No open ports";
|
||||
}
|
||||
return openPorts
|
||||
.map((p: any) => `${p.port}/${p.service || "unknown"}`)
|
||||
.join(", ");
|
||||
}
|
||||
return value.slice(0, 5).join(", ") + (value.length > 5 ? ` +${value.length - 5} more` : "");
|
||||
}
|
||||
if (typeof value === "object" && value !== null) {
|
||||
return JSON.stringify(value).slice(0, 50);
|
||||
}
|
||||
return String(value);
|
||||
};
|
||||
|
||||
export function GraphView({ showToast, refreshTrigger }: GraphViewProps) {
|
||||
const [rawGraph, setRawGraph] = useState<{ nodes: GraphOverviewNode[]; edges: GraphOverviewEdge[] } | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Filters
|
||||
const [search, setSearch] = useState("");
|
||||
const [labelFilter, setLabelFilter] = useState<Record<string, boolean>>({});
|
||||
const [edgeFilter, setEdgeFilter] = useState<Record<string, boolean>>({});
|
||||
|
||||
// Selection & interaction
|
||||
const [selectedNode, setSelectedNode] = useState<GraphNode | null>(null);
|
||||
const [hoveredNode, setHoveredNode] = useState<GraphNode | null>(null);
|
||||
const [highlightedNodes, setHighlightedNodes] = useState<Set<string>>(new Set());
|
||||
const [highlightedLinks, setHighlightedLinks] = useState<Set<string>>(new Set());
|
||||
|
||||
// Graph ref
|
||||
const graphRef = useRef<ForceGraphMethods<GraphNode, GraphLink>>();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const [dimensions, setDimensions] = useState({ width: 800, height: 600 });
|
||||
|
||||
// Load graph data
|
||||
const loadGraph = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const data = await getGraphOverview({ node_limit: 500, edge_limit: 1500 });
|
||||
setRawGraph(data);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to load graph data";
|
||||
setError(message);
|
||||
showToast(`Graph error: ${message}`, "error");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [showToast]);
|
||||
|
||||
useEffect(() => {
|
||||
loadGraph();
|
||||
}, [loadGraph]);
|
||||
|
||||
// Reload graph when data refreshes (e.g., after scan completes)
|
||||
useEffect(() => {
|
||||
if (refreshTrigger && refreshTrigger > 0) {
|
||||
loadGraph();
|
||||
}
|
||||
}, [refreshTrigger, loadGraph]);
|
||||
|
||||
// Handle container resize
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return;
|
||||
|
||||
const observer = new ResizeObserver((entries) => {
|
||||
for (const entry of entries) {
|
||||
const { width, height } = entry.contentRect;
|
||||
if (width && height) {
|
||||
setDimensions({ width, height });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
observer.observe(containerRef.current);
|
||||
return () => observer.disconnect();
|
||||
}, []);
|
||||
|
||||
// Extract unique labels and relationship types
|
||||
const labels = useMemo(() => {
|
||||
const set = new Set<string>();
|
||||
rawGraph?.nodes.forEach((n) => set.add(n.label));
|
||||
return Array.from(set).sort();
|
||||
}, [rawGraph]);
|
||||
|
||||
const relationshipTypes = useMemo(() => {
|
||||
const set = new Set<string>();
|
||||
rawGraph?.edges.forEach((e) => set.add(e.type));
|
||||
return Array.from(set).sort();
|
||||
}, [rawGraph]);
|
||||
|
||||
// Initialize filters when data loads
|
||||
useEffect(() => {
|
||||
if (!rawGraph) return;
|
||||
setLabelFilter((prev) => {
|
||||
const next: Record<string, boolean> = {};
|
||||
labels.forEach((l) => (next[l] = prev[l] ?? true));
|
||||
return next;
|
||||
});
|
||||
setEdgeFilter((prev) => {
|
||||
const next: Record<string, boolean> = {};
|
||||
relationshipTypes.forEach((r) => (next[r] = prev[r] ?? true));
|
||||
return next;
|
||||
});
|
||||
}, [rawGraph, labels, relationshipTypes]);
|
||||
|
||||
// Build filtered graph for force layout
|
||||
const graphData = useMemo(() => {
|
||||
if (!rawGraph) return { nodes: [], links: [] };
|
||||
|
||||
const normalizedSearch = search.trim().toLowerCase();
|
||||
|
||||
const filteredNodes: GraphNode[] = rawGraph.nodes
|
||||
.filter((n) => {
|
||||
if (labelFilter[n.label] === false) return false;
|
||||
if (!normalizedSearch) return true;
|
||||
const haystack = `${n.name ?? ""} ${n.node_id}`.toLowerCase();
|
||||
return haystack.includes(normalizedSearch);
|
||||
})
|
||||
.map((n) => ({
|
||||
id: n.node_id,
|
||||
node_id: n.node_id,
|
||||
label: n.label,
|
||||
name: n.name ?? null,
|
||||
kind: n.kind ?? null,
|
||||
metadata: n.metadata ?? {},
|
||||
}));
|
||||
|
||||
const nodeIds = new Set(filteredNodes.map((n) => n.id));
|
||||
|
||||
const filteredLinks: GraphLink[] = rawGraph.edges
|
||||
.filter(
|
||||
(e) =>
|
||||
edgeFilter[e.type] !== false &&
|
||||
nodeIds.has(e.source) &&
|
||||
nodeIds.has(e.target)
|
||||
)
|
||||
.map((e) => ({
|
||||
source: e.source,
|
||||
target: e.target,
|
||||
type: e.type,
|
||||
confidence: e.confidence ?? null,
|
||||
}));
|
||||
|
||||
return { nodes: filteredNodes, links: filteredLinks };
|
||||
}, [rawGraph, search, labelFilter, edgeFilter]);
|
||||
|
||||
// Build adjacency for highlighting
|
||||
const adjacency = useMemo(() => {
|
||||
const map = new Map<string, Set<string>>();
|
||||
graphData.links.forEach((link) => {
|
||||
const sourceId = typeof link.source === "string" ? link.source : link.source.id;
|
||||
const targetId = typeof link.target === "string" ? link.target : link.target.id;
|
||||
if (!map.has(sourceId)) map.set(sourceId, new Set());
|
||||
if (!map.has(targetId)) map.set(targetId, new Set());
|
||||
map.get(sourceId)!.add(targetId);
|
||||
map.get(targetId)!.add(sourceId);
|
||||
});
|
||||
return map;
|
||||
}, [graphData.links]);
|
||||
|
||||
// Highlight neighbors on hover
|
||||
useEffect(() => {
|
||||
if (!hoveredNode && !selectedNode) {
|
||||
setHighlightedNodes(new Set());
|
||||
setHighlightedLinks(new Set());
|
||||
return;
|
||||
}
|
||||
|
||||
const focusNode = hoveredNode || selectedNode;
|
||||
if (!focusNode) return;
|
||||
|
||||
const neighbors = adjacency.get(focusNode.id) ?? new Set<string>();
|
||||
const nodeSet = new Set([focusNode.id, ...neighbors]);
|
||||
|
||||
const linkSet = new Set<string>();
|
||||
graphData.links.forEach((link) => {
|
||||
const sourceId = typeof link.source === "string" ? link.source : link.source.id;
|
||||
const targetId = typeof link.target === "string" ? link.target : link.target.id;
|
||||
if (sourceId === focusNode.id || targetId === focusNode.id) {
|
||||
linkSet.add(`${sourceId}-${targetId}`);
|
||||
}
|
||||
});
|
||||
|
||||
setHighlightedNodes(nodeSet);
|
||||
setHighlightedLinks(linkSet);
|
||||
}, [hoveredNode, selectedNode, adjacency, graphData.links]);
|
||||
|
||||
// Node rendering
|
||||
const paintNode = useCallback(
|
||||
(node: GraphNode, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||
const { x, y, label, name, id } = node;
|
||||
if (x === undefined || y === undefined) return;
|
||||
|
||||
const isHighlighted = highlightedNodes.has(id);
|
||||
const isDimmed = highlightedNodes.size > 0 && !isHighlighted;
|
||||
const isSelected = selectedNode?.id === id;
|
||||
|
||||
const baseRadius = 6;
|
||||
const radius = isSelected ? baseRadius * 1.5 : baseRadius;
|
||||
const color = getNodeColor(label);
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, radius, 0, 2 * Math.PI);
|
||||
ctx.fillStyle = isDimmed ? `${color}40` : color;
|
||||
ctx.fill();
|
||||
|
||||
// Border for selected node
|
||||
if (isSelected) {
|
||||
ctx.strokeStyle = getThemeColor("--text") || "#e6edf3";
|
||||
ctx.lineWidth = 2;
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// Label for selected node or when zoomed in
|
||||
if ((isSelected || globalScale > 1.5) && name) {
|
||||
ctx.font = `${11 / globalScale}px sans-serif`;
|
||||
const textColor = getThemeColor("--text") || "#e6edf3";
|
||||
ctx.fillStyle = isDimmed ? `${textColor}40` : textColor;
|
||||
ctx.textAlign = "center";
|
||||
ctx.fillText(name, x, y + radius + 10 / globalScale);
|
||||
}
|
||||
},
|
||||
[highlightedNodes, selectedNode]
|
||||
);
|
||||
|
||||
// Link rendering
|
||||
const paintLink = useCallback(
|
||||
(link: GraphLink, ctx: CanvasRenderingContext2D, globalScale: number) => {
|
||||
const source = link.source as GraphNode;
|
||||
const target = link.target as GraphNode;
|
||||
if (!source.x || !source.y || !target.x || !target.y) return;
|
||||
|
||||
const linkId = `${source.id}-${target.id}`;
|
||||
const isHighlighted = highlightedLinks.has(linkId);
|
||||
const isDimmed = highlightedLinks.size > 0 && !isHighlighted;
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(source.x, source.y);
|
||||
ctx.lineTo(target.x, target.y);
|
||||
const linkColor = getThemeColor("--muted") || "#9aa3b2";
|
||||
ctx.strokeStyle = isDimmed ? `${linkColor}10` : `${linkColor}50`;
|
||||
ctx.lineWidth = isDimmed ? 0.5 : 1;
|
||||
ctx.stroke();
|
||||
|
||||
// Arrow
|
||||
if (isHighlighted || globalScale > 2) {
|
||||
const angle = Math.atan2(target.y - source.y, target.x - source.x);
|
||||
const arrowLength = 6 / globalScale;
|
||||
const endX = target.x - Math.cos(angle) * 8;
|
||||
const endY = target.y - Math.sin(angle) * 8;
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(endX, endY);
|
||||
ctx.lineTo(
|
||||
endX - arrowLength * Math.cos(angle - Math.PI / 6),
|
||||
endY - arrowLength * Math.sin(angle - Math.PI / 6)
|
||||
);
|
||||
ctx.lineTo(
|
||||
endX - arrowLength * Math.cos(angle + Math.PI / 6),
|
||||
endY - arrowLength * Math.sin(angle + Math.PI / 6)
|
||||
);
|
||||
ctx.closePath();
|
||||
ctx.fillStyle = `${linkColor}80`;
|
||||
ctx.fill();
|
||||
}
|
||||
},
|
||||
[highlightedLinks]
|
||||
);
|
||||
|
||||
// Zoom controls
|
||||
const handleZoomIn = () => graphRef.current?.zoom(graphRef.current.zoom() * 1.5, 300);
|
||||
const handleZoomOut = () => graphRef.current?.zoom(graphRef.current.zoom() / 1.5, 300);
|
||||
const handleFitView = () => graphRef.current?.zoomToFit(400, 50);
|
||||
const handleCenterOnNode = (node: GraphNode) => {
|
||||
graphRef.current?.centerAt(node.x, node.y, 500);
|
||||
graphRef.current?.zoom(2.5, 500);
|
||||
};
|
||||
|
||||
// Statistics
|
||||
const stats = useMemo(() => ({
|
||||
nodes: graphData.nodes.length,
|
||||
edges: graphData.links.length,
|
||||
labels: labels.filter((l) => labelFilter[l] !== false).length,
|
||||
relationships: relationshipTypes.filter((r) => edgeFilter[r] !== false).length,
|
||||
}), [graphData, labels, labelFilter, relationshipTypes, edgeFilter]);
|
||||
|
||||
const labelCounts = useMemo(() => {
|
||||
const counts: Record<string, number> = {};
|
||||
graphData.nodes.forEach((n) => {
|
||||
counts[n.label] = (counts[n.label] ?? 0) + 1;
|
||||
});
|
||||
return counts;
|
||||
}, [graphData.nodes]);
|
||||
|
||||
const relationshipCounts = useMemo(() => {
|
||||
const counts: Record<string, number> = {};
|
||||
graphData.links.forEach((l) => {
|
||||
counts[l.type] = (counts[l.type] ?? 0) + 1;
|
||||
});
|
||||
return counts;
|
||||
}, [graphData.links]);
|
||||
|
||||
return (
|
||||
<div className="graph-view">
|
||||
<div className="section-title">Knowledge Graph</div>
|
||||
|
||||
{/* Stats bar */}
|
||||
<div className="graph-stats">
|
||||
<div className="graph-stat">
|
||||
<span className="graph-stat-value">{stats.nodes}</span>
|
||||
<span className="graph-stat-label">Nodes</span>
|
||||
</div>
|
||||
<div className="graph-stat">
|
||||
<span className="graph-stat-value">{stats.edges}</span>
|
||||
<span className="graph-stat-label">Edges</span>
|
||||
</div>
|
||||
<div className="graph-stat">
|
||||
<span className="graph-stat-value">{stats.labels}</span>
|
||||
<span className="graph-stat-label">Types</span>
|
||||
</div>
|
||||
<div className="graph-stat">
|
||||
<span className="graph-stat-value">{stats.relationships}</span>
|
||||
<span className="graph-stat-label">Relations</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main content */}
|
||||
<div className="graph-main">
|
||||
{/* Filters bar */}
|
||||
<div className="graph-filters">
|
||||
<div className="graph-filter-section">
|
||||
<label className="graph-search-box">
|
||||
<Search size={14} />
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search nodes..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
/>
|
||||
{search && (
|
||||
<button onClick={() => setSearch("")} className="graph-search-clear">
|
||||
<X size={12} />
|
||||
</button>
|
||||
)}
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="graph-filter-section">
|
||||
<div className="graph-filter-header">Node Types</div>
|
||||
<div className="graph-filter-list">
|
||||
{labels.map((label) => {
|
||||
const Icon = LABEL_ICONS[label] || Server;
|
||||
const count = labelCounts[label] ?? 0;
|
||||
const isActive = labelFilter[label] !== false;
|
||||
return (
|
||||
<button
|
||||
key={label}
|
||||
className={`graph-filter-item ${isActive ? "active" : ""}`}
|
||||
onClick={() =>
|
||||
setLabelFilter((p) => ({ ...p, [label]: !isActive }))
|
||||
}
|
||||
>
|
||||
<span
|
||||
className="graph-filter-dot"
|
||||
style={{ backgroundColor: getNodeColor(label) }}
|
||||
/>
|
||||
<Icon size={14} />
|
||||
<span className="graph-filter-name">{formatLabel(label)}</span>
|
||||
<span className="graph-filter-count">{count}</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
{labels.length === 0 && (
|
||||
<div className="graph-filter-empty">No node types</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="graph-filter-section">
|
||||
<div className="graph-filter-header">Relationships</div>
|
||||
<div className="graph-filter-list">
|
||||
{relationshipTypes.map((rel) => {
|
||||
const count = relationshipCounts[rel] ?? 0;
|
||||
const isActive = edgeFilter[rel] !== false;
|
||||
return (
|
||||
<button
|
||||
key={rel}
|
||||
className={`graph-filter-item ${isActive ? "active" : ""}`}
|
||||
onClick={() =>
|
||||
setEdgeFilter((p) => ({ ...p, [rel]: !isActive }))
|
||||
}
|
||||
>
|
||||
<ChevronRight size={14} />
|
||||
<span className="graph-filter-name">{formatLabel(rel)}</span>
|
||||
<span className="graph-filter-count">{count}</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
{relationshipTypes.length === 0 && (
|
||||
<div className="graph-filter-empty">No relationships</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Graph and details */}
|
||||
<div className="graph-content">
|
||||
{/* Graph canvas */}
|
||||
<div className="graph-canvas-container" ref={containerRef}>
|
||||
{isLoading && (
|
||||
<div className="graph-overlay">
|
||||
<div className="graph-loading">Loading graph...</div>
|
||||
</div>
|
||||
)}
|
||||
{error && (
|
||||
<div className="graph-overlay">
|
||||
<div className="graph-error">{error}</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && !error && graphData.nodes.length === 0 && (
|
||||
<div className="graph-overlay">
|
||||
<div className="graph-empty-state">
|
||||
<Network size={48} />
|
||||
<p>No graph data yet</p>
|
||||
<span>Run a network scan to populate the graph</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!isLoading && !error && graphData.nodes.length > 0 && (
|
||||
<ForceGraph2D
|
||||
ref={graphRef}
|
||||
graphData={graphData}
|
||||
width={dimensions.width}
|
||||
height={dimensions.height}
|
||||
nodeCanvasObject={paintNode}
|
||||
linkCanvasObject={paintLink}
|
||||
onNodeClick={(node) => {
|
||||
setSelectedNode(node as GraphNode);
|
||||
}}
|
||||
onNodeHover={(node) => setHoveredNode(node as GraphNode | null)}
|
||||
onBackgroundClick={() => setSelectedNode(null)}
|
||||
nodeLabel={(node) => `${(node as GraphNode).name || (node as GraphNode).id}`}
|
||||
linkDirectionalArrowLength={0}
|
||||
cooldownTicks={100}
|
||||
d3AlphaDecay={0.02}
|
||||
d3VelocityDecay={0.3}
|
||||
enableNodeDrag={true}
|
||||
enableZoomInteraction={true}
|
||||
enablePanInteraction={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Zoom controls */}
|
||||
<div className="graph-zoom-controls">
|
||||
<button onClick={handleZoomIn} title="Zoom in">
|
||||
<ZoomIn size={18} />
|
||||
</button>
|
||||
<button onClick={handleZoomOut} title="Zoom out">
|
||||
<ZoomOut size={18} />
|
||||
</button>
|
||||
<button onClick={handleFitView} title="Fit to view">
|
||||
<Maximize2 size={18} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right sidebar - Node details */}
|
||||
<div className="graph-details">
|
||||
<div className="graph-details-header">
|
||||
<Info size={16} />
|
||||
<span>Node Details</span>
|
||||
</div>
|
||||
{!selectedNode && (
|
||||
<div className="graph-details-empty">
|
||||
<p>Select a node to view details</p>
|
||||
<span>Click on any node in the graph</span>
|
||||
</div>
|
||||
)}
|
||||
{selectedNode && (
|
||||
<div className="graph-details-content">
|
||||
<div className="graph-details-title">
|
||||
<span
|
||||
className="graph-details-dot"
|
||||
style={{ backgroundColor: getNodeColor(selectedNode.label) }}
|
||||
/>
|
||||
<span>{selectedNode.name || `Unnamed ${selectedNode.label}`}</span>
|
||||
</div>
|
||||
|
||||
<div className="graph-details-row">
|
||||
<span className="graph-details-label">Type</span>
|
||||
<span className="graph-details-value">
|
||||
{formatLabel(selectedNode.label)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{selectedNode.kind && (
|
||||
<div className="graph-details-row">
|
||||
<span className="graph-details-label">Kind</span>
|
||||
<span className="graph-details-value">{selectedNode.kind}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="graph-details-row">
|
||||
<span className="graph-details-label">ID</span>
|
||||
<span className="graph-details-value graph-details-id">
|
||||
{selectedNode.id}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="graph-details-row">
|
||||
<span className="graph-details-label">Connections</span>
|
||||
<span className="graph-details-value">
|
||||
{adjacency.get(selectedNode.id)?.size ?? 0} neighbors
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{Object.keys(selectedNode.metadata).length > 0 && (
|
||||
<>
|
||||
<div className="graph-details-divider" />
|
||||
<div className="graph-details-section-title">Metadata</div>
|
||||
{Object.entries(selectedNode.metadata).map(([key, value]) => (
|
||||
<div key={key} className="graph-details-row">
|
||||
<span className="graph-details-label">{key}</span>
|
||||
<span className="graph-details-value" title={String(value)}>
|
||||
{formatMetadataValue(key, value)}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="graph-details-actions">
|
||||
<button onClick={() => handleCenterOnNode(selectedNode)}>
|
||||
<Target size={14} /> Center
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user