Files
litellm/tests/ocr_tests/base_ocr_unit_tests.py
T
Yuneng Jiang 2f9519d286 [Fix] Tests: Reduce VCR cassette bloat and fix multipart caching
- Add `_strip_image_b64_payloads` filter: rewrites `data[*].b64_json` in
  image-gen responses to a 4-byte placeholder before the cassette is saved.
  Image-edit and image-gen cassettes (193 MB / 184 MB / 104 MB / ...) will
  shrink to <100 KB on next record. Tests assert response shape only, so
  coverage is preserved.
- Add `_normalize_multipart_boundary` filter: replaces httpx's per-request
  random multipart boundary with a fixed string in both Content-Type header
  and body bytes. Audio-transcription / Whisper tests have been effectively
  unmocked — every CI run hit live providers and was silently capped at
  MAX_EPISODES_PER_CASSETTE=50. Both record and replay now see identical
  bytes; the safe_body matcher works.
- Fix test_evals_api.py body poisoning: replace `int(time.time())` in eval
  names with `hashlib.sha1(test_node_name)[:12]`, add a function-scoped
  `managed_eval` fixture that creates and deletes the eval, and switch
  `get_eval` / `update_eval` from `list_evals().data[0].id` (which made
  the URL vary by run) to `managed_eval.id`. Net coverage gain: delete is
  now actually exercised.
- Swap arxiv PDF URL in BaseOCRTest for the in-repo `dummy.pdf` (589 B)
  served via sha-pinned jsdelivr.
- Swap etsystatic image URL in BaseLLMChatTest.test_image_url for the
  in-repo LiteLLM logo (9.2 KB) served via the same jsdelivr pin.
- Add `tests/llm_translation/test_vcr_filters.py` with 14 unit tests
  covering both new filters: replacement, idempotency, nesting, content-
  length update, two-distinct-boundaries-converge-after-normalize, etc.

Cassettes recorded with the prior patterns will mismatch on the first CI
run after merge; recommend flushing the cassette Redis once (post-merge)
so re-records save under the new format from the start.
2026-05-07 11:54:19 -07:00

204 lines
8.2 KiB
Python

