Files

546 lines
21 KiB
Python

from typing import Any, Dict, List
from PIL import Image
from io import BytesIO
import requests
import numpy as np
from PIL import Image
import torch
import replicate
from .composition_utils import *
try:
from replicate.helpers import FileOutput as ReplicateFileOutput
except Exception:
ReplicateFileOutput = None
# ------------------------------
# Model versions
# ------------------------------
REPLICATE_MODEL = "adirik/grounding-dino:efd10a8ddc57ea28773327e881ce95e20cc1d734c589f7dd01d2036921ed78aa"
REPLICATE_SAM_MODEL = "meta/sam-2:fe97b453a6455861e3bac769b441ca1f1086110da7466dbb65cf1eecfd60dc83"
# ------------------------------
# Detector
# ------------------------------
class ReplicateGroundingDetector:
def __init__(self, model_version: str = REPLICATE_MODEL, box_threshold: float = 0.2, text_threshold: float = 0.2):
self.model_version = model_version
self.box_threshold = box_threshold
self.text_threshold = text_threshold
def __call__(self, image: Image.Image, candidate_labels: List[str], threshold: float = None) -> List[Dict[str, Any]]:
import base64
# Prepare query
labels = [l if l.endswith(".") else (l + ".") for l in candidate_labels]
query = " ".join(labels)
# OPTIMIZATION: Downscale image before uploading to reduce payload size
max_side = 1280 # Replicate's GroundingDINO works well at this resolution
W, H = image.size
if max(W, H) > max_side:
scale = max_side / max(W, H)
new_size = (int(W * scale), int(H * scale))
image_upload = image.resize(new_size, Image.BILINEAR)
print(f"[GroundingDINO] Downscaling for upload: {(W, H)} -> {new_size}")
else:
image_upload = image
scale = 1.0
# Convert image to base64 data URI with JPEG (smaller than PNG)
buffered = BytesIO()
image_upload.save(buffered, format="JPEG", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
data_uri = f"data:image/jpeg;base64,{img_str}"
# Replicate input
inputs = {
"image": data_uri,
"query": query,
"box_threshold": threshold if threshold is not None else self.box_threshold,
"text_threshold": threshold if threshold is not None else self.text_threshold,
}
print(f"[Replicate][GroundingDINO] model={self.model_version}\n query=\"{query}\"\n box_threshold={inputs['box_threshold']} text_threshold={inputs['text_threshold']}")
print(f"[Replicate][GroundingDINO] payload size: {len(img_str) / 1024:.1f} KB")
# Run model with extended timeout
try:
out = replicate.run(self.model_version, input=inputs)
except Exception as e:
print(f"[Replicate][GroundingDINO] Error: {e}")
raise
# Parse detections and scale boxes back to original size
detections = out.get("detections", [])
print(f"[Replicate][GroundingDINO] raw detections: {len(detections)}")
results = []
for d in detections:
bbox = d["bbox"]
item = {
"score": d.get("score", 0.0),
"label": d.get("label", ""),
"box": {
"xmin": int(bbox[0] / scale),
"ymin": int(bbox[1] / scale),
"xmax": int(bbox[2] / scale),
"ymax": int(bbox[3] / scale),
},
}
results.append(item)
if results:
print("[Replicate][GroundingDINO] parsed detections:")
for r in results:
b = r["box"]
print(f" - {r['label']} score={r['score']:.2f} box=[{b['xmin']:.1f},{b['ymin']:.1f},{b['xmax']:.1f},{b['ymax']:.1f}]")
return results
def download_mask(url):
resp = requests.get(url)
img = Image.open(BytesIO(resp.content))
# Prefer alpha channel if present (many mask PNGs encode mask in alpha)
if "A" in img.getbands():
mask_pil = img.getchannel("A")
else:
mask_pil = img.convert("L")
mask_np = np.array(mask_pil)
return mask_np
# ------------------------------
# Cloud SAM wrapper
# ------------------------------
def replicate_sam(image_file, boxes=None, **kwargs):
"""
Cloud-based SAM segmentation using Replicate.
Accepts a local file-like object (BytesIO or open file).
If `boxes` is provided, attempt box-prompted segmentation (xyxy in pixel coords).
"""
# OPTIMIZED: Much lighter configuration to prevent timeouts
inputs = {
"image": image_file,
"use_m2m": False,
# CRITICAL: Reduce points_per_side dramatically (default is 32!)
"points_per_side": 4, # Very aggressive reduction
# Stricter thresholds = fewer masks
"pred_iou_thresh": 0.90,
"stability_score_thresh": 0.92,
# Additional optimizations
"crop_n_layers": 0, # Disable crop-based refinement
"crop_n_points_downscale_factor": 2,
}
# Try common box prompt field names used by Replicate SAM variants
if boxes is not None:
inputs["bboxes"] = boxes
inputs["input_boxes"] = boxes
inputs["boxes"] = boxes
inputs["box_format"] = "xyxy"
inputs["return_individual_masks"] = True
inputs.update(kwargs)
print(f"[Replicate][SAM] model={REPLICATE_SAM_MODEL} calling with keys={list(inputs.keys())}")
try:
out = replicate.run(REPLICATE_SAM_MODEL, input=inputs)
try:
if isinstance(out, dict):
print(f"[Replicate][SAM] returned type={type(out)} keys={list(out.keys())}")
elif isinstance(out, (list, tuple)):
print(f"[Replicate][SAM] returned {len(out)} items")
else:
print(f"[Replicate][SAM] returned type={type(out)}")
except Exception as e:
print(f"[Replicate][SAM] logging error: {e}")
return out
except Exception as e:
print(f"[Replicate][SAM] Error: {e}")
raise
# ------------------------------
# Initialize models
# ------------------------------
def init_models():
"""
Initialize the object detector (Replicate Grounding DINO)
and cloud-based SAM segmenter (Replicate SAM).
"""
# No local device needed; everything is on Replicate
object_detector = ReplicateGroundingDetector(
model_version=REPLICATE_MODEL,
box_threshold=0.2,
text_threshold=0.2,
)
# Cloud SAM: just a function reference
segmentator = replicate_sam
processor = None # not needed for cloud SAM
print("✅ Initialized: Using Replicate for detection and segmentation")
print(f" - Detector: {REPLICATE_MODEL}")
print(f" - SAM: {REPLICATE_SAM_MODEL}")
return object_detector, segmentator, processor
# ------------------------------
# Detection pipeline
# ------------------------------
def detect(
image: Image.Image,
labels: List[str],
detector: ReplicateGroundingDetector,
threshold: float = 0.3,
) -> List[Dict[str, Any]]:
"""
Detect objects with Replicate Grounding DINO and filter overly large boxes.
"""
print(f"[Detect] labels={labels} threshold={threshold}")
raw_results = detector(image, candidate_labels=labels, threshold=threshold)
image_area = image.size[0] * image.size[1]
filtered_results = []
for r in raw_results:
xmin, ymin, xmax, ymax = r["box"]["xmin"], r["box"]["ymin"], r["box"]["xmax"], r["box"]["ymax"]
box_area = (xmax - xmin) * (ymax - ymin)
if image_area > 0 and (box_area / image_area) < 0.8:
filtered_results.append(DetectionResult.from_dict(r))
print(f"[Detect] raw={len(raw_results)} filtered={len(filtered_results)}")
for d in filtered_results:
print(f"[Detect] keep {d.label} score={d.score:.2f} box={d.box.xyxy}")
return filtered_results
# ------------------------------
# Segmentation pipeline
# ------------------------------
def _parse_sam_output(out):
"""Normalize various possible Replicate SAM outputs to a list."""
if isinstance(out, dict):
print(f"[Segment] SAM dict keys: {list(out.keys())}")
# Special-case common schema from meta/sam-2 on Replicate
if "combined_mask" in out:
ims = out.get("individual_masks")
parsed = []
# prefer individual masks first
if isinstance(ims, list):
for m in ims:
if isinstance(m, dict) and "mask" in m:
parsed.append(m["mask"]) # unwrap inner mask field
else:
parsed.append(m)
# then include combined if present
cm = out.get("combined_mask")
if isinstance(cm, dict) and "mask" in cm:
parsed.append(cm["mask"]) # unwrap
elif cm is not None:
parsed.append(cm)
return parsed
for key in ["masks", "mask", "segments", "segmentations", "output", "data"]:
if key in out:
return out[key] if isinstance(out[key], list) else [out[key]]
# fallback: first list-like value
for v in out.values():
if isinstance(v, list):
return v
return []
if isinstance(out, (list, tuple)):
return list(out)
return [out]
def _coerce_mask_to_numpy(mask_data, target_hw):
"""
Convert mask outputs (url, data-uri, PIL, ndarray, dict) to a HxW uint8 binary numpy array (0 or 255).
target_hw = (H, W) of the original image; masks will be resized to this.
"""
import base64
H, W = target_hw
def _ensure_size_u8(m):
# squeeze channel if needed
if m.ndim == 3:
m = m[..., 0]
if m.shape != (H, W):
pil = Image.fromarray(m)
pil = pil.resize((W, H), resample=Image.NEAREST)
m = np.array(pil)
# normalize to uint8 binary
if m.dtype != np.uint8:
if m.max() <= 1.0:
m = (m.astype(np.float32) * 255.0).astype(np.uint8)
else:
m = m.astype(np.uint8)
# Many mask PNGs encode binary mask with alpha 0/255; use >0 to be robust
m = (m > 0).astype(np.uint8) * 255
return m
# numpy array
if isinstance(mask_data, np.ndarray):
return _ensure_size_u8(mask_data)
# PIL image
if isinstance(mask_data, Image.Image):
return _ensure_size_u8(np.array(mask_data.convert("L")))
# Replicate FileOutput (URL-like)
if ReplicateFileOutput is not None and isinstance(mask_data, ReplicateFileOutput):
try:
url = getattr(mask_data, "url", None)
if isinstance(url, str) and url.startswith("http"):
m = download_mask(url)
return _ensure_size_u8(m)
except Exception as e:
print(f"[Segment] Failed to read Replicate FileOutput: {e}")
return None
# string forms
if isinstance(mask_data, str):
s = mask_data.strip()
if s.startswith("http://") or s.startswith("https://"):
try:
m = download_mask(s) # already 0..255 grayscale
return _ensure_size_u8(m)
except Exception as e:
print(f"[Segment] Failed to download mask: {e}")
return None
if s.startswith("data:image"):
try:
_, b64 = s.split(",", 1)
img_bytes = base64.b64decode(b64)
img = Image.open(BytesIO(img_bytes))
# Prefer alpha channel if present
if "A" in img.getbands():
m = np.array(img.getchannel("A"))
else:
m = np.array(img.convert("L"))
return _ensure_size_u8(m)
except Exception as e:
print(f"[Segment] Failed to decode data URI mask: {e}")
return None
# Fallback: some providers return raw base64-encoded PNG without data URI prefix
try:
img_bytes = base64.b64decode(s)
img = Image.open(BytesIO(img_bytes))
if "A" in img.getbands():
m = np.array(img.getchannel("A"))
else:
m = np.array(img.convert("L"))
return _ensure_size_u8(m)
except Exception:
pass
print("[Segment] Unknown mask string format; skipping")
return None
# dict forms
if isinstance(mask_data, dict):
# quick recursive search for any url/data string
def _find_any_image_string(obj):
try:
if isinstance(obj, str) and (obj.startswith("http") or obj.startswith("data:image")):
return obj
if isinstance(obj, dict):
for vv in obj.values():
s = _find_any_image_string(vv)
if s:
return s
if isinstance(obj, (list, tuple)):
for vv in obj:
s = _find_any_image_string(vv)
if s:
return s
except Exception:
pass
return None
# Handle COCO RLE formats (uncompressed only)
def _decode_coco_rle(rle_obj):
try:
if isinstance(rle_obj, dict) and isinstance(rle_obj.get("counts"), list) and "size" in rle_obj:
counts = rle_obj["counts"]
Hh, Ww = rle_obj["size"]
flat = np.zeros(Hh * Ww, dtype=np.uint8)
idx = 0
val = 0
for c in counts:
end = idx + int(c)
flat[idx:end] = val
idx = end
val = 255 - val
return flat.reshape((Hh, Ww))
# Compressed RLE (string counts) not supported without pycocotools
return None
except Exception as e:
print(f"[Segment] Failed to decode RLE: {e}")
return None
# Top-level RLE
if "rle" in mask_data and isinstance(mask_data["rle"], (dict,)):
decoded = _decode_coco_rle(mask_data["rle"])
if decoded is not None:
return _ensure_size_u8(decoded)
# Some schemas put counts/size at top-level
if "counts" in mask_data and "size" in mask_data:
decoded = _decode_coco_rle({"counts": mask_data["counts"], "size": mask_data["size"]})
if decoded is not None:
return _ensure_size_u8(decoded)
# Try common fields carrying URLs or data URIs
for k in ["mask", "url", "image", "overlay", "png", "combined_mask"]:
v = mask_data.get(k)
if isinstance(v, str):
return _coerce_mask_to_numpy(v, target_hw)
if isinstance(v, dict):
# nested dict possibly with url or data
for kk in ["url", "image", "overlay", "data", "png"]:
vv = v.get(kk)
if isinstance(vv, str):
return _coerce_mask_to_numpy(vv, target_hw)
# nested numeric array
if isinstance(vv, (list, tuple)):
arr = np.array(vv)
if arr.ndim >= 2:
return _ensure_size_u8(arr)
# If dict has numeric array directly under known keys
for k in ["data", "array", "segmentation", "mask_array"]:
v = mask_data.get(k)
if isinstance(v, (list, tuple)):
arr = np.array(v)
if arr.ndim >= 2:
return _ensure_size_u8(arr)
# If dict has a single string value somewhere, try it
for v in mask_data.values():
if isinstance(v, str):
return _coerce_mask_to_numpy(v, target_hw)
# Final attempt: recursively search any nested url or data-uri string
s_any = _find_any_image_string(mask_data)
if s_any:
return _coerce_mask_to_numpy(s_any, target_hw)
print("[Segment] Unknown dict mask format; skipping")
return None
# list-of-lists (numeric mask)
if isinstance(mask_data, (list, tuple)):
try:
arr = np.array(mask_data)
if arr.ndim >= 2:
return _ensure_size_u8(arr)
except Exception:
pass
return None
# unknown
return None
def segment(
image: Image.Image,
detection_results: List[Any],
segmentator,
processor=None,
device=None,
polygon_refinement=False,
):
"""
Segment objects using cloud-based Replicate SAM.
OPTIMIZED: More aggressive downscaling and fallback strategies.
"""
W, H = image.size
# OPTIMIZATION: More aggressive downscaling
max_side_img = 768 # Reduced from 1024
scale_img = 1.0
image_for_sam = image
if max(W, H) > max_side_img:
scale_img = max_side_img / float(max(W, H))
new_size = (int(round(W * scale_img)), int(round(H * scale_img)))
image_for_sam = image.resize(new_size, Image.BILINEAR)
print(f"[Segment] Using downscaled image for SAM: {(W, H)} -> {new_size}")
# Run SAM once with timeout handling
buf_img = BytesIO()
image_for_sam.save(buf_img, format="PNG")
buf_img.seek(0)
try:
output = segmentator(buf_img)
parsed = _parse_sam_output(output)
print(f"[Segment] global SAM outputs={len(parsed)}")
# Coerce all masks to original size (H, W)
masks_np: List[np.ndarray] = []
for j, md in enumerate(parsed):
m = _coerce_mask_to_numpy(md, target_hw=(H, W))
if m is not None:
try:
nz = int((m > 0).sum())
print(f"[Segment] mask[{j}] nz={nz}")
except Exception:
pass
masks_np.append(m)
except Exception as e:
print(f"[Segment] SAM failed or timed out: {e}")
print("[Segment] Falling back to box-fill masks for all detections")
masks_np = []
results_with_masks = []
# Assign best-overlap mask to each detection
for idx, det in enumerate(detection_results):
box = det.box
xmin, ymin, xmax, ymax = map(int, [box.xmin, box.ymin, box.xmax, box.ymax])
xmin = max(0, min(xmin, W - 1))
xmax = max(0, min(xmax, W))
ymin = max(0, min(ymin, H - 1))
ymax = max(0, min(ymax, H))
if xmax <= xmin or ymax <= ymin:
print(f"[Segment] skip invalid box at idx {idx}: {(xmin, ymin, xmax, ymax)}")
results_with_masks.append(det)
continue
box_area = max(1, (xmax - xmin) * (ymax - ymin))
best_idx = -1
best_iou = 0.0
best_metrics = None
# Only try mask matching if we have masks
if masks_np:
for k, m in enumerate(masks_np):
mask_area = int((m > 0).sum())
# Skip masks that are basically full-frame (likely combined mask)
if mask_area / float(W * H) > 0.8:
continue
sub = m[ymin:ymax, xmin:xmax]
overlap = int((sub > 0).sum())
if overlap == 0:
continue
# IoU with the detection box region
iou = overlap / float(mask_area + box_area - overlap + 1e-6)
if iou > best_iou:
best_iou = iou
best_idx = k
best_metrics = (overlap, mask_area)
if best_idx >= 0 and best_iou > 0:
det.mask = masks_np[best_idx]
results_with_masks.append(det)
if idx < 5:
ov, ma = best_metrics if best_metrics else (0, 0)
print(f"[Segment] attach mask {best_idx} to det {idx}, overlap={ov}, mask_area={ma}, box_area={box_area}, iou={best_iou:.4f}")
else:
print(f"[Segment] no suitable mask for det {idx} — box fill fallback (box_area={box_area})")
full_mask = np.zeros((H, W), dtype=np.uint8)
full_mask[ymin:ymax, xmin:xmax] = 255
det.mask = full_mask
results_with_masks.append(det)
return results_with_masks
# ------------------------------
# Device helper
# ------------------------------
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
print("CUDA is available. Using GPU.")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print("MPS is available! Using Apple GPU.")
else:
device = torch.device("cpu")
print("Using CPU.")
return device