Add graph and vector search features
This commit is contained in:
188
sync_qdrant_channels.py
Normal file
188
sync_qdrant_channels.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Utility to backfill channel titles/names inside the Qdrant payloads.
|
||||
|
||||
Usage:
|
||||
python -m python_app.sync_qdrant_channels \
|
||||
--batch-size 512 \
|
||||
--max-batches 200 \
|
||||
--dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from .config import CONFIG
|
||||
from .search_app import _ensure_client
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunked(iterable: Iterable, size: int):
|
||||
chunk: List = []
|
||||
for item in iterable:
|
||||
chunk.append(item)
|
||||
if len(chunk) >= size:
|
||||
yield chunk
|
||||
chunk = []
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
|
||||
def resolve_channels(channel_ids: Iterable[str]) -> Dict[str, str]:
|
||||
client = _ensure_client(CONFIG)
|
||||
ids = list(set(channel_ids))
|
||||
if not ids:
|
||||
return {}
|
||||
body = {
|
||||
"size": len(ids) * 2,
|
||||
"_source": ["channel_id", "channel_name"],
|
||||
"query": {"terms": {"channel_id.keyword": ids}},
|
||||
}
|
||||
response = client.search(index=CONFIG.elastic.index, body=body)
|
||||
resolved: Dict[str, str] = {}
|
||||
for hit in response.get("hits", {}).get("hits", []):
|
||||
source = hit.get("_source") or {}
|
||||
cid = source.get("channel_id")
|
||||
cname = source.get("channel_name")
|
||||
if cid and cname and cid not in resolved:
|
||||
resolved[cid] = cname
|
||||
return resolved
|
||||
|
||||
|
||||
def upsert_channel_payload(
|
||||
qdrant_url: str,
|
||||
collection: str,
|
||||
channel_id: str,
|
||||
channel_name: str,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> bool:
|
||||
"""Set channel_name/channel_title for all vectors with this channel_id."""
|
||||
payload = {"channel_name": channel_name, "channel_title": channel_name}
|
||||
body = {
|
||||
"payload": payload,
|
||||
"filter": {"must": [{"key": "channel_id", "match": {"value": channel_id}}]},
|
||||
}
|
||||
LOGGER.info("Updating channel_id=%s -> %s", channel_id, channel_name)
|
||||
if dry_run:
|
||||
return True
|
||||
resp = requests.post(
|
||||
f"{qdrant_url}/collections/{collection}/points/payload",
|
||||
json=body,
|
||||
timeout=120,
|
||||
)
|
||||
if resp.status_code >= 400:
|
||||
LOGGER.error("Failed to update %s: %s", channel_id, resp.text)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def scroll_missing_payloads(
|
||||
qdrant_url: str,
|
||||
collection: str,
|
||||
batch_size: int,
|
||||
*,
|
||||
max_points: Optional[int] = None,
|
||||
) -> Iterable[List[Tuple[str, Dict[str, any]]]]:
|
||||
"""Yield batches of (point_id, payload) missing channel names."""
|
||||
fetched = 0
|
||||
next_page = None
|
||||
while True:
|
||||
current_limit = batch_size
|
||||
while True:
|
||||
body = {
|
||||
"limit": current_limit,
|
||||
"with_payload": True,
|
||||
"filter": {"must": [{"is_empty": {"key": "channel_name"}}]},
|
||||
}
|
||||
if next_page:
|
||||
body["offset"] = next_page
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{qdrant_url}/collections/{collection}/points/scroll",
|
||||
json=body,
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
break
|
||||
except requests.HTTPError as exc:
|
||||
LOGGER.warning(
|
||||
"Scroll request failed at limit=%s: %s", current_limit, exc
|
||||
)
|
||||
if current_limit <= 5:
|
||||
raise
|
||||
current_limit = max(5, current_limit // 2)
|
||||
LOGGER.info("Reducing scroll batch size to %s", current_limit)
|
||||
time.sleep(2)
|
||||
except requests.RequestException as exc: # type: ignore[attr-defined]
|
||||
LOGGER.warning("Transient scroll error: %s", exc)
|
||||
time.sleep(2)
|
||||
payload = resp.json().get("result", {})
|
||||
points = payload.get("points", [])
|
||||
if not points:
|
||||
break
|
||||
batch: List[Tuple[str, Dict[str, any]]] = []
|
||||
for point in points:
|
||||
pid = point.get("id")
|
||||
p_payload = point.get("payload") or {}
|
||||
batch.append((pid, p_payload))
|
||||
yield batch
|
||||
fetched += len(points)
|
||||
if max_points and fetched >= max_points:
|
||||
break
|
||||
next_page = payload.get("next_page_offset")
|
||||
if not next_page:
|
||||
break
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backfill missing channel_name/channel_title in Qdrant payloads"
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--max-points",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit processing to the first N points for testing",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
q_url = CONFIG.qdrant_url
|
||||
collection = CONFIG.qdrant_collection
|
||||
total_updates = 0
|
||||
|
||||
for batch in scroll_missing_payloads(
|
||||
q_url, collection, args.batch_size, max_points=args.max_points
|
||||
):
|
||||
channel_ids: Set[str] = set()
|
||||
for _, payload in batch:
|
||||
cid = payload.get("channel_id")
|
||||
if cid:
|
||||
channel_ids.add(str(cid))
|
||||
if not channel_ids:
|
||||
continue
|
||||
resolved = resolve_channels(channel_ids)
|
||||
if not resolved:
|
||||
LOGGER.warning("No channel names resolved for ids: %s", channel_ids)
|
||||
continue
|
||||
for cid, name in resolved.items():
|
||||
if upsert_channel_payload(
|
||||
q_url, collection, cid, name, dry_run=args.dry_run
|
||||
):
|
||||
total_updates += 1
|
||||
LOGGER.info("Updated %s channel payloads so far", total_updates)
|
||||
|
||||
LOGGER.info("Finished. Total channel updates attempted: %s", total_updates)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user