mirror of
https://github.com/tiennm99/goclaw.git
synced 2026-06-10 12:10:53 +00:00
6dcac3d7e3
Add sanitizeJSON() to fix malformed decimal numbers (e.g. '0. 85' → '0.85') and trailing commas before closing brackets. Fixes extraction failures with Gemini 2.5 Flash which occasionally produces invalid JSON.
290 lines
7.9 KiB
Go
290 lines
7.9 KiB
Go
package knowledgegraph
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
"github.com/nextlevelbuilder/goclaw/internal/providers"
|
|
"github.com/nextlevelbuilder/goclaw/internal/store"
|
|
)
|
|
|
|
// ExtractionResult holds entities and relations extracted from text.
|
|
type ExtractionResult struct {
|
|
Entities []store.Entity `json:"entities"`
|
|
Relations []store.Relation `json:"relations"`
|
|
}
|
|
|
|
// Extractor extracts entities and relations from text using an LLM.
|
|
type Extractor struct {
|
|
provider providers.Provider
|
|
model string
|
|
minConfidence float64
|
|
}
|
|
|
|
// NewExtractor creates a new Extractor with the given provider, model, and confidence threshold.
|
|
func NewExtractor(provider providers.Provider, model string, minConfidence float64) *Extractor {
|
|
if minConfidence <= 0 {
|
|
minConfidence = 0.75
|
|
}
|
|
return &Extractor{provider: provider, model: model, minConfidence: minConfidence}
|
|
}
|
|
|
|
const maxChunkChars = 12000
|
|
|
|
// Extract calls the LLM to extract entities and relations from text.
|
|
// For long texts, it splits into chunks, extracts from each, and merges results.
|
|
func (e *Extractor) Extract(ctx context.Context, text string) (*ExtractionResult, error) {
|
|
// Short text: single extraction
|
|
if len(text) <= maxChunkChars {
|
|
return e.extractChunk(ctx, text)
|
|
}
|
|
|
|
// Long text: split into chunks and merge
|
|
chunks := splitChunks(text, maxChunkChars)
|
|
slog.Info("kg extraction: splitting long input", "chunks", len(chunks), "total_len", len(text))
|
|
|
|
merged := &ExtractionResult{}
|
|
for i, chunk := range chunks {
|
|
result, err := e.extractChunk(ctx, chunk)
|
|
if err != nil {
|
|
slog.Warn("kg extraction: chunk failed", "chunk", i+1, "total", len(chunks), "error", err)
|
|
continue // skip failed chunk, extract what we can
|
|
}
|
|
merged = mergeResults(merged, result)
|
|
}
|
|
return merged, nil
|
|
}
|
|
|
|
// extractChunk extracts entities from a single chunk of text.
|
|
func (e *Extractor) extractChunk(ctx context.Context, text string) (*ExtractionResult, error) {
|
|
req := providers.ChatRequest{
|
|
Messages: []providers.Message{
|
|
{Role: "system", Content: extractionSystemPrompt},
|
|
{Role: "user", Content: text},
|
|
},
|
|
Model: e.model,
|
|
Options: map[string]any{
|
|
"max_tokens": 8192,
|
|
"temperature": 0.0,
|
|
},
|
|
}
|
|
|
|
resp, err := e.provider.Chat(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kg extraction LLM call: %w", err)
|
|
}
|
|
|
|
// If response was truncated, retry with shorter input
|
|
if resp.FinishReason == "length" {
|
|
slog.Warn("kg extraction: response truncated, retrying with shorter input")
|
|
const retryMaxChars = 8000
|
|
if len(text) > retryMaxChars {
|
|
text = text[:retryMaxChars] + "\n\n[...truncated]"
|
|
}
|
|
req.Messages[1].Content = text
|
|
resp, err = e.provider.Chat(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("kg extraction LLM retry: %w", err)
|
|
}
|
|
if resp.FinishReason == "length" {
|
|
return nil, fmt.Errorf("kg extraction: response still truncated after retry")
|
|
}
|
|
}
|
|
|
|
// Parse JSON response
|
|
var result ExtractionResult
|
|
content := strings.TrimSpace(resp.Content)
|
|
content = stripCodeBlock(content)
|
|
|
|
originalContent := content
|
|
content = sanitizeJSON(content)
|
|
if content != originalContent {
|
|
slog.Debug("kg extraction: sanitized JSON output",
|
|
"original_len", len(originalContent),
|
|
"sanitized_len", len(content),
|
|
)
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
|
preview := originalContent
|
|
if len(preview) > 300 {
|
|
preview = preview[:300] + "..."
|
|
}
|
|
slog.Warn("kg extraction: failed to parse LLM response",
|
|
"error", err,
|
|
"content_len", len(originalContent),
|
|
"finish_reason", resp.FinishReason,
|
|
"preview", preview,
|
|
)
|
|
return nil, fmt.Errorf("parse extraction result: %w", err)
|
|
}
|
|
|
|
// Filter by confidence threshold and normalize
|
|
filtered := &ExtractionResult{}
|
|
for _, ent := range result.Entities {
|
|
if ent.Confidence >= e.minConfidence {
|
|
ent.ExternalID = strings.ToLower(strings.TrimSpace(ent.ExternalID))
|
|
ent.Name = strings.TrimSpace(ent.Name)
|
|
ent.EntityType = strings.ToLower(strings.TrimSpace(ent.EntityType))
|
|
filtered.Entities = append(filtered.Entities, ent)
|
|
}
|
|
}
|
|
for _, rel := range result.Relations {
|
|
if rel.Confidence >= e.minConfidence {
|
|
rel.SourceEntityID = strings.ToLower(strings.TrimSpace(rel.SourceEntityID))
|
|
rel.TargetEntityID = strings.ToLower(strings.TrimSpace(rel.TargetEntityID))
|
|
rel.RelationType = strings.ToLower(strings.TrimSpace(rel.RelationType))
|
|
filtered.Relations = append(filtered.Relations, rel)
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
// sanitizeJSON fixes common LLM JSON issues while preserving string values.
|
|
// It walks the JSON character-by-character, only applying fixes outside quoted strings:
|
|
// - Malformed decimals: "0. 85" → "0.85"
|
|
// - Trailing commas: [1, 2,] → [1, 2]
|
|
func sanitizeJSON(s string) string {
|
|
var b strings.Builder
|
|
b.Grow(len(s))
|
|
|
|
inString := false
|
|
escaped := false
|
|
|
|
for i := 0; i < len(s); i++ {
|
|
ch := s[i]
|
|
|
|
if escaped {
|
|
b.WriteByte(ch)
|
|
escaped = false
|
|
continue
|
|
}
|
|
|
|
if ch == '\\' && inString {
|
|
b.WriteByte(ch)
|
|
escaped = true
|
|
continue
|
|
}
|
|
|
|
if ch == '"' {
|
|
inString = !inString
|
|
b.WriteByte(ch)
|
|
continue
|
|
}
|
|
|
|
if inString {
|
|
b.WriteByte(ch)
|
|
continue
|
|
}
|
|
|
|
// Fix malformed decimals: "0. 85" → "0.85"
|
|
if ch == '.' && i > 0 && isDigit(s[i-1]) {
|
|
b.WriteByte('.')
|
|
for i+1 < len(s) && s[i+1] == ' ' {
|
|
i++
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Fix trailing commas: skip comma if next non-whitespace is } or ]
|
|
if ch == ',' {
|
|
j := i + 1
|
|
for j < len(s) && (s[j] == ' ' || s[j] == '\t' || s[j] == '\n' || s[j] == '\r') {
|
|
j++
|
|
}
|
|
if j < len(s) && (s[j] == '}' || s[j] == ']') {
|
|
continue
|
|
}
|
|
}
|
|
|
|
b.WriteByte(ch)
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
func isDigit(c byte) bool {
|
|
return c >= '0' && c <= '9'
|
|
}
|
|
|
|
// splitChunks splits text into chunks at paragraph boundaries (\n\n).
|
|
func splitChunks(text string, maxChars int) []string {
|
|
if len(text) <= maxChars {
|
|
return []string{text}
|
|
}
|
|
|
|
var chunks []string
|
|
for len(text) > 0 {
|
|
if len(text) <= maxChars {
|
|
chunks = append(chunks, text)
|
|
break
|
|
}
|
|
// Find last paragraph break within limit
|
|
cut := maxChars
|
|
if idx := strings.LastIndex(text[:cut], "\n\n"); idx > cut/2 {
|
|
cut = idx
|
|
}
|
|
chunks = append(chunks, strings.TrimSpace(text[:cut]))
|
|
text = strings.TrimSpace(text[cut:])
|
|
}
|
|
return chunks
|
|
}
|
|
|
|
// mergeResults merges two extraction results, deduplicating entities by external_id
|
|
// (keeping higher confidence) and relations by source+type+target.
|
|
func mergeResults(a, b *ExtractionResult) *ExtractionResult {
|
|
// Deduplicate entities — keep higher confidence
|
|
entityMap := make(map[string]store.Entity, len(a.Entities)+len(b.Entities))
|
|
for _, ent := range a.Entities {
|
|
entityMap[ent.ExternalID] = ent
|
|
}
|
|
for _, ent := range b.Entities {
|
|
if existing, ok := entityMap[ent.ExternalID]; !ok || ent.Confidence > existing.Confidence {
|
|
entityMap[ent.ExternalID] = ent
|
|
}
|
|
}
|
|
|
|
// Deduplicate relations
|
|
type relKey struct{ src, rel, tgt string }
|
|
relMap := make(map[relKey]store.Relation, len(a.Relations)+len(b.Relations))
|
|
for _, rel := range a.Relations {
|
|
relMap[relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}] = rel
|
|
}
|
|
for _, rel := range b.Relations {
|
|
k := relKey{rel.SourceEntityID, rel.RelationType, rel.TargetEntityID}
|
|
if existing, ok := relMap[k]; !ok || rel.Confidence > existing.Confidence {
|
|
relMap[k] = rel
|
|
}
|
|
}
|
|
|
|
result := &ExtractionResult{
|
|
Entities: make([]store.Entity, 0, len(entityMap)),
|
|
Relations: make([]store.Relation, 0, len(relMap)),
|
|
}
|
|
for _, ent := range entityMap {
|
|
result.Entities = append(result.Entities, ent)
|
|
}
|
|
for _, rel := range relMap {
|
|
result.Relations = append(result.Relations, rel)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// stripCodeBlock removes ```json ... ``` wrapper if present.
|
|
func stripCodeBlock(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
if strings.HasPrefix(s, "```") {
|
|
if idx := strings.Index(s, "\n"); idx >= 0 {
|
|
s = s[idx+1:]
|
|
}
|
|
if idx := strings.LastIndex(s, "```"); idx >= 0 {
|
|
s = s[:idx]
|
|
}
|
|
}
|
|
return strings.TrimSpace(s)
|
|
}
|
|
|