Files
goclaw/internal/store/pg/knowledge_graph.go
T
Viet Tran ce333c70f3 fix(security): followup hardening — ILIKE ESCAPE, allowlist logging, shell deny, tests (#251)
- Add explicit ESCAPE '\' clause to all ILIKE queries (knowledge_graph,
  custom_tools, channel_instances, channel_contacts) for correct wildcard
  escaping regardless of PostgreSQL standard_conforming_strings setting
- Log warning when filterAllowedKeys drops unknown fields for debuggability
- Widen base64 decode shell deny pattern to catch -di, -dw0 variants
- Add unit tests for filterAllowedKeys, pqStringArray, scanStringArray,
  pqStringArray↔scanStringArray roundtrip, limitedBuffer, base64 deny
2026-03-18 07:48:48 +07:00

404 lines
12 KiB
Go

package pg
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/nextlevelbuilder/goclaw/internal/store"
)
// PGKnowledgeGraphStore implements store.KnowledgeGraphStore backed by Postgres.
type PGKnowledgeGraphStore struct {
db *sql.DB
}
// NewPGKnowledgeGraphStore creates a new PG-backed knowledge graph store.
func NewPGKnowledgeGraphStore(db *sql.DB) *PGKnowledgeGraphStore {
return &PGKnowledgeGraphStore{db: db}
}
func (s *PGKnowledgeGraphStore) UpsertEntity(ctx context.Context, entity *store.Entity) error {
aid := mustParseUUID(entity.AgentID)
props, err := json.Marshal(entity.Properties)
if err != nil {
props = []byte("{}")
}
now := time.Now()
id := uuid.Must(uuid.NewV7())
_, err = s.db.ExecContext(ctx, `
INSERT INTO kg_entities
(id, agent_id, user_id, external_id, name, entity_type, description, properties, source_id, confidence, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11)
ON CONFLICT (agent_id, user_id, external_id) DO UPDATE SET
name = EXCLUDED.name,
entity_type = EXCLUDED.entity_type,
description = EXCLUDED.description,
properties = EXCLUDED.properties,
source_id = EXCLUDED.source_id,
confidence = EXCLUDED.confidence,
updated_at = EXCLUDED.updated_at`,
id, aid, entity.UserID, entity.ExternalID, entity.Name, entity.EntityType,
entity.Description, props, entity.SourceID, entity.Confidence, now,
)
return err
}
func (s *PGKnowledgeGraphStore) GetEntity(ctx context.Context, agentID, userID, entityID string) (*store.Entity, error) {
aid := mustParseUUID(agentID)
eid := mustParseUUID(entityID)
row := s.db.QueryRowContext(ctx, `
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
properties, source_id, confidence, created_at, updated_at
FROM kg_entities WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
eid, aid, userID,
)
return scanEntity(row)
}
func (s *PGKnowledgeGraphStore) DeleteEntity(ctx context.Context, agentID, userID, entityID string) error {
aid := mustParseUUID(agentID)
eid := mustParseUUID(entityID)
_, err := s.db.ExecContext(ctx,
`DELETE FROM kg_entities WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
eid, aid, userID,
)
return err
}
func (s *PGKnowledgeGraphStore) ListEntities(ctx context.Context, agentID, userID string, opts store.EntityListOptions) ([]store.Entity, error) {
aid := mustParseUUID(agentID)
limit := opts.Limit
if limit <= 0 {
limit = 50
}
// Build dynamic WHERE clause: always filter by agent_id, optionally by user_id and entity_type
where := "agent_id = $1"
args := []any{aid}
idx := 2
if userID != "" {
where += fmt.Sprintf(" AND user_id = $%d", idx)
args = append(args, userID)
idx++
}
if opts.EntityType != "" {
where += fmt.Sprintf(" AND entity_type = $%d", idx)
args = append(args, opts.EntityType)
idx++
}
args = append(args, limit, opts.Offset)
query := fmt.Sprintf(`
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
properties, source_id, confidence, created_at, updated_at
FROM kg_entities WHERE %s
ORDER BY updated_at DESC LIMIT $%d OFFSET $%d`, where, idx, idx+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanEntities(rows)
}
func (s *PGKnowledgeGraphStore) SearchEntities(ctx context.Context, agentID, userID, query string, limit int) ([]store.Entity, error) {
aid := mustParseUUID(agentID)
if limit <= 0 {
limit = 20
}
// Escape LIKE wildcards to prevent pattern injection.
escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(query)
pattern := "%" + escaped + "%"
where := "agent_id = $1"
args := []any{aid}
idx := 2
if userID != "" {
where += fmt.Sprintf(" AND user_id = $%d", idx)
args = append(args, userID)
idx++
}
args = append(args, pattern, limit)
q := fmt.Sprintf(`
SELECT id, agent_id, user_id, external_id, name, entity_type, description,
properties, source_id, confidence, created_at, updated_at
FROM kg_entities
WHERE %s AND (name ILIKE $%d ESCAPE '\' OR description ILIKE $%d ESCAPE '\')
ORDER BY confidence DESC, updated_at DESC LIMIT $%d`, where, idx, idx, idx+1)
rows, err := s.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanEntities(rows)
}
func (s *PGKnowledgeGraphStore) UpsertRelation(ctx context.Context, relation *store.Relation) error {
aid := mustParseUUID(relation.AgentID)
src := mustParseUUID(relation.SourceEntityID)
tgt := mustParseUUID(relation.TargetEntityID)
props, err := json.Marshal(relation.Properties)
if err != nil {
props = []byte("{}")
}
id := uuid.Must(uuid.NewV7())
now := time.Now()
_, err = s.db.ExecContext(ctx, `
INSERT INTO kg_relations
(id, agent_id, user_id, source_entity_id, relation_type, target_entity_id, confidence, properties, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (agent_id, user_id, source_entity_id, relation_type, target_entity_id) DO UPDATE SET
confidence = EXCLUDED.confidence,
properties = EXCLUDED.properties`,
id, aid, relation.UserID, src, relation.RelationType, tgt, relation.Confidence, props, now,
)
return err
}
func (s *PGKnowledgeGraphStore) DeleteRelation(ctx context.Context, agentID, userID, relationID string) error {
aid := mustParseUUID(agentID)
rid := mustParseUUID(relationID)
_, err := s.db.ExecContext(ctx,
`DELETE FROM kg_relations WHERE id = $1 AND agent_id = $2 AND user_id = $3`,
rid, aid, userID,
)
return err
}
func (s *PGKnowledgeGraphStore) ListRelations(ctx context.Context, agentID, userID, entityID string) ([]store.Relation, error) {
aid := mustParseUUID(agentID)
eid := mustParseUUID(entityID)
rows, err := s.db.QueryContext(ctx, `
SELECT id, agent_id, user_id, source_entity_id, relation_type, target_entity_id,
confidence, properties, created_at
FROM kg_relations
WHERE agent_id = $1 AND user_id = $2
AND (source_entity_id = $3 OR target_entity_id = $3)
ORDER BY created_at DESC`,
aid, userID, eid,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanRelations(rows)
}
func (s *PGKnowledgeGraphStore) ListAllRelations(ctx context.Context, agentID, userID string, limit int) ([]store.Relation, error) {
aid := mustParseUUID(agentID)
if limit <= 0 {
limit = 200
}
where := "agent_id = $1"
args := []any{aid}
idx := 2
if userID != "" {
where += fmt.Sprintf(" AND user_id = $%d", idx)
args = append(args, userID)
idx++
}
args = append(args, limit)
q := fmt.Sprintf(`
SELECT id, agent_id, user_id, source_entity_id, relation_type, target_entity_id,
confidence, properties, created_at
FROM kg_relations WHERE %s
ORDER BY created_at DESC LIMIT $%d`, where, idx)
rows, err := s.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanRelations(rows)
}
func (s *PGKnowledgeGraphStore) IngestExtraction(ctx context.Context, agentID, userID string, entities []store.Entity, relations []store.Relation) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
aid := mustParseUUID(agentID)
now := time.Now()
// Upsert entities and build external_id → DB UUID lookup for relations
extIDToUUID := make(map[string]uuid.UUID, len(entities))
for i := range entities {
e := &entities[i]
e.AgentID = agentID
e.UserID = userID
props, _ := json.Marshal(e.Properties)
id := uuid.Must(uuid.NewV7())
// Use RETURNING to get the actual ID (could be existing row on conflict)
var actualID uuid.UUID
if err := tx.QueryRowContext(ctx, `
INSERT INTO kg_entities
(id, agent_id, user_id, external_id, name, entity_type, description, properties, source_id, confidence, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $11)
ON CONFLICT (agent_id, user_id, external_id) DO UPDATE SET
name = EXCLUDED.name,
entity_type = EXCLUDED.entity_type,
description = EXCLUDED.description,
properties = EXCLUDED.properties,
source_id = EXCLUDED.source_id,
confidence = EXCLUDED.confidence,
updated_at = EXCLUDED.updated_at
RETURNING id`,
id, aid, userID, e.ExternalID, e.Name, e.EntityType,
e.Description, props, e.SourceID, e.Confidence, now,
).Scan(&actualID); err != nil {
return err
}
extIDToUUID[e.ExternalID] = actualID
}
for i := range relations {
r := &relations[i]
r.AgentID = agentID
r.UserID = userID
// Resolve external_id references to actual DB UUIDs
src, ok1 := extIDToUUID[r.SourceEntityID]
tgt, ok2 := extIDToUUID[r.TargetEntityID]
if !ok1 || !ok2 {
continue // skip relations referencing unknown entities
}
props, _ := json.Marshal(r.Properties)
id := uuid.Must(uuid.NewV7())
if _, err := tx.ExecContext(ctx, `
INSERT INTO kg_relations
(id, agent_id, user_id, source_entity_id, relation_type, target_entity_id, confidence, properties, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (agent_id, user_id, source_entity_id, relation_type, target_entity_id) DO UPDATE SET
confidence = EXCLUDED.confidence,
properties = EXCLUDED.properties`,
id, aid, userID, src, r.RelationType, tgt, r.Confidence, props, now,
); err != nil {
return err
}
}
return tx.Commit()
}
func (s *PGKnowledgeGraphStore) PruneByConfidence(ctx context.Context, agentID, userID string, minConfidence float64) (int, error) {
aid := mustParseUUID(agentID)
res, err := s.db.ExecContext(ctx,
`DELETE FROM kg_entities WHERE agent_id = $1 AND user_id = $2 AND confidence < $3`,
aid, userID, minConfidence,
)
if err != nil {
return 0, err
}
n, _ := res.RowsAffected()
return int(n), nil
}
func (s *PGKnowledgeGraphStore) Stats(ctx context.Context, agentID, userID string) (*store.GraphStats, error) {
aid := mustParseUUID(agentID)
stats := &store.GraphStats{EntityTypes: make(map[string]int)}
userFilter := ""
args := []any{aid}
if userID != "" {
userFilter = " AND user_id = $2"
args = append(args, userID)
}
if err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM kg_entities WHERE agent_id = $1`+userFilter, args...,
).Scan(&stats.EntityCount); err != nil {
return nil, err
}
if err := s.db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM kg_relations WHERE agent_id = $1`+userFilter, args...,
).Scan(&stats.RelationCount); err != nil {
return nil, err
}
rows, err := s.db.QueryContext(ctx,
`SELECT entity_type, COUNT(*) FROM kg_entities WHERE agent_id = $1`+userFilter+` GROUP BY entity_type`, args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var t string
var c int
if err := rows.Scan(&t, &c); err != nil {
continue
}
stats.EntityTypes[t] = c
}
return stats, nil
}
func (s *PGKnowledgeGraphStore) Close() error { return nil }
// --- scan helpers ---
type rowScanner interface {
Scan(dest ...any) error
}
func scanEntity(row rowScanner) (*store.Entity, error) {
var e store.Entity
var props []byte
var createdAt, updatedAt time.Time
if err := row.Scan(
&e.ID, &e.AgentID, &e.UserID, &e.ExternalID, &e.Name, &e.EntityType,
&e.Description, &props, &e.SourceID, &e.Confidence, &createdAt, &updatedAt,
); err != nil {
return nil, err
}
json.Unmarshal(props, &e.Properties) //nolint:errcheck
e.CreatedAt = createdAt.UnixMilli()
e.UpdatedAt = updatedAt.UnixMilli()
return &e, nil
}
func scanEntities(rows *sql.Rows) ([]store.Entity, error) {
var result []store.Entity
for rows.Next() {
var e store.Entity
var props []byte
var createdAt, updatedAt time.Time
if err := rows.Scan(
&e.ID, &e.AgentID, &e.UserID, &e.ExternalID, &e.Name, &e.EntityType,
&e.Description, &props, &e.SourceID, &e.Confidence, &createdAt, &updatedAt,
); err != nil {
continue
}
json.Unmarshal(props, &e.Properties) //nolint:errcheck
e.CreatedAt = createdAt.UnixMilli()
e.UpdatedAt = updatedAt.UnixMilli()
result = append(result, e)
}
return result, rows.Err()
}
func scanRelations(rows *sql.Rows) ([]store.Relation, error) {
var result []store.Relation
for rows.Next() {
var r store.Relation
var props []byte
var createdAt time.Time
if err := rows.Scan(
&r.ID, &r.AgentID, &r.UserID, &r.SourceEntityID, &r.RelationType,
&r.TargetEntityID, &r.Confidence, &props, &createdAt,
); err != nil {
continue
}
json.Unmarshal(props, &r.Properties) //nolint:errcheck
r.CreatedAt = createdAt.UnixMilli()
result = append(result, r)
}
return result, rows.Err()
}