diff --git a/search_app.py b/search_app.py index 82278fb..e74dbc8 100644 --- a/search_app.py +++ b/search_app.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy import json import logging +import os import re from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple @@ -46,6 +47,46 @@ LOGGER = logging.getLogger(__name__) _EMBED_MODEL = None _EMBED_MODEL_NAME: Optional[str] = None +# Security constants +MAX_QUERY_SIZE = 100 +MAX_OFFSET = 10000 +ALLOWED_QDRANT_FILTER_FIELDS = {"channel_id", "date", "video_status", "external_reference"} + + +def sanitize_query_string(query: str) -> str: + """ + Sanitize user input for Elasticsearch query_string queries. + Removes dangerous field targeting and script injection patterns. + """ + if not query: + return "*" + sanitized = query.strip() + # Remove field targeting patterns like "_id:", "_source:", "script:" + dangerous_field_patterns = [ + r'\b_[a-z_]+\s*:', # Internal fields like _id:, _source: + r'\bscript\s*:', # Script injection + ] + for pattern in dangerous_field_patterns: + sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE) + # Remove excessive wildcards that could cause ReDoS + sanitized = re.sub(r'\*{2,}', '*', sanitized) + sanitized = re.sub(r'\?{2,}', '?', sanitized) + return sanitized.strip() or "*" + + +def validate_qdrant_filter(filters: Any) -> Dict[str, Any]: + """ + Validate and sanitize Qdrant filter objects. + Only allows whitelisted fields to prevent filter injection. + """ + if not isinstance(filters, dict): + return {} + validated: Dict[str, Any] = {} + for key, value in filters.items(): + if key in ALLOWED_QDRANT_FILTER_FIELDS: + validated[key] = value + return validated + def _ensure_embedder(model_name: str) -> "SentenceTransformer": global _EMBED_MODEL, _EMBED_MODEL_NAME @@ -415,7 +456,7 @@ def build_query_payload( if use_query_string: base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"] - qs_query = (query or "").strip() or "*" + qs_query = sanitize_query_string(query or "") query_body: Dict[str, Any] = { "query_string": { "query": qs_query, @@ -1496,9 +1537,9 @@ def create_app(config: AppConfig = CONFIG) -> Flask: def api_vector_search(): payload = request.get_json(silent=True) or {} query_text = (payload.get("query") or "").strip() - filters = payload.get("filters") or {} - limit = max(int(payload.get("size", 10)), 1) - offset = max(int(payload.get("offset", 0)), 0) + filters = validate_qdrant_filter(payload.get("filters")) + limit = min(max(int(payload.get("size", 10)), 1), MAX_QUERY_SIZE) + offset = min(max(int(payload.get("offset", 0)), 0), MAX_OFFSET) if not query_text: return jsonify( @@ -1667,7 +1708,8 @@ def create_app(config: AppConfig = CONFIG) -> Flask: def main() -> None: # pragma: no cover logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") app = create_app() - app.run(host="0.0.0.0", port=8080, debug=True) + debug_mode = os.environ.get("FLASK_DEBUG", "0").lower() in ("1", "true") + app.run(host="0.0.0.0", port=8080, debug=debug_mode) if __name__ == "__main__": # pragma: no cover