Files
viettranx 49f51da81c fix(kg): raise extraction temperature from 0.0 to 0.2
Zero temperature was too rigid, causing LLM to miss implied entities
and relations. 0.2 allows picking up contextual connections while
staying deterministic for structured JSON output.
2026-03-31 19:06:38 +07:00

290 lines
7.9 KiB
Go

package knowledgegraph
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"github.com/nextlevelbuilder/goclaw/internal/providers"
"github.com/nextlevelbuilder/goclaw/internal/store"
)
// ExtractionResult holds entities and relations extracted from text.
type ExtractionResult struct {
Entities []store.Entity `json:"entities"`
Relations []store.Relation `json:"relations"`
}
// Extractor extracts entities and relations from text using an LLM.
type Extractor struct {
provider providers.Provider
model string
minConfidence float64
}
// NewExtractor creates a new Extractor with the given provider, model, and confidence threshold.
func NewExtractor(provider providers.Provider, model string, minConfidence float64) *Extractor {
if minConfidence <= 0 {
minConfidence = 0.75
}
return &Extractor{provider: provider, model: model, minConfidence: minConfidence}
}
const maxChunkChars = 12000
// Extract calls the LLM to extract entities and relations from text.
// For long texts, it splits into chunks, extracts from each, and merges results.
func (e *Extractor) Extract(ctx context.Context, text string) (*ExtractionResult, error) {
// Short text: single extraction
if len(text) <= maxChunkChars {
return e.extractChunk(ctx, text)
}
// Long text: split into chunks and merge
chunks := splitChunks(text, maxChunkChars)
slog.Info("kg extraction: splitting long input", "chunks", len(chunks), "total_len", len(text))
merged := &ExtractionResult{}
for i, chunk := range chunks {
result, err := e.extractChunk(ctx, chunk)
if err != nil {
slog.Warn("kg extraction: chunk failed", "chunk", i+1, "total", len(chunks), "error", err)
continue // skip failed chunk, extract what we can
}
merged = mergeResults(merged, result)
}
return merged, nil
}
// extractChunk extracts entities from a single chunk of text.
func (e *Extractor) extractChunk(ctx context.Context, text string) (*ExtractionResult, error) {
req := providers.ChatRequest{
Messages: []providers.Message{
{Role: "system", Content: extractionSystemPrompt},
{Role: "user", Content: text},
},
Model: e.model,
Options: map[string]any{
"max_tokens": 8192,
"temperature": 0.2,
},
}
resp, err := e.provider.Chat(ctx, req)
if err != nil {
return nil, fmt.Errorf("kg extraction LLM call: %w", err)
}
// If response was truncated, retry with shorter input
if resp.FinishReason == "length" {
slog.Warn("kg extraction: response truncated, retrying with shorter input")
const retryMaxChars = 8000
if len(text) > retryMaxChars {
text = text[:retryMaxChars] + "\n\n[...truncated]"
}
req.Messages[1].Content = text
resp, err = e.provider.Chat(ctx, req)
if err != nil {
return nil, fmt.Errorf("kg extraction LLM retry: %w", err)
}
if resp.FinishReason == "length" {
return nil, fmt.Errorf("kg extraction: response still truncated after retry")
}
}
// Parse JSON response
var result ExtractionResult
content := strings.TrimSpace(resp.Content)
content = stripCodeBlock(content)
originalContent := content
content = sanitizeJSON(content)
if content != originalContent {
slog.Debug("kg extraction: sanitized JSON output",
"original_len", len(originalContent),
"sanitized_len", len(content),
)
}
if err := json.Unmarshal([]byte(content), &result); err != nil {
preview := originalContent
if len(preview) > 300 {
preview = preview[:300] + "..."
}
slog.Warn("kg extraction: failed to parse LLM response",
"error", err,
"content_len", len(originalContent),
"finish_reason", resp.FinishReason,
"preview", preview,
)
return nil, fmt.Errorf("parse extraction result: %w", err)
}
// Filter by confidence threshold and normalize
filtered := &ExtractionResult{}
for _, ent := range result.Entities {
if ent.Confidence >= e.minConfidence {
ent.ExternalID = strings.ToLower(strings.TrimSpace(ent.ExternalID))
ent.Name = strings.TrimSpace(ent.Name)
ent.EntityType = strings.ToLower(strings.TrimSpace(ent.EntityType))
filtered.Entities = append(filtered.Entities, ent)
}
}
for _, rel := range result.Relations {
if rel.Confidence >= e.minConfidence {
rel.SourceEntityID = strings.ToLower(strings.TrimSpace(rel.SourceEntityID))
rel.TargetEntityID = strings.ToLower(strings.TrimSpace(rel.TargetEntityID))
rel.RelationType = strings.ToLower(strings.TrimSpace(rel.RelationType))
filtered.Relations = append(filtered.Relations, rel)
}
}
return filtered, nil
}
// sanitizeJSON fixes common LLM JSON issues while preserving string values.
// It walks the JSON character-by-character, only applying fixes outside quoted strings:
// - Malformed decimals: "0. 85" → "0.85"
// - Trailing commas: [1, 2,] → [1, 2]
func sanitizeJSON(s string) string {
var b strings.Builder
b.Grow(len(s))
inString := false
escaped := false
for i := 0; i < len(s); i++ {
ch := s[i]
if escaped {
b.WriteByte(ch)
escaped = false
continue
}
if ch == '\\' && inString {
b.WriteByte(ch)
escaped = true
continue
}
if ch == '"' {
inString = !inString
b.WriteByte(ch)
continue
}
if inString {
b.WriteByte(ch)
continue
}
// Fix malformed decimals: "0. 85" → "0.85"
if ch == '.' && i > 0 && isDigit(s[i-1]) {
b.WriteByte('.')
for i+1 < len(s) && s[i+1] == ' ' {
i++
}
continue
}
// Fix trailing commas: skip comma if next non-whitespace is } or ]
if ch == ',' {
j := i + 1
for j < len(s) && (s[j] == ' ' || s[j] == '\t' || s[j] == '\n' || s[j] == '\r') {
j++
}
if j < len(s) && (s[j] == '}' || s[j] == ']') {
continue
}
}
b.WriteByte(ch)
}
return b.String()
}
func isDigit(c byte) bool {
return c >= '0' && c <= '9'
}
// splitChunks splits text into chunks at paragraph boundaries (\n\n).
func splitChunks(text string, maxChars int) []string {
if len(text) <= maxChars {
return []string{text}
}
var chunks []string
for len(text) > 0 {
if len(text) <= maxChars {
chunks = append(chunks, text)
break
}
// Find last paragraph break within limit
cut := maxChars
if idx := strings.LastIndex(text[:cut], "\n\n"); idx > cut/2 {
cut = idx
}
chunks = append(chunks, strings.TrimSpace(text[:cut]))
text = strings.TrimSpace(text[cut:])
}
return chunks
}
// mergeResults merges two extraction results, deduplicating entities by external_id
// (keeping higher confidence) and relations by source+type+target.
func mergeResults(a, b *ExtractionResult) *ExtractionResult {
// Deduplicate entities — keep higher confidence
entityMap := make(map[string]store.Entity, len(a.Entities)+len(b.Entities))
for _, ent := range a.Entities {
entityMap[ent.ExternalID] = ent
}
for _, ent := range b.Entities {
if existing, ok := entityMap[ent.ExternalID]; !ok || ent.Confidence > existing.Confidence {
entityMap[ent.ExternalID] = ent
}
}
// Deduplicate relations
type relKey struct{ src, rel, tgt string }
relMap := make(map[relKey]store.Relation, len(a.Relations)+len(b.Relations))
for _, rel := range a.Relations {
relMap[relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}] = rel
}
for _, rel := range b.Relations {
k := relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}
if existing, ok := relMap[k]; !ok || rel.Confidence > existing.Confidence {
relMap[k] = rel
}
}
result := &ExtractionResult{
Entities: make([]store.Entity, 0, len(entityMap)),
Relations: make([]store.Relation, 0, len(relMap)),
}
for _, ent := range entityMap {
result.Entities = append(result.Entities, ent)
}
for _, rel := range relMap {
result.Relations = append(result.Relations, rel)
}
return result
}
// stripCodeBlock removes ```json ... ``` wrapper if present.
func stripCodeBlock(s string) string {
s = strings.TrimSpace(s)
if strings.HasPrefix(s, "```") {
if idx := strings.Index(s, "\n"); idx >= 0 {
s = s[idx+1:]
}
if idx := strings.LastIndex(s, "```"); idx >= 0 {
s = s[:idx]
}
}
return strings.TrimSpace(s)
}