189 lines
5.7 KiB
Python
189 lines
5.7 KiB
Python
"""
|
|
Utility to backfill channel titles/names inside the Qdrant payloads.
|
|
|
|
Usage:
|
|
python -m python_app.sync_qdrant_channels \
|
|
--batch-size 512 \
|
|
--max-batches 200 \
|
|
--dry-run
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
|
import time
|
|
|
|
import requests
|
|
|
|
from .config import CONFIG
|
|
from .search_app import _ensure_client
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def chunked(iterable: Iterable, size: int):
|
|
chunk: List = []
|
|
for item in iterable:
|
|
chunk.append(item)
|
|
if len(chunk) >= size:
|
|
yield chunk
|
|
chunk = []
|
|
if chunk:
|
|
yield chunk
|
|
|
|
|
|
def resolve_channels(channel_ids: Iterable[str]) -> Dict[str, str]:
|
|
client = _ensure_client(CONFIG)
|
|
ids = list(set(channel_ids))
|
|
if not ids:
|
|
return {}
|
|
body = {
|
|
"size": len(ids) * 2,
|
|
"_source": ["channel_id", "channel_name"],
|
|
"query": {"terms": {"channel_id.keyword": ids}},
|
|
}
|
|
response = client.search(index=CONFIG.elastic.index, body=body)
|
|
resolved: Dict[str, str] = {}
|
|
for hit in response.get("hits", {}).get("hits", []):
|
|
source = hit.get("_source") or {}
|
|
cid = source.get("channel_id")
|
|
cname = source.get("channel_name")
|
|
if cid and cname and cid not in resolved:
|
|
resolved[cid] = cname
|
|
return resolved
|
|
|
|
|
|
def upsert_channel_payload(
|
|
qdrant_url: str,
|
|
collection: str,
|
|
channel_id: str,
|
|
channel_name: str,
|
|
*,
|
|
dry_run: bool = False,
|
|
) -> bool:
|
|
"""Set channel_name/channel_title for all vectors with this channel_id."""
|
|
payload = {"channel_name": channel_name, "channel_title": channel_name}
|
|
body = {
|
|
"payload": payload,
|
|
"filter": {"must": [{"key": "channel_id", "match": {"value": channel_id}}]},
|
|
}
|
|
LOGGER.info("Updating channel_id=%s -> %s", channel_id, channel_name)
|
|
if dry_run:
|
|
return True
|
|
resp = requests.post(
|
|
f"{qdrant_url}/collections/{collection}/points/payload",
|
|
json=body,
|
|
timeout=120,
|
|
)
|
|
if resp.status_code >= 400:
|
|
LOGGER.error("Failed to update %s: %s", channel_id, resp.text)
|
|
return False
|
|
return True
|
|
|
|
|
|
def scroll_missing_payloads(
|
|
qdrant_url: str,
|
|
collection: str,
|
|
batch_size: int,
|
|
*,
|
|
max_points: Optional[int] = None,
|
|
) -> Iterable[List[Tuple[str, Dict[str, any]]]]:
|
|
"""Yield batches of (point_id, payload) missing channel names."""
|
|
fetched = 0
|
|
next_page = None
|
|
while True:
|
|
current_limit = batch_size
|
|
while True:
|
|
body = {
|
|
"limit": current_limit,
|
|
"with_payload": True,
|
|
"filter": {"must": [{"is_empty": {"key": "channel_name"}}]},
|
|
}
|
|
if next_page:
|
|
body["offset"] = next_page
|
|
try:
|
|
resp = requests.post(
|
|
f"{qdrant_url}/collections/{collection}/points/scroll",
|
|
json=body,
|
|
timeout=120,
|
|
)
|
|
resp.raise_for_status()
|
|
break
|
|
except requests.HTTPError as exc:
|
|
LOGGER.warning(
|
|
"Scroll request failed at limit=%s: %s", current_limit, exc
|
|
)
|
|
if current_limit <= 5:
|
|
raise
|
|
current_limit = max(5, current_limit // 2)
|
|
LOGGER.info("Reducing scroll batch size to %s", current_limit)
|
|
time.sleep(2)
|
|
except requests.RequestException as exc: # type: ignore[attr-defined]
|
|
LOGGER.warning("Transient scroll error: %s", exc)
|
|
time.sleep(2)
|
|
payload = resp.json().get("result", {})
|
|
points = payload.get("points", [])
|
|
if not points:
|
|
break
|
|
batch: List[Tuple[str, Dict[str, any]]] = []
|
|
for point in points:
|
|
pid = point.get("id")
|
|
p_payload = point.get("payload") or {}
|
|
batch.append((pid, p_payload))
|
|
yield batch
|
|
fetched += len(points)
|
|
if max_points and fetched >= max_points:
|
|
break
|
|
next_page = payload.get("next_page_offset")
|
|
if not next_page:
|
|
break
|
|
|
|
|
|
def main() -> None:
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
|
parser = argparse.ArgumentParser(
|
|
description="Backfill missing channel_name/channel_title in Qdrant payloads"
|
|
)
|
|
parser.add_argument("--batch-size", type=int, default=512)
|
|
parser.add_argument(
|
|
"--max-points",
|
|
type=int,
|
|
default=None,
|
|
help="Limit processing to the first N points for testing",
|
|
)
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
q_url = CONFIG.qdrant_url
|
|
collection = CONFIG.qdrant_collection
|
|
total_updates = 0
|
|
|
|
for batch in scroll_missing_payloads(
|
|
q_url, collection, args.batch_size, max_points=args.max_points
|
|
):
|
|
channel_ids: Set[str] = set()
|
|
for _, payload in batch:
|
|
cid = payload.get("channel_id")
|
|
if cid:
|
|
channel_ids.add(str(cid))
|
|
if not channel_ids:
|
|
continue
|
|
resolved = resolve_channels(channel_ids)
|
|
if not resolved:
|
|
LOGGER.warning("No channel names resolved for ids: %s", channel_ids)
|
|
continue
|
|
for cid, name in resolved.items():
|
|
if upsert_channel_payload(
|
|
q_url, collection, cid, name, dry_run=args.dry_run
|
|
):
|
|
total_updates += 1
|
|
LOGGER.info("Updated %s channel payloads so far", total_updates)
|
|
|
|
LOGGER.info("Finished. Total channel updates attempted: %s", total_updates)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|