Security: disable debug mode, sanitize query input, validate Qdrant filters, add size/offset bounds
This commit is contained in:
@@ -17,6 +17,7 @@ from __future__ import annotations
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||||
@@ -46,6 +47,46 @@ LOGGER = logging.getLogger(__name__)
|
|||||||
_EMBED_MODEL = None
|
_EMBED_MODEL = None
|
||||||
_EMBED_MODEL_NAME: Optional[str] = None
|
_EMBED_MODEL_NAME: Optional[str] = None
|
||||||
|
|
||||||
|
# Security constants
|
||||||
|
MAX_QUERY_SIZE = 100
|
||||||
|
MAX_OFFSET = 10000
|
||||||
|
ALLOWED_QDRANT_FILTER_FIELDS = {"channel_id", "date", "video_status", "external_reference"}
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_query_string(query: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize user input for Elasticsearch query_string queries.
|
||||||
|
Removes dangerous field targeting and script injection patterns.
|
||||||
|
"""
|
||||||
|
if not query:
|
||||||
|
return "*"
|
||||||
|
sanitized = query.strip()
|
||||||
|
# Remove field targeting patterns like "_id:", "_source:", "script:"
|
||||||
|
dangerous_field_patterns = [
|
||||||
|
r'\b_[a-z_]+\s*:', # Internal fields like _id:, _source:
|
||||||
|
r'\bscript\s*:', # Script injection
|
||||||
|
]
|
||||||
|
for pattern in dangerous_field_patterns:
|
||||||
|
sanitized = re.sub(pattern, '', sanitized, flags=re.IGNORECASE)
|
||||||
|
# Remove excessive wildcards that could cause ReDoS
|
||||||
|
sanitized = re.sub(r'\*{2,}', '*', sanitized)
|
||||||
|
sanitized = re.sub(r'\?{2,}', '?', sanitized)
|
||||||
|
return sanitized.strip() or "*"
|
||||||
|
|
||||||
|
|
||||||
|
def validate_qdrant_filter(filters: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate and sanitize Qdrant filter objects.
|
||||||
|
Only allows whitelisted fields to prevent filter injection.
|
||||||
|
"""
|
||||||
|
if not isinstance(filters, dict):
|
||||||
|
return {}
|
||||||
|
validated: Dict[str, Any] = {}
|
||||||
|
for key, value in filters.items():
|
||||||
|
if key in ALLOWED_QDRANT_FILTER_FIELDS:
|
||||||
|
validated[key] = value
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
def _ensure_embedder(model_name: str) -> "SentenceTransformer":
|
def _ensure_embedder(model_name: str) -> "SentenceTransformer":
|
||||||
global _EMBED_MODEL, _EMBED_MODEL_NAME
|
global _EMBED_MODEL, _EMBED_MODEL_NAME
|
||||||
@@ -415,7 +456,7 @@ def build_query_payload(
|
|||||||
|
|
||||||
if use_query_string:
|
if use_query_string:
|
||||||
base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"]
|
base_fields = ["title^3", "description^2", "transcript_full", "transcript_secondary_full"]
|
||||||
qs_query = (query or "").strip() or "*"
|
qs_query = sanitize_query_string(query or "")
|
||||||
query_body: Dict[str, Any] = {
|
query_body: Dict[str, Any] = {
|
||||||
"query_string": {
|
"query_string": {
|
||||||
"query": qs_query,
|
"query": qs_query,
|
||||||
@@ -1496,9 +1537,9 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
|
|||||||
def api_vector_search():
|
def api_vector_search():
|
||||||
payload = request.get_json(silent=True) or {}
|
payload = request.get_json(silent=True) or {}
|
||||||
query_text = (payload.get("query") or "").strip()
|
query_text = (payload.get("query") or "").strip()
|
||||||
filters = payload.get("filters") or {}
|
filters = validate_qdrant_filter(payload.get("filters"))
|
||||||
limit = max(int(payload.get("size", 10)), 1)
|
limit = min(max(int(payload.get("size", 10)), 1), MAX_QUERY_SIZE)
|
||||||
offset = max(int(payload.get("offset", 0)), 0)
|
offset = min(max(int(payload.get("offset", 0)), 0), MAX_OFFSET)
|
||||||
|
|
||||||
if not query_text:
|
if not query_text:
|
||||||
return jsonify(
|
return jsonify(
|
||||||
@@ -1667,7 +1708,8 @@ def create_app(config: AppConfig = CONFIG) -> Flask:
|
|||||||
def main() -> None: # pragma: no cover
|
def main() -> None: # pragma: no cover
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
||||||
app = create_app()
|
app = create_app()
|
||||||
app.run(host="0.0.0.0", port=8080, debug=True)
|
debug_mode = os.environ.get("FLASK_DEBUG", "0").lower() in ("1", "true")
|
||||||
|
app.run(host="0.0.0.0", port=8080, debug=debug_mode)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__": # pragma: no cover
|
if __name__ == "__main__": # pragma: no cover
|
||||||
|
|||||||
Reference in New Issue
Block a user