2025-12-31 20:11:44 -05:00

204 lines
8.6 KiB
Python

import json
import unittest
from unittest.mock import patch
class FakeResponse:
def __init__(self, payload, status_code=200):
self._payload = payload
self.status_code = status_code
def json(self):
return self._payload
def raise_for_status(self):
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
def _wildcard_match(pattern: str, value: str, case_insensitive: bool) -> bool:
if value is None:
return False
if case_insensitive:
pattern = pattern.lower()
value = value.lower()
if pattern.startswith("*") and pattern.endswith("*"):
needle = pattern.strip("*")
return needle in value
return pattern == value
def _extract_wildcard_clause(field_clause):
# Supports either {"field": "*term*"} or {"field": {"value":"*term*", "case_insensitive":true}}
if not isinstance(field_clause, dict):
return None, None, None
if len(field_clause) != 1:
return None, None, None
field, value = next(iter(field_clause.items()))
if isinstance(value, str):
return field, value, False
if isinstance(value, dict):
return field, value.get("value"), bool(value.get("case_insensitive"))
return None, None, None
def _filter_hosts_by_query(host_docs, query):
if not query:
return host_docs
bool_query = query.get("bool") if isinstance(query, dict) else None
if not bool_query:
return host_docs
filters = bool_query.get("filter") or []
if not filters:
return host_docs
matched = host_docs
for f in filters:
if "term" in f and "host.sources.keyword" in f["term"]:
src = f["term"]["host.sources.keyword"]
matched = [h for h in matched if src in (h.get("host", {}).get("sources") or [])]
continue
if "bool" in f and "should" in f["bool"]:
shoulds = f["bool"]["should"]
def matches_any(host_doc):
host = host_doc.get("host", {})
haystacks = {
"host.name.keyword": [host.get("name")],
"host.hostnames.keyword": host.get("hostnames") or [],
"host.id.keyword": [host.get("id")],
"host.ips": host.get("ips") or [],
"host.macs": host.get("macs") or [],
}
for clause in shoulds:
if "bool" in clause and "should" in clause["bool"]:
# nested should from multiple search terms
nested_shoulds = clause["bool"]["should"]
for nested in nested_shoulds:
if "wildcard" not in nested:
continue
field, value, ci = _extract_wildcard_clause(nested["wildcard"])
if not field or value is None:
continue
for candidate in haystacks.get(field, []):
if _wildcard_match(value, str(candidate or ""), ci):
return True
if "wildcard" in clause:
field, value, ci = _extract_wildcard_clause(clause["wildcard"])
if not field or value is None:
continue
for candidate in haystacks.get(field, []):
if _wildcard_match(value, str(candidate or ""), ci):
return True
return False
matched = [h for h in matched if matches_any(h)]
continue
return matched
class TestNetworkMCP(unittest.TestCase):
def setUp(self):
from frontend import app as app_module
self.app_module = app_module
self.client = app_module.app.test_client()
self.host_docs = [
{
"host": {
"id": "mac:dc:a6:32:67:55:dc",
"name": "SEELE",
"hostnames": ["SEELE"],
"ips": ["192.168.5.208"],
"macs": ["dc:a6:32:67:55:dc"],
"sources": ["opnsense-dhcp", "opnsense-arp"],
"last_seen": "2025-12-14T16:27:15.427091+00:00",
},
"ports": [{"port": 22, "state": "open", "service": {"name": "ssh"}}],
},
{
"host": {
"id": "mac:aa:bb:cc:dd:ee:ff",
"name": "core",
"hostnames": ["core.localdomain"],
"ips": ["192.168.5.34"],
"macs": ["aa:bb:cc:dd:ee:ff"],
"sources": ["inventory", "opnsense-arp"],
"last_seen": "2025-12-14T16:27:15.427091+00:00",
"notes": "Production Docker host",
},
"ports": [{"port": 443, "state": "open", "service": {"name": "https"}}],
},
]
def fake_requests_get(self, url, json=None, headers=None, auth=None, verify=None):
if url.endswith("/network-hosts/_search"):
query = (json or {}).get("query")
hits = _filter_hosts_by_query(self.host_docs, query)
return FakeResponse({"hits": {"hits": [{"_source": h} for h in hits]}})
if "/network-events-" in url and url.endswith("/_search"):
return FakeResponse({"hits": {"hits": []}})
return FakeResponse({}, status_code=404)
def test_rest_search_hostname_case_insensitive(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
resp = self.client.get("/api/hosts?q=seele&limit=50")
self.assertEqual(resp.status_code, 200)
payload = resp.get_json()
self.assertEqual(payload["total"], 1)
self.assertEqual(payload["hosts"][0]["name"], "SEELE")
def test_rest_search_by_ip(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
resp = self.client.get("/api/hosts?q=192.168.5.208")
payload = resp.get_json()
self.assertEqual(payload["total"], 1)
self.assertEqual(payload["hosts"][0]["id"], "mac:dc:a6:32:67:55:dc")
def test_rest_search_by_mac(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
resp = self.client.get("/api/hosts?q=dc:a6:32:67:55:dc")
payload = resp.get_json()
self.assertEqual(payload["total"], 1)
self.assertEqual(payload["hosts"][0]["name"], "SEELE")
def test_mcp_tools_call_search_terms(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
body = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "list_hosts", "arguments": {"terms": ["seele"], "limit": 10}},
}
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
self.assertEqual(resp.status_code, 200)
payload = resp.get_json()
self.assertFalse(payload["result"]["isError"])
hosts = payload["result"]["structuredContent"]["hosts"]
self.assertEqual(len(hosts), 1)
self.assertEqual(hosts[0]["name"], "SEELE")
def test_mcp_resources_read_hosts_query(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
body = {"jsonrpc": "2.0", "id": 2, "method": "resources/read", "params": {"uri": "network://hosts?q=seele&limit=5"}}
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
self.assertEqual(resp.status_code, 200)
result = resp.get_json()["result"]
self.assertEqual(result["contents"][0]["mimeType"], "application/json")
data = json.loads(result["contents"][0]["text"])
self.assertEqual(data["total"], 1)
self.assertEqual(data["hosts"][0]["name"], "SEELE")
def test_mcp_notifications_initialized_no_response(self):
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
body = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
self.assertEqual(resp.status_code, 204)
if __name__ == "__main__":
unittest.main()