mirror of
https://github.com/tiennm99/goclaw.git
synced 2026-06-10 08:11:23 +00:00
ce333c70f3
- Add explicit ESCAPE '\' clause to all ILIKE queries (knowledge_graph, custom_tools, channel_instances, channel_contacts) for correct wildcard escaping regardless of PostgreSQL standard_conforming_strings setting - Log warning when filterAllowedKeys drops unknown fields for debuggability - Widen base64 decode shell deny pattern to catch -di, -dw0 variants - Add unit tests for filterAllowedKeys, pqStringArray, scanStringArray, pqStringArray↔scanStringArray roundtrip, limitedBuffer, base64 deny
404 lines
12 KiB
Go
404 lines
12 KiB
Go
package pg
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/nextlevelbuilder/goclaw/internal/store"
|
|
)
|
|
|
|
// PGKnowledgeGraphStore implements store.KnowledgeGraphStore backed by Postgres.
|
|
type PGKnowledgeGraphStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewPGKnowledgeGraphStore creates a new PG-backed knowledge graph store.
|
|
func NewPGKnowledgeGraphStore(db *sql.DB) *PGKnowledgeGraphStore {
|
|
return &PGKnowledgeGraphStore{db: db}
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) UpsertEntity(ctx context.Context, entity *store.Entity) error {
|
|
aid := mustParseUUID(entity.AgentID)
|
|
props, err := json.Marshal(entity.Properties)
|
|
if err != nil {
|
|
props = []byte("{}")
|
|
}
|
|
now := time.Now()
|
|
id := uuid.Must(uuid.NewV7())
|
|
_, err = s.db.ExecContext(ctx, `
|
|
INSERT INTO kg_entities
|
|
(id, agent_id, user_id, external_id, name, entity_type, description, properties, source_id, confidence, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11)
|
|
ON CONFLICT (agent_id, user_id, external_id) DO UPDATE SET
|
|
name = EXCLUDED.name,
|
|
entity_type = EXCLUDED.entity_type,
|
|
description = EXCLUDED.description,
|
|
properties = EXCLUDED.properties,
|
|
source_id = EXCLUDED.source_id,
|
|
confidence = EXCLUDED.confidence,
|
|
updated_at = EXCLUDED.updated_at`,
|
|
id, aid, entity.UserID, entity.ExternalID, entity.Name, entity.EntityType,
|
|
entity.Description, props, entity.SourceID, entity.Confidence, now,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) GetEntity(ctx context.Context, agentID, userID, entityID string) (*store.Entity, error) {
|
|
aid := mustParseUUID(agentID)
|
|
eid := mustParseUUID(entityID)
|
|
row := s.db.QueryRowContext(ctx, `
|
|
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
|
|
properties, source_id, confidence, created_at, updated_at
|
|
FROM kg_entities WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
|
|
eid, aid, userID,
|
|
)
|
|
return scanEntity(row)
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) DeleteEntity(ctx context.Context, agentID, userID, entityID string) error {
|
|
aid := mustParseUUID(agentID)
|
|
eid := mustParseUUID(entityID)
|
|
_, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM kg_entities WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
|
|
eid, aid, userID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) ListEntities(ctx context.Context, agentID, userID string, opts store.EntityListOptions) ([]store.Entity, error) {
|
|
aid := mustParseUUID(agentID)
|
|
limit := opts.Limit
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
|
|
// Build dynamic WHERE clause: always filter by agent_id, optionally by user_id and entity_type
|
|
where := "agent_id = $1"
|
|
args := []any{aid}
|
|
idx := 2
|
|
if userID != "" {
|
|
where += fmt.Sprintf(" AND user_id = $%d", idx)
|
|
args = append(args, userID)
|
|
idx++
|
|
}
|
|
if opts.EntityType != "" {
|
|
where += fmt.Sprintf(" AND entity_type = $%d", idx)
|
|
args = append(args, opts.EntityType)
|
|
idx++
|
|
}
|
|
args = append(args, limit, opts.Offset)
|
|
query := fmt.Sprintf(`
|
|
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
|
|
properties, source_id, confidence, created_at, updated_at
|
|
FROM kg_entities WHERE %s
|
|
ORDER BY updated_at DESC LIMIT $%d OFFSET $%d`, where, idx, idx+1)
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanEntities(rows)
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) SearchEntities(ctx context.Context, agentID, userID, query string, limit int) ([]store.Entity, error) {
|
|
aid := mustParseUUID(agentID)
|
|
if limit <= 0 {
|
|
limit = 20
|
|
}
|
|
// Escape LIKE wildcards to prevent pattern injection.
|
|
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(query)
|
|
pattern := "%" + escaped + "%"
|
|
|
|
where := "agent_id = $1"
|
|
args := []any{aid}
|
|
idx := 2
|
|
if userID != "" {
|
|
where += fmt.Sprintf(" AND user_id = $%d", idx)
|
|
args = append(args, userID)
|
|
idx++
|
|
}
|
|
args = append(args, pattern, limit)
|
|
q := fmt.Sprintf(`
|
|
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
|
|
properties, source_id, confidence, created_at, updated_at
|
|
FROM kg_entities
|
|
WHERE %s AND (name ILIKE $%d ESCAPE '\' OR description ILIKE $%d ESCAPE '\')
|
|
ORDER BY confidence DESC, updated_at DESC LIMIT $%d`, where, idx, idx, idx+1)
|
|
|
|
rows, err := s.db.QueryContext(ctx, q, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanEntities(rows)
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) UpsertRelation(ctx context.Context, relation *store.Relation) error {
|
|
aid := mustParseUUID(relation.AgentID)
|
|
src := mustParseUUID(relation.SourceEntityID)
|
|
tgt := mustParseUUID(relation.TargetEntityID)
|
|
props, err := json.Marshal(relation.Properties)
|
|
if err != nil {
|
|
props = []byte("{}")
|
|
}
|
|
id := uuid.Must(uuid.NewV7())
|
|
now := time.Now()
|
|
_, err = s.db.ExecContext(ctx, `
|
|
INSERT INTO kg_relations
|
|
(id, agent_id, user_id, source_entity_id, relation_type, target_entity_id, confidence, properties, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
ON CONFLICT (agent_id, user_id, source_entity_id, relation_type, target_entity_id) DO UPDATE SET
|
|
confidence = EXCLUDED.confidence,
|
|
properties = EXCLUDED.properties`,
|
|
id, aid, relation.UserID, src, relation.RelationType, tgt, relation.Confidence, props, now,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) DeleteRelation(ctx context.Context, agentID, userID, relationID string) error {
|
|
aid := mustParseUUID(agentID)
|
|
rid := mustParseUUID(relationID)
|
|
_, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM kg_relations WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
|
|
rid, aid, userID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) ListRelations(ctx context.Context, agentID, userID, entityID string) ([]store.Relation, error) {
|
|
aid := mustParseUUID(agentID)
|
|
eid := mustParseUUID(entityID)
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, agent_id, user_id, source_entity_id, relation_type, target_entity_id,
|
|
confidence, properties, created_at
|
|
FROM kg_relations
|
|
WHERE agent_id = $1 AND user_id = $2
|
|
AND (source_entity_id = $3 OR target_entity_id = $3)
|
|
ORDER BY created_at DESC`,
|
|
aid, userID, eid,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanRelations(rows)
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) ListAllRelations(ctx context.Context, agentID, userID string, limit int) ([]store.Relation, error) {
|
|
aid := mustParseUUID(agentID)
|
|
if limit <= 0 {
|
|
limit = 200
|
|
}
|
|
where := "agent_id = $1"
|
|
args := []any{aid}
|
|
idx := 2
|
|
if userID != "" {
|
|
where += fmt.Sprintf(" AND user_id = $%d", idx)
|
|
args = append(args, userID)
|
|
idx++
|
|
}
|
|
args = append(args, limit)
|
|
q := fmt.Sprintf(`
|
|
SELECT id, agent_id, user_id, source_entity_id, relation_type, target_entity_id,
|
|
confidence, properties, created_at
|
|
FROM kg_relations WHERE %s
|
|
ORDER BY created_at DESC LIMIT $%d`, where, idx)
|
|
rows, err := s.db.QueryContext(ctx, q, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanRelations(rows)
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) IngestExtraction(ctx context.Context, agentID, userID string, entities []store.Entity, relations []store.Relation) error {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback() //nolint:errcheck
|
|
|
|
aid := mustParseUUID(agentID)
|
|
now := time.Now()
|
|
|
|
// Upsert entities and build external_id → DB UUID lookup for relations
|
|
extIDToUUID := make(map[string]uuid.UUID, len(entities))
|
|
for i := range entities {
|
|
e := &entities[i]
|
|
e.AgentID = agentID
|
|
e.UserID = userID
|
|
props, _ := json.Marshal(e.Properties)
|
|
id := uuid.Must(uuid.NewV7())
|
|
// Use RETURNING to get the actual ID (could be existing row on conflict)
|
|
var actualID uuid.UUID
|
|
if err := tx.QueryRowContext(ctx, `
|
|
INSERT INTO kg_entities
|
|
(id, agent_id, user_id, external_id, name, entity_type, description, properties, source_id, confidence, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11)
|
|
ON CONFLICT (agent_id, user_id, external_id) DO UPDATE SET
|
|
name = EXCLUDED.name,
|
|
entity_type = EXCLUDED.entity_type,
|
|
description = EXCLUDED.description,
|
|
properties = EXCLUDED.properties,
|
|
source_id = EXCLUDED.source_id,
|
|
confidence = EXCLUDED.confidence,
|
|
updated_at = EXCLUDED.updated_at
|
|
RETURNING id`,
|
|
id, aid, userID, e.ExternalID, e.Name, e.EntityType,
|
|
e.Description, props, e.SourceID, e.Confidence, now,
|
|
).Scan(&actualID); err != nil {
|
|
return err
|
|
}
|
|
extIDToUUID[e.ExternalID] = actualID
|
|
}
|
|
|
|
for i := range relations {
|
|
r := &relations[i]
|
|
r.AgentID = agentID
|
|
r.UserID = userID
|
|
// Resolve external_id references to actual DB UUIDs
|
|
src, ok1 := extIDToUUID[r.SourceEntityID]
|
|
tgt, ok2 := extIDToUUID[r.TargetEntityID]
|
|
if !ok1 || !ok2 {
|
|
continue // skip relations referencing unknown entities
|
|
}
|
|
props, _ := json.Marshal(r.Properties)
|
|
id := uuid.Must(uuid.NewV7())
|
|
if _, err := tx.ExecContext(ctx, `
|
|
INSERT INTO kg_relations
|
|
(id, agent_id, user_id, source_entity_id, relation_type, target_entity_id, confidence, properties, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
ON CONFLICT (agent_id, user_id, source_entity_id, relation_type, target_entity_id) DO UPDATE SET
|
|
confidence = EXCLUDED.confidence,
|
|
properties = EXCLUDED.properties`,
|
|
id, aid, userID, src, r.RelationType, tgt, r.Confidence, props, now,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) PruneByConfidence(ctx context.Context, agentID, userID string, minConfidence float64) (int, error) {
|
|
aid := mustParseUUID(agentID)
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM kg_entities WHERE agent_id = $1 AND user_id = $2 AND confidence < $3`,
|
|
aid, userID, minConfidence,
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return int(n), nil
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) Stats(ctx context.Context, agentID, userID string) (*store.GraphStats, error) {
|
|
aid := mustParseUUID(agentID)
|
|
stats := &store.GraphStats{EntityTypes: make(map[string]int)}
|
|
|
|
userFilter := ""
|
|
args := []any{aid}
|
|
if userID != "" {
|
|
userFilter = " AND user_id = $2"
|
|
args = append(args, userID)
|
|
}
|
|
|
|
if err := s.db.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM kg_entities WHERE agent_id = $1`+userFilter, args...,
|
|
).Scan(&stats.EntityCount); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.db.QueryRowContext(ctx,
|
|
`SELECT COUNT(*) FROM kg_relations WHERE agent_id = $1`+userFilter, args...,
|
|
).Scan(&stats.RelationCount); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := s.db.QueryContext(ctx,
|
|
`SELECT entity_type, COUNT(*) FROM kg_entities WHERE agent_id = $1`+userFilter+` GROUP BY entity_type`, args...,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var t string
|
|
var c int
|
|
if err := rows.Scan(&t, &c); err != nil {
|
|
continue
|
|
}
|
|
stats.EntityTypes[t] = c
|
|
}
|
|
return stats, nil
|
|
}
|
|
|
|
func (s *PGKnowledgeGraphStore) Close() error { return nil }
|
|
|
|
// --- scan helpers ---
|
|
|
|
type rowScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanEntity(row rowScanner) (*store.Entity, error) {
|
|
var e store.Entity
|
|
var props []byte
|
|
var createdAt, updatedAt time.Time
|
|
if err := row.Scan(
|
|
&e.ID, &e.AgentID, &e.UserID, &e.ExternalID, &e.Name, &e.EntityType,
|
|
&e.Description, &props, &e.SourceID, &e.Confidence, &createdAt, &updatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
json.Unmarshal(props, &e.Properties) //nolint:errcheck
|
|
e.CreatedAt = createdAt.UnixMilli()
|
|
e.UpdatedAt = updatedAt.UnixMilli()
|
|
return &e, nil
|
|
}
|
|
|
|
func scanEntities(rows *sql.Rows) ([]store.Entity, error) {
|
|
var result []store.Entity
|
|
for rows.Next() {
|
|
var e store.Entity
|
|
var props []byte
|
|
var createdAt, updatedAt time.Time
|
|
if err := rows.Scan(
|
|
&e.ID, &e.AgentID, &e.UserID, &e.ExternalID, &e.Name, &e.EntityType,
|
|
&e.Description, &props, &e.SourceID, &e.Confidence, &createdAt, &updatedAt,
|
|
); err != nil {
|
|
continue
|
|
}
|
|
json.Unmarshal(props, &e.Properties) //nolint:errcheck
|
|
e.CreatedAt = createdAt.UnixMilli()
|
|
e.UpdatedAt = updatedAt.UnixMilli()
|
|
result = append(result, e)
|
|
}
|
|
return result, rows.Err()
|
|
}
|
|
|
|
func scanRelations(rows *sql.Rows) ([]store.Relation, error) {
|
|
var result []store.Relation
|
|
for rows.Next() {
|
|
var r store.Relation
|
|
var props []byte
|
|
var createdAt time.Time
|
|
if err := rows.Scan(
|
|
&r.ID, &r.AgentID, &r.UserID, &r.SourceEntityID, &r.RelationType,
|
|
&r.TargetEntityID, &r.Confidence, &props, &createdAt,
|
|
); err != nil {
|
|
continue
|
|
}
|
|
json.Unmarshal(props, &r.Properties) //nolint:errcheck
|
|
r.CreatedAt = createdAt.UnixMilli()
|
|
result = append(result, r)
|
|
}
|
|
return result, rows.Err()
|
|
}
|