Security: disable debug mode, sanitize query input, validate Qdrant filters, add size/offset bounds

This commit is contained in:
2026-01-08 14:41:42 -05:00
parent d26edda029
commit 1565c8db38

View File

@@ -17,6 +17,7 @@ from __future__ import annotations
import copy import copy
import json import json
import logging import logging
import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
@@ -46,6 +47,46 @@ LOGGER = logging.getLogger(__name__)
_EMBED_MODEL = None _EMBED_MODEL = None
_EMBED_MODEL_NAME: Optional[str] = None _EMBED_MODEL_NAME: Optional[str] = None
# Security constants
MAX_QUERY_SIZE = 100
MAX_OFFSET = 10000
ALLOWED_QDRANT_FILTER_FIELDS = {"channel_id", "date", "video_status", "external_reference"}
def sanitize_query_string(query: str) -> str:
"""
Sanitize user input for Elasticsearch query_string queries.
Removes dangerous field targeting and script injection patterns.
"""
if not query:
return "*"
sanitized = query.strip()
# Remove field targeting patterns like "_id:", "_source:", "script:"
dangerous_field_patterns = [
r'\b_[a-z_]+\s*:', # Internal fields like _id:, _source:
r'\bscript\s*:', # Script injection
]
for pattern in dangerous_field_patterns:
sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE)
# Remove excessive wildcards that could cause ReDoS
sanitized = re.sub(r'\*{2,}', '*', sanitized)
sanitized = re.sub(r'\?{2,}', '?', sanitized)
return sanitized.strip() or "*"
def validate_qdrant_filter(filters: Any) -> Dict[str, Any]:
"""
Validate and sanitize Qdrant filter objects.
Only allows whitelisted fields to prevent filter injection.
"""
if not isinstance(filters, dict):
return {}
validated: Dict[str, Any] = {}
for key, value in filters.items():
if key in ALLOWED_QDRANT_FILTER_FIELDS:
validated[key] = value
return validated
def _ensure_embedder(model_name: str) -> "SentenceTransformer": def _ensure_embedder(model_name: str) -> "SentenceTransformer":
global _EMBED_MODEL, _EMBED_MODEL_NAME global _EMBED_MODEL, _EMBED_MODEL_NAME
@@ -415,7 +456,7 @@ def build_query_payload(
if use_query_string: if use_query_string:
base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"] base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"]
qs_query = (query or "").strip() or "*" qs_query = sanitize_query_string(query or "")
query_body: Dict[str, Any] = { query_body: Dict[str, Any] = {
"query_string": { "query_string": {
"query": qs_query, "query": qs_query,
@@ -1496,9 +1537,9 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
def api_vector_search(): def api_vector_search():
payload = request.get_json(silent=True) or {} payload = request.get_json(silent=True) or {}
query_text = (payload.get("query") or "").strip() query_text = (payload.get("query") or "").strip()
filters = payload.get("filters") or {} filters = validate_qdrant_filter(payload.get("filters"))
limit = max(int(payload.get("size", 10)), 1) limit = min(max(int(payload.get("size", 10)), 1), MAX_QUERY_SIZE)
offset = max(int(payload.get("offset", 0)), 0) offset = min(max(int(payload.get("offset", 0)), 0), MAX_OFFSET)
if not query_text: if not query_text:
return jsonify( return jsonify(
@@ -1667,7 +1708,8 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
def main() -> None: # pragma: no cover def main() -> None: # pragma: no cover
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
app = create_app() app = create_app()
app.run(host="0.0.0.0", port=8080, debug=True) debug_mode = os.environ.get("FLASK_DEBUG", "0").lower() in ("1", "true")
app.run(host="0.0.0.0", port=8080, debug=debug_mode)
if __name__ == "__main__": # pragma: no cover if __name__ == "__main__": # pragma: no cover