"""
Base test class for OCR functionality across different providers.
This follows the same pattern as BaseLLMChatTest in tests/llm_translation/base_llm_unit_tests.py
"""
import pytest
import litellm
import os
from abc import ABC, abstractmethod
# Test resources
TEST_IMAGE_PATH = "test_image_edit.png"
# Tiny in-repo PDF served via jsdelivr (sha-pinned, immutable). The arxiv
# PDF previously used here was several MB — once base64-encoded into the
# Vertex OCR request it ballooned cassettes past 100 MB per test. Keep
# the URL stable across runs so cassettes don't churn.
TEST_PDF_URL = (
"https://cdn.jsdelivr.net/gh/BerriAI/litellm"
"@d769e81c90d453240c61fc572cdb27fae06a89d0"
"/tests/llm_translation/fixtures/dummy.pdf"
)
class BaseOCRTest(ABC):
"""
Abstract base test class that enforces common OCR tests across all providers.
Each provider-specific test class should inherit from this and implement
get_base_ocr_call_args() to return provider-specific configuration.
"""
@abstractmethod
def get_base_ocr_call_args(self) -> dict:
"""Must return the base OCR call args for the specific provider"""
pass
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_basic_ocr_with_url(self, sync_mode):
"""
Test basic OCR with a public URL.
"""
litellm._turn_on_debug()
base_ocr_call_args = self.get_base_ocr_call_args()
print("BASE OCR Call args=", base_ocr_call_args)
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
try:
if sync_mode:
response = litellm.ocr(
document={"type": "document_url", "document_url": TEST_PDF_URL},
**base_ocr_call_args,
)
else:
response = await litellm.aocr(
document={"type": "document_url", "document_url": TEST_PDF_URL},
**base_ocr_call_args,
)
print(f"\n{'='*80}")
print(f"Sync Mode: {sync_mode}")
print(f"Response type: {type(response)}")
print(
f"Response object: {response.object if hasattr(response, 'object') else 'N/A'}"
)
# Check if response has expected OCR format
assert hasattr(response, "pages"), "Response should have 'pages' attribute"
assert hasattr(response, "model"), "Response should have 'model' attribute"
assert hasattr(
response, "object"
), "Response should have 'object' attribute"
assert (
response.object == "ocr"
), f"Expected object='ocr', got '{response.object}'"
# Validate pages structure
assert isinstance(response.pages, list), "pages should be a list"
assert len(response.pages) > 0, "Should have at least one page"
# Check first page structure
first_page = response.pages[0]
assert hasattr(first_page, "index"), "Page should have 'index' attribute"
assert hasattr(
first_page, "markdown"
), "Page should have 'markdown' attribute"
# Extract text from all pages for validation
total_text = "\n\n".join(
page.markdown for page in response.pages if page.markdown
)
print(f"Total pages: {len(response.pages)}")
print(f"Total extracted text length: {len(total_text)} characters")
print(f"First 200 chars: {total_text[:200]}")
print(f"Model: {response.model}")
if response.usage_info:
print(f"Pages processed: {response.usage_info.pages_processed}")
print(f"{'='*80}\n")
assert len(total_text) > 0, "Should extract some text from the document"
#########################################################
# validate we get a response cost in hidden parameters
#########################################################
hidden_params = response._hidden_params
assert isinstance(
hidden_params, dict
), "Hidden parameters should be a dictionary"
print("response usage_info:", response.usage_info)
response_cost = hidden_params.get("response_cost")
assert (
response_cost is not None
), "Response cost should be in hidden parameters"
assert response_cost > 0, "Response cost should be greater than 0"
print("response_cost=", response_cost)
except litellm.RateLimitError as e:
error_msg = str(e)
if "Quota exceeded" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
pytest.skip(f"Quota exceeded - {error_msg}")
else:
pytest.skip(f"Rate limit exceeded - {error_msg}")
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
except litellm.BadRequestError as e:
error_msg = str(e)
if (
"URL_REJECTED" in error_msg
or "Cannot fetch content from the provided URL" in error_msg
):
pytest.skip(f"URL rejected by provider - {error_msg}")
else:
pytest.fail(f"OCR call failed: {str(e)}")
except Exception as e:
pytest.fail(f"OCR call failed: {str(e)}")
def test_ocr_response_structure(self):
"""
Test that the OCR response has the correct structure.
"""
litellm.set_verbose = True
base_ocr_call_args = self.get_base_ocr_call_args()
try:
response = litellm.ocr(
document={"type": "document_url", "document_url": TEST_PDF_URL},
**base_ocr_call_args,
)
# Validate response structure
assert hasattr(response, "pages"), "Response should have 'pages' attribute"
assert hasattr(response, "model"), "Response should have 'model' attribute"
assert hasattr(
response, "object"
), "Response should have 'object' attribute"
assert hasattr(
response, "usage_info"
), "Response should have 'usage_info' attribute"
assert isinstance(response.pages, list), "pages should be a list"
assert len(response.pages) > 0, "Should have at least one page"
assert response.object == "ocr", "object should be 'ocr'"
# Validate first page structure
first_page = response.pages[0]
assert hasattr(first_page, "index"), "Page should have 'index' attribute"
assert hasattr(
first_page, "markdown"
), "Page should have 'markdown' attribute"
assert isinstance(first_page.markdown, str), "markdown should be a string"
print(f"\nResponse structure validated:")
print(f" - object: {response.object}")
print(f" - model: {response.model}")
print(f" - pages: {len(response.pages)}")
if response.usage_info:
print(f" - pages_processed: {response.usage_info.pages_processed}")
print(f" - doc_size_bytes: {response.usage_info.doc_size_bytes}")
except litellm.RateLimitError as e:
error_msg = str(e)
if "Quota exceeded" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
pytest.skip(f"Quota exceeded - {error_msg}")
else:
pytest.skip(f"Rate limit exceeded - {error_msg}")
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
except litellm.BadRequestError as e:
error_msg = str(e)
if (
"URL_REJECTED" in error_msg
or "Cannot fetch content from the provided URL" in error_msg
):
pytest.skip(f"URL rejected by provider - {error_msg}")
else:
pytest.fail(f"OCR response structure test failed: {str(e)}")
except Exception as e:
pytest.fail(f"OCR response structure test failed: {str(e)}")