From 0740dffd6bf736cbb41d20136745ca3bb618a71b Mon Sep 17 00:00:00 2001 From: tiennm99 Date: Wed, 22 Apr 2026 23:53:36 +0700 Subject: [PATCH] refactor(doantu): swap ConceptNet for Workers AI bge-m3 embeddings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror the semantle migration but with @cf/baai/bge-m3 — BAAI's multilingual embedding model — because the English-only BGE variants can't produce meaningful Vietnamese vectors (their tokenizer shreds diacritics into noisy byte-level subwords). bge-m3 is trained across 194 languages incl. Vietnamese and is actually cheaper in Neurons (1,075 vs 1,841 per M tokens for bge-small-en-v1.5). Vocab check reuses the local Viet22K wordlist as an in-memory Set — O(1) OOV detection, no upstream call. Also add a test file for the module (mirrors semantle coverage plus Vietnamese-specific cases: diacritics, multi-syllable compounds). --- src/modules/doantu/api-client.js | 165 +++++++++--------------- src/modules/doantu/index.js | 13 +- tests/modules/doantu/api-client.test.js | 144 +++++++++++++++++++++ 3 files changed, 214 insertions(+), 108 deletions(-) create mode 100644 tests/modules/doantu/api-client.test.js diff --git a/src/modules/doantu/api-client.js b/src/modules/doantu/api-client.js index 89d5b1f..665a6aa 100644 --- a/src/modules/doantu/api-client.js +++ b/src/modules/doantu/api-client.js @@ -1,21 +1,29 @@ /** - * @file ConceptNet API client for the doantu module (Vietnamese). + * @file Cloudflare Workers AI client for the doantu module (Vietnamese semantle). * - * Mirrors semantle/api-client.js one-for-one — same endpoints, same - * response shape — with two Vietnamese-specific changes: - * 1. Concept URIs use `/c/vi/…` instead of `/c/en/…`. - * 2. Multi-word terms are joined with an underscore for URL building - * (`con chó` → `/c/vi/con_chó`), so the board can still display the - * space-separated form while ConceptNet resolves the canonical URI. + * Mirrors semantle/api-client.js but uses `@cf/baai/bge-m3` — BAAI's + * multilingual embedding model — because the English-only BGE variants + * can't produce meaningful Vietnamese vectors (their tokenizer is + * English-centric and Vietnamese diacritics get shredded into noisy + * byte-level subwords). + * + * Vocabulary: the curated `words-data.js` list (duyet/vietnamese-wordlist + * Viet22K) doubles as the in/out-of-vocabulary set. Lookups are O(1) via + * Set.has(), so OOV detection needs no extra round-trip. + * + * The returned `similarity(a, b)` shape matches the semantle sibling so + * handlers/render/state can be reused unchanged. */ import { randomLine } from "./wordlist.js"; +import WORDS from "./words-data.js"; -const DEFAULT_API_BASE = "https://api.conceptnet.io"; -const DEFAULT_TIMEOUT_MS = 5000; -const USER_AGENT = "miti99bot/doantu"; -const MAX_RANDOM_ATTEMPTS = 5; -const LANG = "vi"; +// BGE-M3: multilingual (194 languages incl. Vietnamese), 1024 dimensions, +// ~1,075 Neurons per M input tokens — cheaper than bge-small-en-v1.5. +const DEFAULT_MODEL = "@cf/baai/bge-m3"; + +// O(1) membership lookup for OOV detection. Built once per isolate. +const VOCAB = new Set(WORDS); export class UpstreamError extends Error { /** @param {string} message @param {{status?: number, body?: string, cause?: unknown}} [meta] */ @@ -28,117 +36,70 @@ export class UpstreamError extends Error { } } -/** `con chó` → `con_chó` — ConceptNet's concept-URI convention for phrases. */ -function toConceptTerm(word) { - return String(word).trim().replace(/\s+/g, "_"); -} - -function buildUrl(base, path, params = {}) { - const normalized = String(base).replace(/\/+$/, ""); - const url = new URL(`${normalized}${path}`); - for (const [k, v] of Object.entries(params)) { - if (v === undefined || v === null) continue; - url.searchParams.set(k, String(v)); +function cosineSimilarity(a, b) { + if (!a || !b || a.length !== b.length) return null; + let dot = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; } - return url.toString(); -} - -async function fetchJson(url, timeoutMs) { - const controller = new AbortController(); - const timer = setTimeout(() => controller.abort(), timeoutMs); - let res; - try { - res = await fetch(url, { - headers: { "User-Agent": USER_AGENT, Accept: "application/json" }, - signal: controller.signal, - }); - } catch (err) { - clearTimeout(timer); - throw new UpstreamError("conceptnet fetch failed", { cause: err }); - } - clearTimeout(timer); - const text = await res.text(); - if (!res.ok) { - throw new UpstreamError(`conceptnet HTTP ${res.status}`, { - status: res.status, - body: text.slice(0, 500), - }); - } - try { - return JSON.parse(text); - } catch (err) { - throw new UpstreamError("conceptnet non-JSON response", { cause: err }); - } -} - -function hasEdges(concept) { - return Array.isArray(concept?.edges) && concept.edges.length > 0; + const denom = Math.sqrt(normA) * Math.sqrt(normB); + return denom === 0 ? null : dot / denom; } /** - * @param {string} [apiBase] — override for mirrors/tests. - * @param {{ timeoutMs?: number }} [opts] + * @param {{ run: (model: string, input: { text: string[] }) => Promise<{ data: number[][] }> }} ai + * — Workers AI binding (`env.AI`). Tests pass a fake with the same `.run()` shape. + * @param {{ model?: string }} [opts] */ -export function createClient(apiBase = DEFAULT_API_BASE, { timeoutMs = DEFAULT_TIMEOUT_MS } = {}) { - /** @param {string} term */ - function concept(term) { - const cn = toConceptTerm(term); - return fetchJson(buildUrl(apiBase, `/c/${LANG}/${encodeURIComponent(cn)}`), timeoutMs); +export function createClient(ai, { model = DEFAULT_MODEL } = {}) { + if (!ai || typeof ai.run !== "function") { + throw new TypeError("createClient: ai binding with .run(model, input) is required"); } - /** @param {string} a @param {string} b */ - function relatedness(a, b) { - return fetchJson( - buildUrl(apiBase, "/relatedness", { - node1: `/c/${LANG}/${toConceptTerm(a)}`, - node2: `/c/${LANG}/${toConceptTerm(b)}`, - }), - timeoutMs, - ); + async function embedPair(a, b) { + let resp; + try { + resp = await ai.run(model, { text: [a, b] }); + } catch (err) { + throw new UpstreamError("workers-ai embedding failed", { cause: err }); + } + const data = resp?.data; + if (!Array.isArray(data) || data.length < 2) { + throw new UpstreamError("workers-ai returned malformed embedding payload"); + } + return [data[0], data[1]]; } return { - concept, - relatedness, - /** - * Pick a target word from the local pool, verify via the concept endpoint, - * fall back to an unverified pick after a few misses. + * Pick a target word from the local Vietnamese pool. The pool IS the + * vocabulary, so every pick is trivially verified. * @returns {Promise<{ word: string, verified: boolean }>} */ async randomWord() { - for (let i = 0; i < MAX_RANDOM_ATTEMPTS; i++) { - const candidate = randomLine(); - try { - const c = await concept(candidate); - if (hasEdges(c)) return { word: candidate, verified: true }; - } catch { - // swallow — try the next candidate - } - } - return { word: randomLine(), verified: false }; + return { word: randomLine(), verified: true }; }, /** - * Cosine-like similarity. Runs concept edge-check for `b` in parallel - * with the relatedness call so OOV guesses are caught on the same round-trip. - * Shape mirrors the semantle sibling. + * Cosine similarity between `a` (target) and `b` (guess). Uses the local + * Vietnamese wordlist as vocabulary — unknown words return + * `in_vocab_b: false` with `similarity: null` and skip inference. + * * @param {string} a * @param {string} b */ async similarity(a, b) { - const [conceptB, rel] = await Promise.all([concept(b), relatedness(a, b)]); - const inVocabB = hasEdges(conceptB); - const value = typeof rel?.value === "number" ? rel.value : null; - return { - a, - b, - canonical_a: a, - canonical_b: b, - in_vocab_a: true, - in_vocab_b: inVocabB, - similarity: inVocabB ? value : null, - }; + const base = { a, b, canonical_a: a, canonical_b: b, in_vocab_a: true }; + if (!VOCAB.has(b)) { + return { ...base, in_vocab_b: false, similarity: null }; + } + const [vecA, vecB] = await embedPair(a, b); + const sim = cosineSimilarity(vecA, vecB); + return { ...base, in_vocab_b: true, similarity: sim }; }, }; } diff --git a/src/modules/doantu/index.js b/src/modules/doantu/index.js index 0d08f29..07fdc9a 100644 --- a/src/modules/doantu/index.js +++ b/src/modules/doantu/index.js @@ -1,10 +1,11 @@ /** * @file Doantu module — Vietnamese semantle. * - * Targets from a curated local wordlist (duyet/vietnamese-wordlist Viet22K); - * similarity scores from api.conceptnet.io's `/relatedness` endpoint against - * `/c/vi/` concept URIs. All commands are `protected` — listed in - * /help but hidden from Telegram's native / autocomplete menu. + * Targets from a curated local wordlist (duyet/vietnamese-wordlist Viet22K — + * same list doubles as the vocabulary for OOV detection). Similarity scores + * come from cosine distance between `@cf/baai/bge-m3` multilingual embeddings + * produced by the `env.AI` binding. All commands are `protected` — listed + * in /help but hidden from Telegram's native / autocomplete menu. */ import { createClient } from "./api-client.js"; @@ -18,9 +19,9 @@ let client = null; /** @type {import("../registry.js").BotModule} */ const doantuModule = { name: "doantu", - init: async ({ db: store }) => { + init: async ({ db: store, env }) => { db = store; - client = createClient(); + client = createClient(env.AI); }, commands: [ { diff --git a/tests/modules/doantu/api-client.test.js b/tests/modules/doantu/api-client.test.js new file mode 100644 index 0000000..7ff9577 --- /dev/null +++ b/tests/modules/doantu/api-client.test.js @@ -0,0 +1,144 @@ +import { describe, expect, it, vi } from "vitest"; +import { UpstreamError, createClient } from "../../../src/modules/doantu/api-client.js"; + +/** + * Build a deterministic 1024-dim vector from a seed so cosine scores are + * reproducible in tests without hardcoding floats. bge-m3 produces 1024-dim + * vectors; tests use the same width for realism. + */ +function fakeVector(seed, dim = 1024) { + const out = new Array(dim); + for (let i = 0; i < dim; i++) out[i] = Math.sin(seed * (i + 1)); + return out; +} + +/** + * Minimal Workers AI binding fake. `impl(model, input)` returns the payload + * `env.AI.run()` would normally resolve to. + */ +function fakeAi(impl) { + return { run: vi.fn(impl) }; +} + +describe("doantu/api-client", () => { + describe("UpstreamError", () => { + it("stores status and body metadata", () => { + const err = new UpstreamError("test", { status: 404, body: "not found" }); + expect(err.message).toBe("test"); + expect(err.status).toBe(404); + expect(err.body).toBe("not found"); + expect(err.name).toBe("UpstreamError"); + }); + + it("stores cause when provided", () => { + const cause = new Error("underlying"); + const err = new UpstreamError("wrapper", { cause }); + expect(err.cause).toBe(cause); + }); + }); + + describe("createClient", () => { + it("throws without a valid AI binding", () => { + expect(() => createClient(null)).toThrow(TypeError); + expect(() => createClient({})).toThrow(TypeError); + expect(() => createClient({ run: "not a function" })).toThrow(TypeError); + }); + + it("similarity batches target + guess in a single run() call with bge-m3", async () => { + const ai = fakeAi(async (_model, { text }) => ({ + shape: [text.length, 1024], + data: text.map((_, i) => fakeVector(i + 1)), + })); + const client = createClient(ai); + await client.similarity("chó", "mèo"); + expect(ai.run).toHaveBeenCalledTimes(1); + const [model, input] = ai.run.mock.calls[0]; + expect(model).toBe("@cf/baai/bge-m3"); + expect(input).toEqual({ text: ["chó", "mèo"] }); + }); + + it("similarity returns cosine score for an in-vocab Vietnamese guess", async () => { + const ai = fakeAi(async (_model, { text }) => ({ + data: text.map((_, i) => fakeVector(i + 1)), + })); + const client = createClient(ai); + const res = await client.similarity("chó", "mèo"); + expect(res.in_vocab_a).toBe(true); + expect(res.in_vocab_b).toBe(true); + expect(res.canonical_a).toBe("chó"); + expect(res.canonical_b).toBe("mèo"); + expect(typeof res.similarity).toBe("number"); + expect(res.similarity).toBeGreaterThan(-1); + expect(res.similarity).toBeLessThanOrEqual(1); + }); + + it('similarity accepts multi-syllable Vietnamese words in vocab ("a dua")', async () => { + const ai = fakeAi(async () => ({ data: [fakeVector(1), fakeVector(2)] })); + const client = createClient(ai); + const res = await client.similarity("chó", "a dua"); + expect(res.in_vocab_b).toBe(true); + expect(res.similarity).not.toBeNull(); + }); + + it("similarity returns 1 for identical vectors", async () => { + const vec = fakeVector(7); + const ai = fakeAi(async () => ({ data: [vec, vec] })); + const client = createClient(ai); + const res = await client.similarity("chó", "mèo"); + expect(res.similarity).toBeCloseTo(1, 10); + }); + + it("similarity skips AI call for OOV guess and flags in_vocab_b:false", async () => { + const ai = fakeAi(async () => ({ data: [fakeVector(1), fakeVector(2)] })); + const client = createClient(ai); + const res = await client.similarity("chó", "zzzkhôngcótrongtừđiển"); + expect(res.in_vocab_b).toBe(false); + expect(res.similarity).toBe(null); + expect(ai.run).not.toHaveBeenCalled(); + }); + + it("similarity wraps AI.run rejection as UpstreamError", async () => { + const ai = fakeAi(async () => { + throw new Error("boom"); + }); + const client = createClient(ai); + await expect(client.similarity("chó", "mèo")).rejects.toMatchObject({ + name: "UpstreamError", + }); + }); + + it("similarity throws UpstreamError on malformed payload", async () => { + const ai = fakeAi(async () => ({ data: [fakeVector(1)] })); + const client = createClient(ai); + await expect(client.similarity("chó", "mèo")).rejects.toMatchObject({ + name: "UpstreamError", + }); + }); + + it("similarity returns null score when a vector norm is zero", async () => { + const zero = new Array(1024).fill(0); + const ai = fakeAi(async () => ({ data: [zero, fakeVector(1)] })); + const client = createClient(ai); + const res = await client.similarity("chó", "mèo"); + expect(res.in_vocab_b).toBe(true); + expect(res.similarity).toBe(null); + }); + + it("randomWord returns a verified pick from the local pool", async () => { + const ai = fakeAi(async () => ({ data: [] })); + const client = createClient(ai); + const res = await client.randomWord(); + expect(typeof res.word).toBe("string"); + expect(res.word.length).toBeGreaterThan(0); + expect(res.verified).toBe(true); + expect(ai.run).not.toHaveBeenCalled(); + }); + + it("supports model override via options", async () => { + const ai = fakeAi(async () => ({ data: [fakeVector(1), fakeVector(2)] })); + const client = createClient(ai, { model: "@cf/baai/bge-large-en-v1.5" }); + await client.similarity("chó", "mèo"); + expect(ai.run.mock.calls[0][0]).toBe("@cf/baai/bge-large-en-v1.5"); + }); + }); +});