mirror of
https://github.com/tiennm99/goclaw.git
synced 2026-06-10 08:11:23 +00:00
49f51da81c
Zero temperature was too rigid, causing LLM to miss implied entities and relations. 0.2 allows picking up contextual connections while staying deterministic for structured JSON output.
290 lines
7.9 KiB
Go
290 lines
7.9 KiB
Go
package knowledgegraph
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
|
"github.com/nextlevelbuilder/goclaw/internal/store"
|
|
)
|
|
|
|
// ExtractionResult holds entities and relations extracted from text.
|
|
type ExtractionResult struct {
|
|
Entities []store.Entity `json:"entities"`
|
|
Relations []store.Relation `json:"relations"`
|
|
}
|
|
|
|
// Extractor extracts entities and relations from text using an LLM.
|
|
type Extractor struct {
|
|
provider providers.Provider
|
|
model string
|
|
minConfidence float64
|
|
}
|
|
|
|
// NewExtractor creates a new Extractor with the given provider, model, and confidence threshold.
|
|
func NewExtractor(provider providers.Provider, model string, minConfidence float64) *Extractor {
|
|
if minConfidence <= 0 {
|
|
minConfidence = 0.75
|
|
}
|
|
return &Extractor{provider: provider, model: model, minConfidence: minConfidence}
|
|
}
|
|
|
|
const maxChunkChars = 12000
|
|
|
|
// Extract calls the LLM to extract entities and relations from text.
|
|
// For long texts, it splits into chunks, extracts from each, and merges results.
|
|
func (e *Extractor) Extract(ctx context.Context, text string) (*ExtractionResult, error) {
|
|
// Short text: single extraction
|
|
if len(text) <= maxChunkChars {
|
|
return e.extractChunk(ctx, text)
|
|
}
|
|
|
|
// Long text: split into chunks and merge
|
|
chunks := splitChunks(text, maxChunkChars)
|
|
slog.Info("kg extraction: splitting long input", "chunks", len(chunks), "total_len", len(text))
|
|
|
|
merged := &ExtractionResult{}
|
|
for i, chunk := range chunks {
|
|
result, err := e.extractChunk(ctx, chunk)
|
|
if err != nil {
|
|
slog.Warn("kg extraction: chunk failed", "chunk", i+1, "total", len(chunks), "error", err)
|
|
continue // skip failed chunk, extract what we can
|
|
}
|
|
merged = mergeResults(merged, result)
|
|
}
|
|
return merged, nil
|
|
}
|
|
|
|
// extractChunk extracts entities from a single chunk of text.
|
|
func (e *Extractor) extractChunk(ctx context.Context, text string) (*ExtractionResult, error) {
|
|
req := providers.ChatRequest{
|
|
Messages: []providers.Message{
|
|
{Role: "system", Content: extractionSystemPrompt},
|
|
{Role: "user", Content: text},
|
|
},
|
|
Model: e.model,
|
|
Options: map[string]any{
|
|
"max_tokens": 8192,
|
|
"temperature": 0.2,
|
|
},
|
|
}
|
|
|
|
resp, err := e.provider.Chat(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kg extraction LLM call: %w", err)
|
|
}
|
|
|
|
// If response was truncated, retry with shorter input
|
|
if resp.FinishReason == "length" {
|
|
slog.Warn("kg extraction: response truncated, retrying with shorter input")
|
|
const retryMaxChars = 8000
|
|
if len(text) > retryMaxChars {
|
|
text = text[:retryMaxChars] + "\n\n[...truncated]"
|
|
}
|
|
req.Messages[1].Content = text
|
|
resp, err = e.provider.Chat(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kg extraction LLM retry: %w", err)
|
|
}
|
|
if resp.FinishReason == "length" {
|
|
return nil, fmt.Errorf("kg extraction: response still truncated after retry")
|
|
}
|
|
}
|
|
|
|
// Parse JSON response
|
|
var result ExtractionResult
|
|
content := strings.TrimSpace(resp.Content)
|
|
content = stripCodeBlock(content)
|
|
|
|
originalContent := content
|
|
content = sanitizeJSON(content)
|
|
if content != originalContent {
|
|
slog.Debug("kg extraction: sanitized JSON output",
|
|
"original_len", len(originalContent),
|
|
"sanitized_len", len(content),
|
|
)
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
|
preview := originalContent
|
|
if len(preview) > 300 {
|
|
preview = preview[:300] + "..."
|
|
}
|
|
slog.Warn("kg extraction: failed to parse LLM response",
|
|
"error", err,
|
|
"content_len", len(originalContent),
|
|
"finish_reason", resp.FinishReason,
|
|
"preview", preview,
|
|
)
|
|
return nil, fmt.Errorf("parse extraction result: %w", err)
|
|
}
|
|
|
|
// Filter by confidence threshold and normalize
|
|
filtered := &ExtractionResult{}
|
|
for _, ent := range result.Entities {
|
|
if ent.Confidence >= e.minConfidence {
|
|
ent.ExternalID = strings.ToLower(strings.TrimSpace(ent.ExternalID))
|
|
ent.Name = strings.TrimSpace(ent.Name)
|
|
ent.EntityType = strings.ToLower(strings.TrimSpace(ent.EntityType))
|
|
filtered.Entities = append(filtered.Entities, ent)
|
|
}
|
|
}
|
|
for _, rel := range result.Relations {
|
|
if rel.Confidence >= e.minConfidence {
|
|
rel.SourceEntityID = strings.ToLower(strings.TrimSpace(rel.SourceEntityID))
|
|
rel.TargetEntityID = strings.ToLower(strings.TrimSpace(rel.TargetEntityID))
|
|
rel.RelationType = strings.ToLower(strings.TrimSpace(rel.RelationType))
|
|
filtered.Relations = append(filtered.Relations, rel)
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
// sanitizeJSON fixes common LLM JSON issues while preserving string values.
|
|
// It walks the JSON character-by-character, only applying fixes outside quoted strings:
|
|
// - Malformed decimals: "0. 85" → "0.85"
|
|
// - Trailing commas: [1, 2,] → [1, 2]
|
|
func sanitizeJSON(s string) string {
|
|
var b strings.Builder
|
|
b.Grow(len(s))
|
|
|
|
inString := false
|
|
escaped := false
|
|
|
|
for i := 0; i < len(s); i++ {
|
|
ch := s[i]
|
|
|
|
if escaped {
|
|
b.WriteByte(ch)
|
|
escaped = false
|
|
continue
|
|
}
|
|
|
|
if ch == '\\' && inString {
|
|
b.WriteByte(ch)
|
|
escaped = true
|
|
continue
|
|
}
|
|
|
|
if ch == '"' {
|
|
inString = !inString
|
|
b.WriteByte(ch)
|
|
continue
|
|
}
|
|
|
|
if inString {
|
|
b.WriteByte(ch)
|
|
continue
|
|
}
|
|
|
|
// Fix malformed decimals: "0. 85" → "0.85"
|
|
if ch == '.' && i > 0 && isDigit(s[i-1]) {
|
|
b.WriteByte('.')
|
|
for i+1 < len(s) && s[i+1] == ' ' {
|
|
i++
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Fix trailing commas: skip comma if next non-whitespace is } or ]
|
|
if ch == ',' {
|
|
j := i + 1
|
|
for j < len(s) && (s[j] == ' ' || s[j] == '\t' || s[j] == '\n' || s[j] == '\r') {
|
|
j++
|
|
}
|
|
if j < len(s) && (s[j] == '}' || s[j] == ']') {
|
|
continue
|
|
}
|
|
}
|
|
|
|
b.WriteByte(ch)
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
func isDigit(c byte) bool {
|
|
return c >= '0' && c <= '9'
|
|
}
|
|
|
|
// splitChunks splits text into chunks at paragraph boundaries (\n\n).
|
|
func splitChunks(text string, maxChars int) []string {
|
|
if len(text) <= maxChars {
|
|
return []string{text}
|
|
}
|
|
|
|
var chunks []string
|
|
for len(text) > 0 {
|
|
if len(text) <= maxChars {
|
|
chunks = append(chunks, text)
|
|
break
|
|
}
|
|
// Find last paragraph break within limit
|
|
cut := maxChars
|
|
if idx := strings.LastIndex(text[:cut], "\n\n"); idx > cut/2 {
|
|
cut = idx
|
|
}
|
|
chunks = append(chunks, strings.TrimSpace(text[:cut]))
|
|
text = strings.TrimSpace(text[cut:])
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
// mergeResults merges two extraction results, deduplicating entities by external_id
|
|
// (keeping higher confidence) and relations by source+type+target.
|
|
func mergeResults(a, b *ExtractionResult) *ExtractionResult {
|
|
// Deduplicate entities — keep higher confidence
|
|
entityMap := make(map[string]store.Entity, len(a.Entities)+len(b.Entities))
|
|
for _, ent := range a.Entities {
|
|
entityMap[ent.ExternalID] = ent
|
|
}
|
|
for _, ent := range b.Entities {
|
|
if existing, ok := entityMap[ent.ExternalID]; !ok || ent.Confidence > existing.Confidence {
|
|
entityMap[ent.ExternalID] = ent
|
|
}
|
|
}
|
|
|
|
// Deduplicate relations
|
|
type relKey struct{ src, rel, tgt string }
|
|
relMap := make(map[relKey]store.Relation, len(a.Relations)+len(b.Relations))
|
|
for _, rel := range a.Relations {
|
|
relMap[relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}] = rel
|
|
}
|
|
for _, rel := range b.Relations {
|
|
k := relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}
|
|
if existing, ok := relMap[k]; !ok || rel.Confidence > existing.Confidence {
|
|
relMap[k] = rel
|
|
}
|
|
}
|
|
|
|
result := &ExtractionResult{
|
|
Entities: make([]store.Entity, 0, len(entityMap)),
|
|
Relations: make([]store.Relation, 0, len(relMap)),
|
|
}
|
|
for _, ent := range entityMap {
|
|
result.Entities = append(result.Entities, ent)
|
|
}
|
|
for _, rel := range relMap {
|
|
result.Relations = append(result.Relations, rel)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// stripCodeBlock removes ```json ... ``` wrapper if present.
|
|
func stripCodeBlock(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
if strings.HasPrefix(s, "```") {
|
|
if idx := strings.Index(s, "\n"); idx >= 0 {
|
|
s = s[idx+1:]
|
|
}
|
|
if idx := strings.LastIndex(s, "```"); idx >= 0 {
|
|
s = s[:idx]
|
|
}
|
|
}
|
|
return strings.TrimSpace(s)
|
|
}
|
|
|