Security: disable debug mode, sanitize query input, validate Qdrant filters, add size/offset bounds

This commit is contained in:
knight 2026-01-08 14:41:42 -05:00
parent d26edda029
commit 1565c8db38

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import copy
import json
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
@ -46,6 +47,46 @@ LOGGER = logging.getLogger(__name__)
_EMBED_MODEL = None
_EMBED_MODEL_NAME: Optional[str] = None
# Security constants
MAX_QUERY_SIZE = 100
MAX_OFFSET = 10000
ALLOWED_QDRANT_FILTER_FIELDS = {"channel_id", "date", "video_status", "external_reference"}
def sanitize_query_string(query: str) -> str:
"""
Sanitize user input for Elasticsearch query_string queries.
Removes dangerous field targeting and script injection patterns.
"""
if not query:
return "*"
sanitized = query.strip()
# Remove field targeting patterns like "_id:", "_source:", "script:"
dangerous_field_patterns = [
r'\b_[a-z_]+\s*:', # Internal fields like _id:, _source:
r'\bscript\s*:', # Script injection
]
for pattern in dangerous_field_patterns:
sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE)
# Remove excessive wildcards that could cause ReDoS
sanitized = re.sub(r'\*{2,}', '*', sanitized)
sanitized = re.sub(r'\?{2,}', '?', sanitized)
return sanitized.strip() or "*"
def validate_qdrant_filter(filters: Any) -> Dict[str, Any]:
"""
Validate and sanitize Qdrant filter objects.
Only allows whitelisted fields to prevent filter injection.
"""
if not isinstance(filters, dict):
return {}
validated: Dict[str, Any] = {}
for key, value in filters.items():
if key in ALLOWED_QDRANT_FILTER_FIELDS:
validated[key] = value
return validated
def _ensure_embedder(model_name: str) -> "SentenceTransformer":
global _EMBED_MODEL, _EMBED_MODEL_NAME
@ -415,7 +456,7 @@ def build_query_payload(
if use_query_string:
base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"]
qs_query = (query or "").strip() or "*"
qs_query = sanitize_query_string(query or "")
query_body: Dict[str, Any] = {
"query_string": {
"query": qs_query,
@ -1496,9 +1537,9 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
def api_vector_search():
payload = request.get_json(silent=True) or {}
query_text = (payload.get("query") or "").strip()
filters = payload.get("filters") or {}
limit = max(int(payload.get("size", 10)), 1)
offset = max(int(payload.get("offset", 0)), 0)
filters = validate_qdrant_filter(payload.get("filters"))
limit = min(max(int(payload.get("size", 10)), 1), MAX_QUERY_SIZE)
offset = min(max(int(payload.get("offset", 0)), 0), MAX_OFFSET)
if not query_text:
return jsonify(
@ -1667,7 +1708,8 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
def main() -> None: # pragma: no cover
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
app = create_app()
app.run(host="0.0.0.0", port=8080, debug=True)
debug_mode = os.environ.get("FLASK_DEBUG", "0").lower() in ("1", "true")
app.run(host="0.0.0.0", port=8080, debug=debug_mode)
if __name__ == "__main__": # pragma: no cover