TLC-Search/sync_qdrant_channels.py

189 lines
5.7 KiB
Python

"""
Utility to backfill channel titles/names inside the Qdrant payloads.
Usage:
python -m python_app.sync_qdrant_channels \
--batch-size 512 \
--max-batches 200 \
--dry-run
"""
from __future__ import annotations
import argparse
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
import time
import requests
from .config import CONFIG
from .search_app import _ensure_client
LOGGER = logging.getLogger(__name__)
def chunked(iterable: Iterable, size: int):
chunk: List = []
for item in iterable:
chunk.append(item)
if len(chunk) >= size:
yield chunk
chunk = []
if chunk:
yield chunk
def resolve_channels(channel_ids: Iterable[str]) -> Dict[str, str]:
client = _ensure_client(CONFIG)
ids = list(set(channel_ids))
if not ids:
return {}
body = {
"size": len(ids) * 2,
"_source": ["channel_id", "channel_name"],
"query": {"terms": {"channel_id.keyword": ids}},
}
response = client.search(index=CONFIG.elastic.index, body=body)
resolved: Dict[str, str] = {}
for hit in response.get("hits", {}).get("hits", []):
source = hit.get("_source") or {}
cid = source.get("channel_id")
cname = source.get("channel_name")
if cid and cname and cid not in resolved:
resolved[cid] = cname
return resolved
def upsert_channel_payload(
qdrant_url: str,
collection: str,
channel_id: str,
channel_name: str,
*,
dry_run: bool = False,
) -> bool:
"""Set channel_name/channel_title for all vectors with this channel_id."""
payload = {"channel_name": channel_name, "channel_title": channel_name}
body = {
"payload": payload,
"filter": {"must": [{"key": "channel_id", "match": {"value": channel_id}}]},
}
LOGGER.info("Updating channel_id=%s -> %s", channel_id, channel_name)
if dry_run:
return True
resp = requests.post(
f"{qdrant_url}/collections/{collection}/points/payload",
json=body,
timeout=120,
)
if resp.status_code >= 400:
LOGGER.error("Failed to update %s: %s", channel_id, resp.text)
return False
return True
def scroll_missing_payloads(
qdrant_url: str,
collection: str,
batch_size: int,
*,
max_points: Optional[int] = None,
) -> Iterable[List[Tuple[str, Dict[str, any]]]]:
"""Yield batches of (point_id, payload) missing channel names."""
fetched = 0
next_page = None
while True:
current_limit = batch_size
while True:
body = {
"limit": current_limit,
"with_payload": True,
"filter": {"must": [{"is_empty": {"key": "channel_name"}}]},
}
if next_page:
body["offset"] = next_page
try:
resp = requests.post(
f"{qdrant_url}/collections/{collection}/points/scroll",
json=body,
timeout=120,
)
resp.raise_for_status()
break
except requests.HTTPError as exc:
LOGGER.warning(
"Scroll request failed at limit=%s: %s", current_limit, exc
)
if current_limit <= 5:
raise
current_limit = max(5, current_limit // 2)
LOGGER.info("Reducing scroll batch size to %s", current_limit)
time.sleep(2)
except requests.RequestException as exc: # type: ignore[attr-defined]
LOGGER.warning("Transient scroll error: %s", exc)
time.sleep(2)
payload = resp.json().get("result", {})
points = payload.get("points", [])
if not points:
break
batch: List[Tuple[str, Dict[str, any]]] = []
for point in points:
pid = point.get("id")
p_payload = point.get("payload") or {}
batch.append((pid, p_payload))
yield batch
fetched += len(points)
if max_points and fetched >= max_points:
break
next_page = payload.get("next_page_offset")
if not next_page:
break
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
parser = argparse.ArgumentParser(
description="Backfill missing channel_name/channel_title in Qdrant payloads"
)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument(
"--max-points",
type=int,
default=None,
help="Limit processing to the first N points for testing",
)
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
q_url = CONFIG.qdrant_url
collection = CONFIG.qdrant_collection
total_updates = 0
for batch in scroll_missing_payloads(
q_url, collection, args.batch_size, max_points=args.max_points
):
channel_ids: Set[str] = set()
for _, payload in batch:
cid = payload.get("channel_id")
if cid:
channel_ids.add(str(cid))
if not channel_ids:
continue
resolved = resolve_channels(channel_ids)
if not resolved:
LOGGER.warning("No channel names resolved for ids: %s", channel_ids)
continue
for cid, name in resolved.items():
if upsert_channel_payload(
q_url, collection, cid, name, dry_run=args.dry_run
):
total_updates += 1
LOGGER.info("Updated %s channel payloads so far", total_updates)
LOGGER.info("Finished. Total channel updates attempted: %s", total_updates)
if __name__ == "__main__":
main()