mirror of
https://github.com/tiennm99/goclaw.git
synced 2026-06-10 08:11:23 +00:00
1b99406012
The embedding provider resolution only matched 3 hardcoded names (openai, openrouter, gemini), silently failing for DB-stored providers like "openai-embedding". This caused memory chunks to be stored without vectors even when a valid embedding provider was configured. Changes: - resolveEmbeddingProvider: fallback to provider registry for DB-stored provider names when hardcoded match fails - gateway startup: read per-agent memory config from DB (priority over config file defaults) for embedding provider resolution - memory IndexDocument: log embedding errors instead of swallowing them - memory admin ListChunks: return full chunk text instead of truncating to 200 chars, avoiding confusing partial content in the UI Co-authored-by: Luvu182 <208665161+Luvu182@users.noreply.github.com>
328 lines
8.2 KiB
Go
328 lines
8.2 KiB
Go
package pg
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/nextlevelbuilder/goclaw/internal/memory"
|
|
"github.com/nextlevelbuilder/goclaw/internal/store"
|
|
)
|
|
|
|
// PGMemoryStore implements store.MemoryStore backed by Postgres.
|
|
type PGMemoryStore struct {
|
|
db *sql.DB
|
|
provider store.EmbeddingProvider
|
|
cfg PGMemoryConfig
|
|
}
|
|
|
|
// PGMemoryConfig configures the PG memory store.
|
|
type PGMemoryConfig struct {
|
|
MaxChunkLen int
|
|
MaxResults int
|
|
VectorWeight float64
|
|
TextWeight float64
|
|
}
|
|
|
|
// DefaultPGMemoryConfig returns sensible defaults.
|
|
func DefaultPGMemoryConfig() PGMemoryConfig {
|
|
return PGMemoryConfig{
|
|
MaxChunkLen: 1000,
|
|
MaxResults: 6,
|
|
VectorWeight: 0.7,
|
|
TextWeight: 0.3,
|
|
}
|
|
}
|
|
|
|
func NewPGMemoryStore(db *sql.DB, cfg PGMemoryConfig) *PGMemoryStore {
|
|
return &PGMemoryStore{db: db, cfg: cfg}
|
|
}
|
|
|
|
func (s *PGMemoryStore) GetDocument(ctx context.Context, agentID, userID, path string) (string, error) {
|
|
aid := mustParseUUID(agentID)
|
|
var content string
|
|
|
|
var err error
|
|
if userID == "" {
|
|
err = s.db.QueryRowContext(ctx,
|
|
"SELECT content FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
|
|
aid, path).Scan(&content)
|
|
} else {
|
|
err = s.db.QueryRowContext(ctx,
|
|
"SELECT content FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
|
|
aid, path, userID).Scan(&content)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return content, nil
|
|
}
|
|
|
|
func (s *PGMemoryStore) PutDocument(ctx context.Context, agentID, userID, path, content string) error {
|
|
aid := mustParseUUID(agentID)
|
|
hash := memory.ContentHash(content)
|
|
id := uuid.Must(uuid.NewV7())
|
|
now := time.Now()
|
|
|
|
var uid *string
|
|
if userID != "" {
|
|
uid = &userID
|
|
}
|
|
|
|
_, err := s.db.ExecContext(ctx,
|
|
`INSERT INTO memory_documents (id, agent_id, user_id, path, content, hash, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
ON CONFLICT (agent_id, COALESCE(user_id, ''), path)
|
|
DO UPDATE SET content = EXCLUDED.content, hash = EXCLUDED.hash, updated_at = EXCLUDED.updated_at`,
|
|
id, aid, uid, path, content, hash, now,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *PGMemoryStore) DeleteDocument(ctx context.Context, agentID, userID, path string) error {
|
|
aid := mustParseUUID(agentID)
|
|
if userID == "" {
|
|
_, err := s.db.ExecContext(ctx,
|
|
"DELETE FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
|
|
aid, path)
|
|
return err
|
|
}
|
|
_, err := s.db.ExecContext(ctx,
|
|
"DELETE FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
|
|
aid, path, userID)
|
|
return err
|
|
}
|
|
|
|
func (s *PGMemoryStore) ListDocuments(ctx context.Context, agentID, userID string) ([]store.DocumentInfo, error) {
|
|
aid := mustParseUUID(agentID)
|
|
|
|
var rows *sql.Rows
|
|
var err error
|
|
if userID == "" {
|
|
rows, err = s.db.QueryContext(ctx,
|
|
"SELECT path, hash, user_id, updated_at FROM memory_documents WHERE agent_id = $1 AND user_id IS NULL", aid)
|
|
} else {
|
|
rows, err = s.db.QueryContext(ctx,
|
|
"SELECT path, hash, user_id, updated_at FROM memory_documents WHERE agent_id = $1 AND (user_id IS NULL OR user_id = $2)", aid, userID)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var result []store.DocumentInfo
|
|
for rows.Next() {
|
|
var path, hash string
|
|
var uid *string
|
|
var updatedAt time.Time
|
|
if err := rows.Scan(&path, &hash, &uid, &updatedAt); err != nil {
|
|
continue
|
|
}
|
|
info := store.DocumentInfo{
|
|
Path: path,
|
|
Hash: hash,
|
|
UpdatedAt: updatedAt.UnixMilli(),
|
|
}
|
|
if uid != nil {
|
|
info.UserID = *uid
|
|
}
|
|
result = append(result, info)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// IndexDocument chunks a document and stores chunks with embeddings.
|
|
func (s *PGMemoryStore) IndexDocument(ctx context.Context, agentID, userID, path string) error {
|
|
aid := mustParseUUID(agentID)
|
|
|
|
// Get document content
|
|
content, err := s.GetDocument(ctx, agentID, userID, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get document ID
|
|
var docID uuid.UUID
|
|
if userID == "" {
|
|
err = s.db.QueryRowContext(ctx,
|
|
"SELECT id FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
|
|
aid, path).Scan(&docID)
|
|
} else {
|
|
err = s.db.QueryRowContext(ctx,
|
|
"SELECT id FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
|
|
aid, path, userID).Scan(&docID)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Delete old chunks
|
|
s.db.ExecContext(ctx, "DELETE FROM memory_chunks WHERE document_id = $1", docID)
|
|
|
|
// Chunk text
|
|
chunks := memory.ChunkText(content, s.cfg.MaxChunkLen)
|
|
if len(chunks) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Generate embeddings
|
|
var embeddings [][]float32
|
|
if s.provider != nil {
|
|
texts := make([]string, len(chunks))
|
|
for i, c := range chunks {
|
|
texts[i] = c.Text
|
|
}
|
|
var embErr error
|
|
embeddings, embErr = s.provider.Embed(ctx, texts)
|
|
if embErr != nil {
|
|
slog.Warn("memory embedding failed, storing chunks without vectors",
|
|
"path", path, "chunks", len(chunks), "error", embErr)
|
|
}
|
|
}
|
|
|
|
// Insert chunks
|
|
for i, tc := range chunks {
|
|
hash := memory.ContentHash(tc.Text)
|
|
chunkID := uuid.Must(uuid.NewV7())
|
|
now := time.Now()
|
|
|
|
var uid *string
|
|
if userID != "" {
|
|
uid = &userID
|
|
}
|
|
|
|
if embeddings != nil && i < len(embeddings) {
|
|
// Insert with embedding via raw SQL (pgvector)
|
|
s.db.ExecContext(ctx,
|
|
`INSERT INTO memory_chunks (id, agent_id, document_id, user_id, path, start_line, end_line, hash, text, embedding, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11)`,
|
|
chunkID, aid, docID, uid, path, tc.StartLine, tc.EndLine, hash, tc.Text,
|
|
vectorToString(embeddings[i]), now,
|
|
)
|
|
} else {
|
|
s.db.ExecContext(ctx,
|
|
`INSERT INTO memory_chunks (id, agent_id, document_id, user_id, path, start_line, end_line, hash, text, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
ON CONFLICT DO NOTHING`,
|
|
chunkID, aid, docID, uid, path, tc.StartLine, tc.EndLine, hash, tc.Text, now,
|
|
)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *PGMemoryStore) IndexAll(ctx context.Context, agentID, userID string) error {
|
|
docs, err := s.ListDocuments(ctx, agentID, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, doc := range docs {
|
|
s.IndexDocument(ctx, agentID, doc.UserID, doc.Path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *PGMemoryStore) SetEmbeddingProvider(provider store.EmbeddingProvider) {
|
|
s.provider = provider
|
|
}
|
|
|
|
// BackfillEmbeddings finds all chunks without embeddings and generates them.
|
|
// Processes in batches to avoid memory spikes. Safe to call multiple times.
|
|
func (s *PGMemoryStore) BackfillEmbeddings(ctx context.Context) (int, error) {
|
|
if s.provider == nil {
|
|
return 0, fmt.Errorf("no embedding provider configured")
|
|
}
|
|
|
|
const batchSize = 50
|
|
total := 0
|
|
|
|
for {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
"SELECT id, text FROM memory_chunks WHERE embedding IS NULL ORDER BY id ASC LIMIT $1", batchSize)
|
|
if err != nil {
|
|
return total, fmt.Errorf("query chunks without embeddings: %w", err)
|
|
}
|
|
|
|
type chunkRow struct {
|
|
ID uuid.UUID
|
|
Text string
|
|
}
|
|
var chunks []chunkRow
|
|
for rows.Next() {
|
|
var c chunkRow
|
|
if err := rows.Scan(&c.ID, &c.Text); err != nil {
|
|
continue
|
|
}
|
|
chunks = append(chunks, c)
|
|
}
|
|
rows.Close()
|
|
|
|
if len(chunks) == 0 {
|
|
break
|
|
}
|
|
|
|
texts := make([]string, len(chunks))
|
|
for i, c := range chunks {
|
|
texts[i] = c.Text
|
|
}
|
|
|
|
embeddings, err := s.provider.Embed(ctx, texts)
|
|
if err != nil {
|
|
return total, fmt.Errorf("generate embeddings: %w", err)
|
|
}
|
|
|
|
for i, chunk := range chunks {
|
|
if i >= len(embeddings) {
|
|
break
|
|
}
|
|
vecStr := vectorToString(embeddings[i])
|
|
if _, err := s.db.ExecContext(ctx,
|
|
"UPDATE memory_chunks SET embedding = $1::vector WHERE id = $2",
|
|
vecStr, chunk.ID,
|
|
); err != nil {
|
|
return total, fmt.Errorf("update chunk embedding id=%s: %w", chunk.ID, err)
|
|
}
|
|
total++
|
|
}
|
|
|
|
if len(chunks) < batchSize {
|
|
break
|
|
}
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func (s *PGMemoryStore) Close() error { return nil }
|
|
|
|
// --- Helpers ---
|
|
|
|
func mustParseUUID(s string) uuid.UUID {
|
|
id, err := uuid.Parse(s)
|
|
if err != nil {
|
|
return uuid.Nil
|
|
}
|
|
return id
|
|
}
|
|
|
|
func vectorToString(v []float32) string {
|
|
if len(v) == 0 {
|
|
return ""
|
|
}
|
|
buf := make([]byte, 0, len(v)*10)
|
|
buf = append(buf, '[')
|
|
for i, f := range v {
|
|
if i > 0 {
|
|
buf = append(buf, ',')
|
|
}
|
|
buf = append(buf, fmt.Appendf(nil, "%g", f)...)
|
|
}
|
|
buf = append(buf, ']')
|
|
return string(buf)
|
|
}
|