mirror of
https://github.com/tiennm99/ArtKrit.git
synced 2026-06-05 20:11:58 +00:00
546 lines
21 KiB
Python
546 lines
21 KiB
Python
from typing import Any, Dict, List
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import requests
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
|
|
import torch
|
|
import replicate
|
|
|
|
from .composition_utils import *
|
|
|
|
try:
|
|
from replicate.helpers import FileOutput as ReplicateFileOutput
|
|
except Exception:
|
|
ReplicateFileOutput = None
|
|
|
|
# ------------------------------
|
|
# Model versions
|
|
# ------------------------------
|
|
REPLICATE_MODEL = "adirik/grounding-dino:efd10a8ddc57ea28773327e881ce95e20cc1d734c589f7dd01d2036921ed78aa"
|
|
REPLICATE_SAM_MODEL = "meta/sam-2:fe97b453a6455861e3bac769b441ca1f1086110da7466dbb65cf1eecfd60dc83"
|
|
|
|
# ------------------------------
|
|
# Detector
|
|
# ------------------------------
|
|
class ReplicateGroundingDetector:
|
|
def __init__(self, model_version: str = REPLICATE_MODEL, box_threshold: float = 0.2, text_threshold: float = 0.2):
|
|
self.model_version = model_version
|
|
self.box_threshold = box_threshold
|
|
self.text_threshold = text_threshold
|
|
|
|
def __call__(self, image: Image.Image, candidate_labels: List[str], threshold: float = None) -> List[Dict[str, Any]]:
|
|
import base64
|
|
|
|
# Prepare query
|
|
labels = [l if l.endswith(".") else (l + ".") for l in candidate_labels]
|
|
query = " ".join(labels)
|
|
|
|
# OPTIMIZATION: Downscale image before uploading to reduce payload size
|
|
max_side = 1280 # Replicate's GroundingDINO works well at this resolution
|
|
W, H = image.size
|
|
if max(W, H) > max_side:
|
|
scale = max_side / max(W, H)
|
|
new_size = (int(W * scale), int(H * scale))
|
|
image_upload = image.resize(new_size, Image.BILINEAR)
|
|
print(f"[GroundingDINO] Downscaling for upload: {(W, H)} -> {new_size}")
|
|
else:
|
|
image_upload = image
|
|
scale = 1.0
|
|
|
|
# Convert image to base64 data URI with JPEG (smaller than PNG)
|
|
buffered = BytesIO()
|
|
image_upload.save(buffered, format="JPEG", quality=85)
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
data_uri = f"data:image/jpeg;base64,{img_str}"
|
|
|
|
# Replicate input
|
|
inputs = {
|
|
"image": data_uri,
|
|
"query": query,
|
|
"box_threshold": threshold if threshold is not None else self.box_threshold,
|
|
"text_threshold": threshold if threshold is not None else self.text_threshold,
|
|
}
|
|
|
|
print(f"[Replicate][GroundingDINO] model={self.model_version}\n query=\"{query}\"\n box_threshold={inputs['box_threshold']} text_threshold={inputs['text_threshold']}")
|
|
print(f"[Replicate][GroundingDINO] payload size: {len(img_str) / 1024:.1f} KB")
|
|
|
|
# Run model with extended timeout
|
|
try:
|
|
out = replicate.run(self.model_version, input=inputs)
|
|
except Exception as e:
|
|
print(f"[Replicate][GroundingDINO] Error: {e}")
|
|
raise
|
|
|
|
# Parse detections and scale boxes back to original size
|
|
detections = out.get("detections", [])
|
|
print(f"[Replicate][GroundingDINO] raw detections: {len(detections)}")
|
|
results = []
|
|
for d in detections:
|
|
bbox = d["bbox"]
|
|
item = {
|
|
"score": d.get("score", 0.0),
|
|
"label": d.get("label", ""),
|
|
"box": {
|
|
"xmin": int(bbox[0] / scale),
|
|
"ymin": int(bbox[1] / scale),
|
|
"xmax": int(bbox[2] / scale),
|
|
"ymax": int(bbox[3] / scale),
|
|
},
|
|
}
|
|
results.append(item)
|
|
if results:
|
|
print("[Replicate][GroundingDINO] parsed detections:")
|
|
for r in results:
|
|
b = r["box"]
|
|
print(f" - {r['label']} score={r['score']:.2f} box=[{b['xmin']:.1f},{b['ymin']:.1f},{b['xmax']:.1f},{b['ymax']:.1f}]")
|
|
return results
|
|
|
|
def download_mask(url):
|
|
resp = requests.get(url)
|
|
img = Image.open(BytesIO(resp.content))
|
|
# Prefer alpha channel if present (many mask PNGs encode mask in alpha)
|
|
if "A" in img.getbands():
|
|
mask_pil = img.getchannel("A")
|
|
else:
|
|
mask_pil = img.convert("L")
|
|
mask_np = np.array(mask_pil)
|
|
return mask_np
|
|
|
|
|
|
# ------------------------------
|
|
# Cloud SAM wrapper
|
|
# ------------------------------
|
|
def replicate_sam(image_file, boxes=None, **kwargs):
|
|
"""
|
|
Cloud-based SAM segmentation using Replicate.
|
|
Accepts a local file-like object (BytesIO or open file).
|
|
If `boxes` is provided, attempt box-prompted segmentation (xyxy in pixel coords).
|
|
"""
|
|
# OPTIMIZED: Much lighter configuration to prevent timeouts
|
|
inputs = {
|
|
"image": image_file,
|
|
"use_m2m": False,
|
|
# CRITICAL: Reduce points_per_side dramatically (default is 32!)
|
|
"points_per_side": 4, # Very aggressive reduction
|
|
# Stricter thresholds = fewer masks
|
|
"pred_iou_thresh": 0.90,
|
|
"stability_score_thresh": 0.92,
|
|
# Additional optimizations
|
|
"crop_n_layers": 0, # Disable crop-based refinement
|
|
"crop_n_points_downscale_factor": 2,
|
|
}
|
|
# Try common box prompt field names used by Replicate SAM variants
|
|
if boxes is not None:
|
|
inputs["bboxes"] = boxes
|
|
inputs["input_boxes"] = boxes
|
|
inputs["boxes"] = boxes
|
|
inputs["box_format"] = "xyxy"
|
|
inputs["return_individual_masks"] = True
|
|
inputs.update(kwargs)
|
|
print(f"[Replicate][SAM] model={REPLICATE_SAM_MODEL} calling with keys={list(inputs.keys())}")
|
|
|
|
try:
|
|
out = replicate.run(REPLICATE_SAM_MODEL, input=inputs)
|
|
try:
|
|
if isinstance(out, dict):
|
|
print(f"[Replicate][SAM] returned type={type(out)} keys={list(out.keys())}")
|
|
elif isinstance(out, (list, tuple)):
|
|
print(f"[Replicate][SAM] returned {len(out)} items")
|
|
else:
|
|
print(f"[Replicate][SAM] returned type={type(out)}")
|
|
except Exception as e:
|
|
print(f"[Replicate][SAM] logging error: {e}")
|
|
return out
|
|
except Exception as e:
|
|
print(f"[Replicate][SAM] Error: {e}")
|
|
raise
|
|
|
|
# ------------------------------
|
|
# Initialize models
|
|
# ------------------------------
|
|
def init_models():
|
|
"""
|
|
Initialize the object detector (Replicate Grounding DINO)
|
|
and cloud-based SAM segmenter (Replicate SAM).
|
|
"""
|
|
# No local device needed; everything is on Replicate
|
|
object_detector = ReplicateGroundingDetector(
|
|
model_version=REPLICATE_MODEL,
|
|
box_threshold=0.2,
|
|
text_threshold=0.2,
|
|
)
|
|
|
|
# Cloud SAM: just a function reference
|
|
segmentator = replicate_sam
|
|
processor = None # not needed for cloud SAM
|
|
|
|
print("✅ Initialized: Using Replicate for detection and segmentation")
|
|
print(f" - Detector: {REPLICATE_MODEL}")
|
|
print(f" - SAM: {REPLICATE_SAM_MODEL}")
|
|
return object_detector, segmentator, processor
|
|
|
|
# ------------------------------
|
|
# Detection pipeline
|
|
# ------------------------------
|
|
def detect(
|
|
image: Image.Image,
|
|
labels: List[str],
|
|
detector: ReplicateGroundingDetector,
|
|
threshold: float = 0.3,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect objects with Replicate Grounding DINO and filter overly large boxes.
|
|
"""
|
|
print(f"[Detect] labels={labels} threshold={threshold}")
|
|
raw_results = detector(image, candidate_labels=labels, threshold=threshold)
|
|
image_area = image.size[0] * image.size[1]
|
|
|
|
filtered_results = []
|
|
for r in raw_results:
|
|
xmin, ymin, xmax, ymax = r["box"]["xmin"], r["box"]["ymin"], r["box"]["xmax"], r["box"]["ymax"]
|
|
box_area = (xmax - xmin) * (ymax - ymin)
|
|
if image_area > 0 and (box_area / image_area) < 0.8:
|
|
filtered_results.append(DetectionResult.from_dict(r))
|
|
print(f"[Detect] raw={len(raw_results)} filtered={len(filtered_results)}")
|
|
for d in filtered_results:
|
|
print(f"[Detect] keep {d.label} score={d.score:.2f} box={d.box.xyxy}")
|
|
return filtered_results
|
|
|
|
# ------------------------------
|
|
# Segmentation pipeline
|
|
# ------------------------------
|
|
|
|
def _parse_sam_output(out):
|
|
"""Normalize various possible Replicate SAM outputs to a list."""
|
|
if isinstance(out, dict):
|
|
print(f"[Segment] SAM dict keys: {list(out.keys())}")
|
|
# Special-case common schema from meta/sam-2 on Replicate
|
|
if "combined_mask" in out:
|
|
ims = out.get("individual_masks")
|
|
parsed = []
|
|
# prefer individual masks first
|
|
if isinstance(ims, list):
|
|
for m in ims:
|
|
if isinstance(m, dict) and "mask" in m:
|
|
parsed.append(m["mask"]) # unwrap inner mask field
|
|
else:
|
|
parsed.append(m)
|
|
# then include combined if present
|
|
cm = out.get("combined_mask")
|
|
if isinstance(cm, dict) and "mask" in cm:
|
|
parsed.append(cm["mask"]) # unwrap
|
|
elif cm is not None:
|
|
parsed.append(cm)
|
|
return parsed
|
|
for key in ["masks", "mask", "segments", "segmentations", "output", "data"]:
|
|
if key in out:
|
|
return out[key] if isinstance(out[key], list) else [out[key]]
|
|
# fallback: first list-like value
|
|
for v in out.values():
|
|
if isinstance(v, list):
|
|
return v
|
|
return []
|
|
if isinstance(out, (list, tuple)):
|
|
return list(out)
|
|
return [out]
|
|
|
|
|
|
def _coerce_mask_to_numpy(mask_data, target_hw):
|
|
"""
|
|
Convert mask outputs (url, data-uri, PIL, ndarray, dict) to a HxW uint8 binary numpy array (0 or 255).
|
|
target_hw = (H, W) of the original image; masks will be resized to this.
|
|
"""
|
|
import base64
|
|
H, W = target_hw
|
|
|
|
def _ensure_size_u8(m):
|
|
# squeeze channel if needed
|
|
if m.ndim == 3:
|
|
m = m[..., 0]
|
|
if m.shape != (H, W):
|
|
pil = Image.fromarray(m)
|
|
pil = pil.resize((W, H), resample=Image.NEAREST)
|
|
m = np.array(pil)
|
|
# normalize to uint8 binary
|
|
if m.dtype != np.uint8:
|
|
if m.max() <= 1.0:
|
|
m = (m.astype(np.float32) * 255.0).astype(np.uint8)
|
|
else:
|
|
m = m.astype(np.uint8)
|
|
# Many mask PNGs encode binary mask with alpha 0/255; use >0 to be robust
|
|
m = (m > 0).astype(np.uint8) * 255
|
|
return m
|
|
|
|
# numpy array
|
|
if isinstance(mask_data, np.ndarray):
|
|
return _ensure_size_u8(mask_data)
|
|
# PIL image
|
|
if isinstance(mask_data, Image.Image):
|
|
return _ensure_size_u8(np.array(mask_data.convert("L")))
|
|
# Replicate FileOutput (URL-like)
|
|
if ReplicateFileOutput is not None and isinstance(mask_data, ReplicateFileOutput):
|
|
try:
|
|
url = getattr(mask_data, "url", None)
|
|
if isinstance(url, str) and url.startswith("http"):
|
|
m = download_mask(url)
|
|
return _ensure_size_u8(m)
|
|
except Exception as e:
|
|
print(f"[Segment] Failed to read Replicate FileOutput: {e}")
|
|
return None
|
|
# string forms
|
|
if isinstance(mask_data, str):
|
|
s = mask_data.strip()
|
|
if s.startswith("http://") or s.startswith("https://"):
|
|
try:
|
|
m = download_mask(s) # already 0..255 grayscale
|
|
return _ensure_size_u8(m)
|
|
except Exception as e:
|
|
print(f"[Segment] Failed to download mask: {e}")
|
|
return None
|
|
if s.startswith("data:image"):
|
|
try:
|
|
_, b64 = s.split(",", 1)
|
|
img_bytes = base64.b64decode(b64)
|
|
img = Image.open(BytesIO(img_bytes))
|
|
# Prefer alpha channel if present
|
|
if "A" in img.getbands():
|
|
m = np.array(img.getchannel("A"))
|
|
else:
|
|
m = np.array(img.convert("L"))
|
|
return _ensure_size_u8(m)
|
|
except Exception as e:
|
|
print(f"[Segment] Failed to decode data URI mask: {e}")
|
|
return None
|
|
# Fallback: some providers return raw base64-encoded PNG without data URI prefix
|
|
try:
|
|
img_bytes = base64.b64decode(s)
|
|
img = Image.open(BytesIO(img_bytes))
|
|
if "A" in img.getbands():
|
|
m = np.array(img.getchannel("A"))
|
|
else:
|
|
m = np.array(img.convert("L"))
|
|
return _ensure_size_u8(m)
|
|
except Exception:
|
|
pass
|
|
print("[Segment] Unknown mask string format; skipping")
|
|
return None
|
|
# dict forms
|
|
if isinstance(mask_data, dict):
|
|
# quick recursive search for any url/data string
|
|
def _find_any_image_string(obj):
|
|
try:
|
|
if isinstance(obj, str) and (obj.startswith("http") or obj.startswith("data:image")):
|
|
return obj
|
|
if isinstance(obj, dict):
|
|
for vv in obj.values():
|
|
s = _find_any_image_string(vv)
|
|
if s:
|
|
return s
|
|
if isinstance(obj, (list, tuple)):
|
|
for vv in obj:
|
|
s = _find_any_image_string(vv)
|
|
if s:
|
|
return s
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
# Handle COCO RLE formats (uncompressed only)
|
|
def _decode_coco_rle(rle_obj):
|
|
try:
|
|
if isinstance(rle_obj, dict) and isinstance(rle_obj.get("counts"), list) and "size" in rle_obj:
|
|
counts = rle_obj["counts"]
|
|
Hh, Ww = rle_obj["size"]
|
|
flat = np.zeros(Hh * Ww, dtype=np.uint8)
|
|
idx = 0
|
|
val = 0
|
|
for c in counts:
|
|
end = idx + int(c)
|
|
flat[idx:end] = val
|
|
idx = end
|
|
val = 255 - val
|
|
return flat.reshape((Hh, Ww))
|
|
# Compressed RLE (string counts) not supported without pycocotools
|
|
return None
|
|
except Exception as e:
|
|
print(f"[Segment] Failed to decode RLE: {e}")
|
|
return None
|
|
|
|
# Top-level RLE
|
|
if "rle" in mask_data and isinstance(mask_data["rle"], (dict,)):
|
|
decoded = _decode_coco_rle(mask_data["rle"])
|
|
if decoded is not None:
|
|
return _ensure_size_u8(decoded)
|
|
# Some schemas put counts/size at top-level
|
|
if "counts" in mask_data and "size" in mask_data:
|
|
decoded = _decode_coco_rle({"counts": mask_data["counts"], "size": mask_data["size"]})
|
|
if decoded is not None:
|
|
return _ensure_size_u8(decoded)
|
|
|
|
# Try common fields carrying URLs or data URIs
|
|
for k in ["mask", "url", "image", "overlay", "png", "combined_mask"]:
|
|
v = mask_data.get(k)
|
|
if isinstance(v, str):
|
|
return _coerce_mask_to_numpy(v, target_hw)
|
|
if isinstance(v, dict):
|
|
# nested dict possibly with url or data
|
|
for kk in ["url", "image", "overlay", "data", "png"]:
|
|
vv = v.get(kk)
|
|
if isinstance(vv, str):
|
|
return _coerce_mask_to_numpy(vv, target_hw)
|
|
# nested numeric array
|
|
if isinstance(vv, (list, tuple)):
|
|
arr = np.array(vv)
|
|
if arr.ndim >= 2:
|
|
return _ensure_size_u8(arr)
|
|
# If dict has numeric array directly under known keys
|
|
for k in ["data", "array", "segmentation", "mask_array"]:
|
|
v = mask_data.get(k)
|
|
if isinstance(v, (list, tuple)):
|
|
arr = np.array(v)
|
|
if arr.ndim >= 2:
|
|
return _ensure_size_u8(arr)
|
|
# If dict has a single string value somewhere, try it
|
|
for v in mask_data.values():
|
|
if isinstance(v, str):
|
|
return _coerce_mask_to_numpy(v, target_hw)
|
|
# Final attempt: recursively search any nested url or data-uri string
|
|
s_any = _find_any_image_string(mask_data)
|
|
if s_any:
|
|
return _coerce_mask_to_numpy(s_any, target_hw)
|
|
print("[Segment] Unknown dict mask format; skipping")
|
|
return None
|
|
# list-of-lists (numeric mask)
|
|
if isinstance(mask_data, (list, tuple)):
|
|
try:
|
|
arr = np.array(mask_data)
|
|
if arr.ndim >= 2:
|
|
return _ensure_size_u8(arr)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
# unknown
|
|
return None
|
|
|
|
|
|
def segment(
|
|
image: Image.Image,
|
|
detection_results: List[Any],
|
|
segmentator,
|
|
processor=None,
|
|
device=None,
|
|
polygon_refinement=False,
|
|
):
|
|
"""
|
|
Segment objects using cloud-based Replicate SAM.
|
|
OPTIMIZED: More aggressive downscaling and fallback strategies.
|
|
"""
|
|
W, H = image.size
|
|
|
|
# OPTIMIZATION: More aggressive downscaling
|
|
max_side_img = 768 # Reduced from 1024
|
|
scale_img = 1.0
|
|
image_for_sam = image
|
|
if max(W, H) > max_side_img:
|
|
scale_img = max_side_img / float(max(W, H))
|
|
new_size = (int(round(W * scale_img)), int(round(H * scale_img)))
|
|
image_for_sam = image.resize(new_size, Image.BILINEAR)
|
|
print(f"[Segment] Using downscaled image for SAM: {(W, H)} -> {new_size}")
|
|
|
|
# Run SAM once with timeout handling
|
|
buf_img = BytesIO()
|
|
image_for_sam.save(buf_img, format="PNG")
|
|
buf_img.seek(0)
|
|
|
|
try:
|
|
output = segmentator(buf_img)
|
|
parsed = _parse_sam_output(output)
|
|
print(f"[Segment] global SAM outputs={len(parsed)}")
|
|
|
|
# Coerce all masks to original size (H, W)
|
|
masks_np: List[np.ndarray] = []
|
|
for j, md in enumerate(parsed):
|
|
m = _coerce_mask_to_numpy(md, target_hw=(H, W))
|
|
if m is not None:
|
|
try:
|
|
nz = int((m > 0).sum())
|
|
print(f"[Segment] mask[{j}] nz={nz}")
|
|
except Exception:
|
|
pass
|
|
masks_np.append(m)
|
|
|
|
except Exception as e:
|
|
print(f"[Segment] SAM failed or timed out: {e}")
|
|
print("[Segment] Falling back to box-fill masks for all detections")
|
|
masks_np = []
|
|
|
|
results_with_masks = []
|
|
# Assign best-overlap mask to each detection
|
|
for idx, det in enumerate(detection_results):
|
|
box = det.box
|
|
xmin, ymin, xmax, ymax = map(int, [box.xmin, box.ymin, box.xmax, box.ymax])
|
|
xmin = max(0, min(xmin, W - 1))
|
|
xmax = max(0, min(xmax, W))
|
|
ymin = max(0, min(ymin, H - 1))
|
|
ymax = max(0, min(ymax, H))
|
|
if xmax <= xmin or ymax <= ymin:
|
|
print(f"[Segment] skip invalid box at idx {idx}: {(xmin, ymin, xmax, ymax)}")
|
|
results_with_masks.append(det)
|
|
continue
|
|
|
|
box_area = max(1, (xmax - xmin) * (ymax - ymin))
|
|
best_idx = -1
|
|
best_iou = 0.0
|
|
best_metrics = None
|
|
|
|
# Only try mask matching if we have masks
|
|
if masks_np:
|
|
for k, m in enumerate(masks_np):
|
|
mask_area = int((m > 0).sum())
|
|
# Skip masks that are basically full-frame (likely combined mask)
|
|
if mask_area / float(W * H) > 0.8:
|
|
continue
|
|
sub = m[ymin:ymax, xmin:xmax]
|
|
overlap = int((sub > 0).sum())
|
|
if overlap == 0:
|
|
continue
|
|
# IoU with the detection box region
|
|
iou = overlap / float(mask_area + box_area - overlap + 1e-6)
|
|
if iou > best_iou:
|
|
best_iou = iou
|
|
best_idx = k
|
|
best_metrics = (overlap, mask_area)
|
|
|
|
if best_idx >= 0 and best_iou > 0:
|
|
det.mask = masks_np[best_idx]
|
|
results_with_masks.append(det)
|
|
if idx < 5:
|
|
ov, ma = best_metrics if best_metrics else (0, 0)
|
|
print(f"[Segment] attach mask {best_idx} to det {idx}, overlap={ov}, mask_area={ma}, box_area={box_area}, iou={best_iou:.4f}")
|
|
else:
|
|
print(f"[Segment] no suitable mask for det {idx} — box fill fallback (box_area={box_area})")
|
|
full_mask = np.zeros((H, W), dtype=np.uint8)
|
|
full_mask[ymin:ymax, xmin:xmax] = 255
|
|
det.mask = full_mask
|
|
results_with_masks.append(det)
|
|
|
|
return results_with_masks
|
|
|
|
# ------------------------------
|
|
# Device helper
|
|
# ------------------------------
|
|
def get_device():
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
print("CUDA is available. Using GPU.")
|
|
elif torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
print("MPS is available! Using Apple GPU.")
|
|
else:
|
|
device = torch.device("cpu")
|
|
print("Using CPU.")
|
|
return device
|