Files
goclaw/internal/store/pg/memory_docs.go
T
Luan Vu 1b99406012 fix: resolve embedding provider from DB registry + per-agent config (#134)
The embedding provider resolution only matched 3 hardcoded names
(openai, openrouter, gemini), silently failing for DB-stored providers
like "openai-embedding". This caused memory chunks to be stored
without vectors even when a valid embedding provider was configured.

Changes:
- resolveEmbeddingProvider: fallback to provider registry for DB-stored
  provider names when hardcoded match fails
- gateway startup: read per-agent memory config from DB (priority over
  config file defaults) for embedding provider resolution
- memory IndexDocument: log embedding errors instead of swallowing them
- memory admin ListChunks: return full chunk text instead of truncating
  to 200 chars, avoiding confusing partial content in the UI

Co-authored-by: Luvu182 <208665161+Luvu182@users.noreply.github.com>
2026-03-11 14:31:00 +07:00

328 lines
8.2 KiB
Go

package pg
import (
"context"
"database/sql"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"github.com/nextlevelbuilder/goclaw/internal/memory"
"github.com/nextlevelbuilder/goclaw/internal/store"
)
// PGMemoryStore implements store.MemoryStore backed by Postgres.
type PGMemoryStore struct {
db *sql.DB
provider store.EmbeddingProvider
cfg PGMemoryConfig
}
// PGMemoryConfig configures the PG memory store.
type PGMemoryConfig struct {
MaxChunkLen int
MaxResults int
VectorWeight float64
TextWeight float64
}
// DefaultPGMemoryConfig returns sensible defaults.
func DefaultPGMemoryConfig() PGMemoryConfig {
return PGMemoryConfig{
MaxChunkLen: 1000,
MaxResults: 6,
VectorWeight: 0.7,
TextWeight: 0.3,
}
}
func NewPGMemoryStore(db *sql.DB, cfg PGMemoryConfig) *PGMemoryStore {
return &PGMemoryStore{db: db, cfg: cfg}
}
func (s *PGMemoryStore) GetDocument(ctx context.Context, agentID, userID, path string) (string, error) {
aid := mustParseUUID(agentID)
var content string
var err error
if userID == "" {
err = s.db.QueryRowContext(ctx,
"SELECT content FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
aid, path).Scan(&content)
} else {
err = s.db.QueryRowContext(ctx,
"SELECT content FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
aid, path, userID).Scan(&content)
}
if err != nil {
return "", err
}
return content, nil
}
func (s *PGMemoryStore) PutDocument(ctx context.Context, agentID, userID, path, content string) error {
aid := mustParseUUID(agentID)
hash := memory.ContentHash(content)
id := uuid.Must(uuid.NewV7())
now := time.Now()
var uid *string
if userID != "" {
uid = &userID
}
_, err := s.db.ExecContext(ctx,
`INSERT INTO memory_documents (id, agent_id, user_id, path, content, hash, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (agent_id, COALESCE(user_id, ''), path)
DO UPDATE SET content = EXCLUDED.content, hash = EXCLUDED.hash, updated_at = EXCLUDED.updated_at`,
id, aid, uid, path, content, hash, now,
)
return err
}
func (s *PGMemoryStore) DeleteDocument(ctx context.Context, agentID, userID, path string) error {
aid := mustParseUUID(agentID)
if userID == "" {
_, err := s.db.ExecContext(ctx,
"DELETE FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
aid, path)
return err
}
_, err := s.db.ExecContext(ctx,
"DELETE FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
aid, path, userID)
return err
}
func (s *PGMemoryStore) ListDocuments(ctx context.Context, agentID, userID string) ([]store.DocumentInfo, error) {
aid := mustParseUUID(agentID)
var rows *sql.Rows
var err error
if userID == "" {
rows, err = s.db.QueryContext(ctx,
"SELECT path, hash, user_id, updated_at FROM memory_documents WHERE agent_id = $1 AND user_id IS NULL", aid)
} else {
rows, err = s.db.QueryContext(ctx,
"SELECT path, hash, user_id, updated_at FROM memory_documents WHERE agent_id = $1 AND (user_id IS NULL OR user_id = $2)", aid, userID)
}
if err != nil {
return nil, err
}
defer rows.Close()
var result []store.DocumentInfo
for rows.Next() {
var path, hash string
var uid *string
var updatedAt time.Time
if err := rows.Scan(&path, &hash, &uid, &updatedAt); err != nil {
continue
}
info := store.DocumentInfo{
Path: path,
Hash: hash,
UpdatedAt: updatedAt.UnixMilli(),
}
if uid != nil {
info.UserID = *uid
}
result = append(result, info)
}
return result, nil
}
// IndexDocument chunks a document and stores chunks with embeddings.
func (s *PGMemoryStore) IndexDocument(ctx context.Context, agentID, userID, path string) error {
aid := mustParseUUID(agentID)
// Get document content
content, err := s.GetDocument(ctx, agentID, userID, path)
if err != nil {
return err
}
// Get document ID
var docID uuid.UUID
if userID == "" {
err = s.db.QueryRowContext(ctx,
"SELECT id FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id IS NULL",
aid, path).Scan(&docID)
} else {
err = s.db.QueryRowContext(ctx,
"SELECT id FROM memory_documents WHERE agent_id = $1 AND path = $2 AND user_id = $3",
aid, path, userID).Scan(&docID)
}
if err != nil {
return err
}
// Delete old chunks
s.db.ExecContext(ctx, "DELETE FROM memory_chunks WHERE document_id = $1", docID)
// Chunk text
chunks := memory.ChunkText(content, s.cfg.MaxChunkLen)
if len(chunks) == 0 {
return nil
}
// Generate embeddings
var embeddings [][]float32
if s.provider != nil {
texts := make([]string, len(chunks))
for i, c := range chunks {
texts[i] = c.Text
}
var embErr error
embeddings, embErr = s.provider.Embed(ctx, texts)
if embErr != nil {
slog.Warn("memory embedding failed, storing chunks without vectors",
"path", path, "chunks", len(chunks), "error", embErr)
}
}
// Insert chunks
for i, tc := range chunks {
hash := memory.ContentHash(tc.Text)
chunkID := uuid.Must(uuid.NewV7())
now := time.Now()
var uid *string
if userID != "" {
uid = &userID
}
if embeddings != nil && i < len(embeddings) {
// Insert with embedding via raw SQL (pgvector)
s.db.ExecContext(ctx,
`INSERT INTO memory_chunks (id, agent_id, document_id, user_id, path, start_line, end_line, hash, text, embedding, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11)`,
chunkID, aid, docID, uid, path, tc.StartLine, tc.EndLine, hash, tc.Text,
vectorToString(embeddings[i]), now,
)
} else {
s.db.ExecContext(ctx,
`INSERT INTO memory_chunks (id, agent_id, document_id, user_id, path, start_line, end_line, hash, text, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT DO NOTHING`,
chunkID, aid, docID, uid, path, tc.StartLine, tc.EndLine, hash, tc.Text, now,
)
}
}
return nil
}
func (s *PGMemoryStore) IndexAll(ctx context.Context, agentID, userID string) error {
docs, err := s.ListDocuments(ctx, agentID, userID)
if err != nil {
return err
}
for _, doc := range docs {
s.IndexDocument(ctx, agentID, doc.UserID, doc.Path)
}
return nil
}
func (s *PGMemoryStore) SetEmbeddingProvider(provider store.EmbeddingProvider) {
s.provider = provider
}
// BackfillEmbeddings finds all chunks without embeddings and generates them.
// Processes in batches to avoid memory spikes. Safe to call multiple times.
func (s *PGMemoryStore) BackfillEmbeddings(ctx context.Context) (int, error) {
if s.provider == nil {
return 0, fmt.Errorf("no embedding provider configured")
}
const batchSize = 50
total := 0
for {
rows, err := s.db.QueryContext(ctx,
"SELECT id, text FROM memory_chunks WHERE embedding IS NULL ORDER BY id ASC LIMIT $1", batchSize)
if err != nil {
return total, fmt.Errorf("query chunks without embeddings: %w", err)
}
type chunkRow struct {
ID uuid.UUID
Text string
}
var chunks []chunkRow
for rows.Next() {
var c chunkRow
if err := rows.Scan(&c.ID, &c.Text); err != nil {
continue
}
chunks = append(chunks, c)
}
rows.Close()
if len(chunks) == 0 {
break
}
texts := make([]string, len(chunks))
for i, c := range chunks {
texts[i] = c.Text
}
embeddings, err := s.provider.Embed(ctx, texts)
if err != nil {
return total, fmt.Errorf("generate embeddings: %w", err)
}
for i, chunk := range chunks {
if i >= len(embeddings) {
break
}
vecStr := vectorToString(embeddings[i])
if _, err := s.db.ExecContext(ctx,
"UPDATE memory_chunks SET embedding = $1::vector WHERE id = $2",
vecStr, chunk.ID,
); err != nil {
return total, fmt.Errorf("update chunk embedding id=%s: %w", chunk.ID, err)
}
total++
}
if len(chunks) < batchSize {
break
}
}
return total, nil
}
func (s *PGMemoryStore) Close() error { return nil }
// --- Helpers ---
func mustParseUUID(s string) uuid.UUID {
id, err := uuid.Parse(s)
if err != nil {
return uuid.Nil
}
return id
}
func vectorToString(v []float32) string {
if len(v) == 0 {
return ""
}
buf := make([]byte, 0, len(v)*10)
buf = append(buf, '[')
for i, f := range v {
if i > 0 {
buf = append(buf, ',')
}
buf = append(buf, fmt.Appendf(nil, "%g", f)...)
}
buf = append(buf, ']')
return string(buf)
}