""" Utility to backfill channel titles/names inside the Qdrant payloads. Usage: python -m python_app.sync_qdrant_channels \ --batch-size 512 \ --max-batches 200 \ --dry-run """ from __future__ import annotations import argparse import logging from typing import Dict, Iterable, List, Optional, Set, Tuple import time import requests from .config import CONFIG from .search_app import _ensure_client LOGGER = logging.getLogger(__name__) def chunked(iterable: Iterable, size: int): chunk: List = [] for item in iterable: chunk.append(item) if len(chunk) >= size: yield chunk chunk = [] if chunk: yield chunk def resolve_channels(channel_ids: Iterable[str]) -> Dict[str, str]: client = _ensure_client(CONFIG) ids = list(set(channel_ids)) if not ids: return {} body = { "size": len(ids) * 2, "_source": ["channel_id", "channel_name"], "query": {"terms": {"channel_id.keyword": ids}}, } response = client.search(index=CONFIG.elastic.index, body=body) resolved: Dict[str, str] = {} for hit in response.get("hits", {}).get("hits", []): source = hit.get("_source") or {} cid = source.get("channel_id") cname = source.get("channel_name") if cid and cname and cid not in resolved: resolved[cid] = cname return resolved def upsert_channel_payload( qdrant_url: str, collection: str, channel_id: str, channel_name: str, *, dry_run: bool = False, ) -> bool: """Set channel_name/channel_title for all vectors with this channel_id.""" payload = {"channel_name": channel_name, "channel_title": channel_name} body = { "payload": payload, "filter": {"must": [{"key": "channel_id", "match": {"value": channel_id}}]}, } LOGGER.info("Updating channel_id=%s -> %s", channel_id, channel_name) if dry_run: return True resp = requests.post( f"{qdrant_url}/collections/{collection}/points/payload", json=body, timeout=120, ) if resp.status_code >= 400: LOGGER.error("Failed to update %s: %s", channel_id, resp.text) return False return True def scroll_missing_payloads( qdrant_url: str, collection: str, batch_size: int, *, max_points: Optional[int] = None, ) -> Iterable[List[Tuple[str, Dict[str, any]]]]: """Yield batches of (point_id, payload) missing channel names.""" fetched = 0 next_page = None while True: current_limit = batch_size while True: body = { "limit": current_limit, "with_payload": True, "filter": {"must": [{"is_empty": {"key": "channel_name"}}]}, } if next_page: body["offset"] = next_page try: resp = requests.post( f"{qdrant_url}/collections/{collection}/points/scroll", json=body, timeout=120, ) resp.raise_for_status() break except requests.HTTPError as exc: LOGGER.warning( "Scroll request failed at limit=%s: %s", current_limit, exc ) if current_limit <= 5: raise current_limit = max(5, current_limit // 2) LOGGER.info("Reducing scroll batch size to %s", current_limit) time.sleep(2) except requests.RequestException as exc: # type: ignore[attr-defined] LOGGER.warning("Transient scroll error: %s", exc) time.sleep(2) payload = resp.json().get("result", {}) points = payload.get("points", []) if not points: break batch: List[Tuple[str, Dict[str, any]]] = [] for point in points: pid = point.get("id") p_payload = point.get("payload") or {} batch.append((pid, p_payload)) yield batch fetched += len(points) if max_points and fetched >= max_points: break next_page = payload.get("next_page_offset") if not next_page: break def main() -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") parser = argparse.ArgumentParser( description="Backfill missing channel_name/channel_title in Qdrant payloads" ) parser.add_argument("--batch-size", type=int, default=512) parser.add_argument( "--max-points", type=int, default=None, help="Limit processing to the first N points for testing", ) parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() q_url = CONFIG.qdrant_url collection = CONFIG.qdrant_collection total_updates = 0 for batch in scroll_missing_payloads( q_url, collection, args.batch_size, max_points=args.max_points ): channel_ids: Set[str] = set() for _, payload in batch: cid = payload.get("channel_id") if cid: channel_ids.add(str(cid)) if not channel_ids: continue resolved = resolve_channels(channel_ids) if not resolved: LOGGER.warning("No channel names resolved for ids: %s", channel_ids) continue for cid, name in resolved.items(): if upsert_channel_payload( q_url, collection, cid, name, dry_run=args.dry_run ): total_updates += 1 LOGGER.info("Updated %s channel payloads so far", total_updates) LOGGER.info("Finished. Total channel updates attempted: %s", total_updates) if __name__ == "__main__": main()