204 lines
8.6 KiB
Python
204 lines
8.6 KiB
Python
import json
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
|
|
class FakeResponse:
|
|
def __init__(self, payload, status_code=200):
|
|
self._payload = payload
|
|
self.status_code = status_code
|
|
|
|
def json(self):
|
|
return self._payload
|
|
|
|
def raise_for_status(self):
|
|
if self.status_code >= 400:
|
|
raise RuntimeError(f"HTTP {self.status_code}")
|
|
|
|
|
|
def _wildcard_match(pattern: str, value: str, case_insensitive: bool) -> bool:
|
|
if value is None:
|
|
return False
|
|
if case_insensitive:
|
|
pattern = pattern.lower()
|
|
value = value.lower()
|
|
if pattern.startswith("*") and pattern.endswith("*"):
|
|
needle = pattern.strip("*")
|
|
return needle in value
|
|
return pattern == value
|
|
|
|
|
|
def _extract_wildcard_clause(field_clause):
|
|
# Supports either {"field": "*term*"} or {"field": {"value":"*term*", "case_insensitive":true}}
|
|
if not isinstance(field_clause, dict):
|
|
return None, None, None
|
|
if len(field_clause) != 1:
|
|
return None, None, None
|
|
field, value = next(iter(field_clause.items()))
|
|
if isinstance(value, str):
|
|
return field, value, False
|
|
if isinstance(value, dict):
|
|
return field, value.get("value"), bool(value.get("case_insensitive"))
|
|
return None, None, None
|
|
|
|
|
|
def _filter_hosts_by_query(host_docs, query):
|
|
if not query:
|
|
return host_docs
|
|
bool_query = query.get("bool") if isinstance(query, dict) else None
|
|
if not bool_query:
|
|
return host_docs
|
|
filters = bool_query.get("filter") or []
|
|
if not filters:
|
|
return host_docs
|
|
|
|
matched = host_docs
|
|
for f in filters:
|
|
if "term" in f and "host.sources.keyword" in f["term"]:
|
|
src = f["term"]["host.sources.keyword"]
|
|
matched = [h for h in matched if src in (h.get("host", {}).get("sources") or [])]
|
|
continue
|
|
|
|
if "bool" in f and "should" in f["bool"]:
|
|
shoulds = f["bool"]["should"]
|
|
|
|
def matches_any(host_doc):
|
|
host = host_doc.get("host", {})
|
|
haystacks = {
|
|
"host.name.keyword": [host.get("name")],
|
|
"host.hostnames.keyword": host.get("hostnames") or [],
|
|
"host.id.keyword": [host.get("id")],
|
|
"host.ips": host.get("ips") or [],
|
|
"host.macs": host.get("macs") or [],
|
|
}
|
|
for clause in shoulds:
|
|
if "bool" in clause and "should" in clause["bool"]:
|
|
# nested should from multiple search terms
|
|
nested_shoulds = clause["bool"]["should"]
|
|
for nested in nested_shoulds:
|
|
if "wildcard" not in nested:
|
|
continue
|
|
field, value, ci = _extract_wildcard_clause(nested["wildcard"])
|
|
if not field or value is None:
|
|
continue
|
|
for candidate in haystacks.get(field, []):
|
|
if _wildcard_match(value, str(candidate or ""), ci):
|
|
return True
|
|
if "wildcard" in clause:
|
|
field, value, ci = _extract_wildcard_clause(clause["wildcard"])
|
|
if not field or value is None:
|
|
continue
|
|
for candidate in haystacks.get(field, []):
|
|
if _wildcard_match(value, str(candidate or ""), ci):
|
|
return True
|
|
return False
|
|
|
|
matched = [h for h in matched if matches_any(h)]
|
|
continue
|
|
return matched
|
|
|
|
|
|
class TestNetworkMCP(unittest.TestCase):
|
|
def setUp(self):
|
|
from frontend import app as app_module
|
|
|
|
self.app_module = app_module
|
|
self.client = app_module.app.test_client()
|
|
|
|
self.host_docs = [
|
|
{
|
|
"host": {
|
|
"id": "mac:dc:a6:32:67:55:dc",
|
|
"name": "SEELE",
|
|
"hostnames": ["SEELE"],
|
|
"ips": ["192.168.5.208"],
|
|
"macs": ["dc:a6:32:67:55:dc"],
|
|
"sources": ["opnsense-dhcp", "opnsense-arp"],
|
|
"last_seen": "2025-12-14T16:27:15.427091+00:00",
|
|
},
|
|
"ports": [{"port": 22, "state": "open", "service": {"name": "ssh"}}],
|
|
},
|
|
{
|
|
"host": {
|
|
"id": "mac:aa:bb:cc:dd:ee:ff",
|
|
"name": "core",
|
|
"hostnames": ["core.localdomain"],
|
|
"ips": ["192.168.5.34"],
|
|
"macs": ["aa:bb:cc:dd:ee:ff"],
|
|
"sources": ["inventory", "opnsense-arp"],
|
|
"last_seen": "2025-12-14T16:27:15.427091+00:00",
|
|
"notes": "Production Docker host",
|
|
},
|
|
"ports": [{"port": 443, "state": "open", "service": {"name": "https"}}],
|
|
},
|
|
]
|
|
|
|
def fake_requests_get(self, url, json=None, headers=None, auth=None, verify=None):
|
|
if url.endswith("/network-hosts/_search"):
|
|
query = (json or {}).get("query")
|
|
hits = _filter_hosts_by_query(self.host_docs, query)
|
|
return FakeResponse({"hits": {"hits": [{"_source": h} for h in hits]}})
|
|
if "/network-events-" in url and url.endswith("/_search"):
|
|
return FakeResponse({"hits": {"hits": []}})
|
|
return FakeResponse({}, status_code=404)
|
|
|
|
def test_rest_search_hostname_case_insensitive(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
resp = self.client.get("/api/hosts?q=seele&limit=50")
|
|
self.assertEqual(resp.status_code, 200)
|
|
payload = resp.get_json()
|
|
self.assertEqual(payload["total"], 1)
|
|
self.assertEqual(payload["hosts"][0]["name"], "SEELE")
|
|
|
|
def test_rest_search_by_ip(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
resp = self.client.get("/api/hosts?q=192.168.5.208")
|
|
payload = resp.get_json()
|
|
self.assertEqual(payload["total"], 1)
|
|
self.assertEqual(payload["hosts"][0]["id"], "mac:dc:a6:32:67:55:dc")
|
|
|
|
def test_rest_search_by_mac(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
resp = self.client.get("/api/hosts?q=dc:a6:32:67:55:dc")
|
|
payload = resp.get_json()
|
|
self.assertEqual(payload["total"], 1)
|
|
self.assertEqual(payload["hosts"][0]["name"], "SEELE")
|
|
|
|
def test_mcp_tools_call_search_terms(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
body = {
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": "tools/call",
|
|
"params": {"name": "list_hosts", "arguments": {"terms": ["seele"], "limit": 10}},
|
|
}
|
|
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
|
|
self.assertEqual(resp.status_code, 200)
|
|
payload = resp.get_json()
|
|
self.assertFalse(payload["result"]["isError"])
|
|
hosts = payload["result"]["structuredContent"]["hosts"]
|
|
self.assertEqual(len(hosts), 1)
|
|
self.assertEqual(hosts[0]["name"], "SEELE")
|
|
|
|
def test_mcp_resources_read_hosts_query(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
body = {"jsonrpc": "2.0", "id": 2, "method": "resources/read", "params": {"uri": "network://hosts?q=seele&limit=5"}}
|
|
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
|
|
self.assertEqual(resp.status_code, 200)
|
|
result = resp.get_json()["result"]
|
|
self.assertEqual(result["contents"][0]["mimeType"], "application/json")
|
|
data = json.loads(result["contents"][0]["text"])
|
|
self.assertEqual(data["total"], 1)
|
|
self.assertEqual(data["hosts"][0]["name"], "SEELE")
|
|
|
|
def test_mcp_notifications_initialized_no_response(self):
|
|
with patch.object(self.app_module.requests, "get", side_effect=self.fake_requests_get):
|
|
body = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
|
|
resp = self.client.post("/.well-known/mcp.json", data=json.dumps(body), content_type="application/json")
|
|
self.assertEqual(resp.status_code, 204)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
|