mirror of
https://github.com/tiennm99/word2sim.git
synced 2026-05-24 19:35:27 +00:00
2e3e61dcbb
Stateless FastAPI service exposing word2vec cosine similarity, nearest neighbors, vocab lookup, and random-word picker. Dockerized with gensim GoogleNews pretrained model support.
90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
"""Word2vec model loader and similarity primitives.
|
|
|
|
Process-wide KeyedVectors singleton; loaded lazily on first use.
|
|
Supports a gensim-downloader id (MODEL_NAME) or a local .bin file (MODEL_PATH).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import random as _random
|
|
import string
|
|
from typing import Optional
|
|
|
|
from gensim.models import KeyedVectors
|
|
|
|
_MODEL: Optional[KeyedVectors] = None
|
|
|
|
|
|
def load_model() -> KeyedVectors:
|
|
"""Return the singleton KeyedVectors, loading it on first call."""
|
|
global _MODEL
|
|
if _MODEL is not None:
|
|
return _MODEL
|
|
|
|
local_path = os.environ.get("MODEL_PATH")
|
|
if local_path and os.path.exists(local_path):
|
|
_MODEL = KeyedVectors.load_word2vec_format(local_path, binary=True)
|
|
return _MODEL
|
|
|
|
# Defer gensim.downloader import so MODEL_PATH users avoid the network path.
|
|
import gensim.downloader as api
|
|
|
|
model_name = os.environ.get("MODEL_NAME", "word2vec-google-news-300")
|
|
_MODEL = api.load(model_name)
|
|
return _MODEL
|
|
|
|
|
|
def canonicalize(kv: KeyedVectors, word: str) -> Optional[str]:
|
|
"""Resolve `word` to its in-vocab form, trying exact → lower → capitalized.
|
|
|
|
GoogleNews vectors are case-sensitive; this matches most user expectations
|
|
without forcing callers to know the casing conventions of the training corpus.
|
|
"""
|
|
for candidate in (word, word.lower(), word.capitalize()):
|
|
if candidate in kv:
|
|
return candidate
|
|
return None
|
|
|
|
|
|
def similarity(kv: KeyedVectors, a: str, b: str) -> float:
|
|
"""Cosine similarity between two in-vocab words. Caller must canonicalize first."""
|
|
return float(kv.similarity(a, b))
|
|
|
|
|
|
def neighbors(kv: KeyedVectors, word: str, topn: int) -> list[tuple[str, float]]:
|
|
"""Top-N nearest-neighbor words with cosine scores. Caller must canonicalize first."""
|
|
return [(w, float(s)) for w, s in kv.most_similar(word, topn=topn)]
|
|
|
|
|
|
def random_word(
|
|
kv: KeyedVectors,
|
|
*,
|
|
min_rank: int = 0,
|
|
max_rank: Optional[int] = None,
|
|
alpha_only: bool = True,
|
|
min_len: int = 1,
|
|
max_len: int = 64,
|
|
max_attempts: int = 1000,
|
|
) -> Optional[str]:
|
|
"""Return a random vocab word matching the filters, or None if no match within budget.
|
|
|
|
`index_to_key` is frequency-ordered for word2vec .bin files, so `min_rank`/`max_rank`
|
|
act as a frequency window — e.g. min_rank=100 skips the most common function words,
|
|
max_rank=50000 avoids the rare/noisy tail. `alpha_only=True` rejects phrases
|
|
(`new_york` has `_`), digits, and punctuation.
|
|
"""
|
|
vocab = kv.index_to_key
|
|
upper = min(max_rank, len(vocab)) if max_rank is not None else len(vocab)
|
|
if min_rank >= upper:
|
|
return None
|
|
allowed = set(string.ascii_letters) if alpha_only else None
|
|
for _ in range(max_attempts):
|
|
word = vocab[_random.randrange(min_rank, upper)]
|
|
if not (min_len <= len(word) <= max_len):
|
|
continue
|
|
if allowed is not None and not all(c in allowed for c in word):
|
|
continue
|
|
return word
|
|
return None
|