- 独立 pyproject.toml(pip install -e .) - vendor_hermes.sh 已改为显式路径模式(不再依赖相对目录) - 包含 hermes vendor 快照
243 lines
8.1 KiB
Python
243 lines
8.1 KiB
Python
"""
|
||
Mind CLI — WebSocket Tunnel 客户端。
|
||
|
||
连接到 Cloud 端 mindcli_bridge,接收工具调用指令并在本地执行。
|
||
采用 Browser-Donated JWT 认证:浏览器授权 CLI,CLI 不需独立认证。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from typing import Any
|
||
|
||
logger = logging.getLogger("mindcli.tunnel")
|
||
|
||
# 连接状态
|
||
DISCONNECTED = "disconnected"
|
||
CONNECTING = "connecting"
|
||
CONNECTED = "connected"
|
||
|
||
|
||
class TunnelClient:
|
||
"""
|
||
CLI → Cloud WebSocket 隧道。
|
||
|
||
生命周期:
|
||
1. 浏览器 POST /tunnel/activate → 提供 JWT + tunnelUrl
|
||
2. TunnelClient.connect() → WebSocket 握手 + 能力协商
|
||
3. 消息循环:接收 tool_call → ManagedMCP 执行 → 返回结果
|
||
4. 心跳维持 30s / 断线指数退避重连
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._status = DISCONNECTED
|
||
self._ws = None
|
||
self._jwt: str | None = None
|
||
self._tunnel_url: str | None = None
|
||
self._user_id: str | None = None
|
||
self._managed_mcp = None
|
||
self._reconnect_task: asyncio.Task | None = None
|
||
self._heartbeat_task: asyncio.Task | None = None
|
||
self._message_task: asyncio.Task | None = None
|
||
|
||
# 重连参数
|
||
self._reconnect_delay = 1.0
|
||
self._max_reconnect_delay = 30.0
|
||
self._reconnect_attempts = 0
|
||
|
||
# 回调:通知 health.py 更新状态
|
||
self._on_status_change = None
|
||
|
||
@property
|
||
def status(self) -> str:
|
||
return self._status
|
||
|
||
@property
|
||
def user_id(self) -> str | None:
|
||
return self._user_id
|
||
|
||
def set_status_callback(self, callback) -> None:
|
||
"""设置状态变更回调(由 health.py 注册)。"""
|
||
self._on_status_change = callback
|
||
|
||
def _set_status(self, status: str) -> None:
|
||
self._status = status
|
||
if self._on_status_change:
|
||
self._on_status_change(status)
|
||
|
||
async def activate(self, jwt: str, tunnel_url: str) -> dict:
|
||
"""
|
||
浏览器授权激活 Tunnel。
|
||
|
||
Args:
|
||
jwt: MindPass JWT(浏览器提供)
|
||
tunnel_url: Cloud Tunnel WebSocket URL
|
||
|
||
Returns:
|
||
{"ok": True, "status": "connecting"}
|
||
"""
|
||
self._jwt = jwt
|
||
self._tunnel_url = tunnel_url
|
||
|
||
# 取消旧连接
|
||
await self.disconnect()
|
||
|
||
# 后台启动连接
|
||
self._reconnect_task = asyncio.create_task(self._connect_loop())
|
||
|
||
return {"ok": True, "status": "connecting"}
|
||
|
||
async def disconnect(self) -> None:
|
||
"""断开 Tunnel 连接。"""
|
||
# 取消所有后台任务
|
||
for task in [self._reconnect_task, self._heartbeat_task, self._message_task]:
|
||
if task and not task.done():
|
||
task.cancel()
|
||
try:
|
||
await task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
if self._ws:
|
||
try:
|
||
await self._ws.close()
|
||
except Exception:
|
||
pass
|
||
self._ws = None
|
||
|
||
self._set_status(DISCONNECTED)
|
||
self._reconnect_attempts = 0
|
||
self._reconnect_delay = 1.0
|
||
logger.info("[Tunnel] 已断开")
|
||
|
||
async def _connect_loop(self) -> None:
|
||
"""连接循环:握手 → 消息循环 → 断线重连。"""
|
||
while True:
|
||
try:
|
||
self._set_status(CONNECTING)
|
||
await self._connect()
|
||
# 连接成功,重置重连参数
|
||
self._reconnect_delay = 1.0
|
||
self._reconnect_attempts = 0
|
||
# 进入消息循环(阻塞直到断线)
|
||
await self._run()
|
||
except asyncio.CancelledError:
|
||
return
|
||
except Exception as e:
|
||
logger.warning("[Tunnel] 连接异常: %s", e)
|
||
|
||
# 断线,指数退避重连
|
||
self._set_status(DISCONNECTED)
|
||
self._reconnect_attempts += 1
|
||
delay = min(self._reconnect_delay, self._max_reconnect_delay)
|
||
logger.info("[Tunnel] %ds 后重连(第 %d 次)...", delay, self._reconnect_attempts)
|
||
await asyncio.sleep(delay)
|
||
self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay)
|
||
|
||
async def _connect(self) -> None:
|
||
"""WebSocket 握手 + 能力协商。"""
|
||
try:
|
||
import websockets
|
||
except ImportError:
|
||
logger.error("[Tunnel] websockets 未安装。运行: pip install websockets")
|
||
raise
|
||
|
||
headers = {"Authorization": f"Bearer {self._jwt}"}
|
||
self._ws = await websockets.connect(
|
||
self._tunnel_url,
|
||
additional_headers=headers,
|
||
ping_interval=30,
|
||
ping_timeout=10,
|
||
close_timeout=5,
|
||
)
|
||
|
||
# 等待 Cloud 确认
|
||
raw = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||
msg = json.loads(raw)
|
||
if msg.get("type") != "connected":
|
||
raise ConnectionError(f"握手失败: {msg}")
|
||
|
||
self._user_id = msg.get("userId")
|
||
logger.info("[Tunnel] 已连接,userId=%s", self._user_id)
|
||
|
||
# 发送能力报告
|
||
from mindcli.capability import scan_capabilities
|
||
cap = scan_capabilities()
|
||
await self._ws.send(json.dumps({
|
||
"type": "capability_report",
|
||
**cap,
|
||
}))
|
||
|
||
# 接收审批结果
|
||
raw = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||
approval = json.loads(raw)
|
||
if approval.get("type") == "approved_tools":
|
||
approved = approval.get("tools", [])
|
||
# 初始化 ManagedMCP
|
||
from mindcli.managed_mcp import ManagedMCP
|
||
self._managed_mcp = ManagedMCP(approved_tools=approved)
|
||
logger.info("[Tunnel] 审批通过工具: %s", approved)
|
||
|
||
# 更新 health 状态
|
||
from mindcli.health import set_tunnel_status
|
||
set_tunnel_status("connected", len(approved) if approval.get("type") == "approved_tools" else 0)
|
||
self._set_status(CONNECTED)
|
||
|
||
async def _run(self) -> None:
|
||
"""主消息循环:接收 Cloud 指令 → 本地执行 → 返回结果。"""
|
||
try:
|
||
async for raw in self._ws:
|
||
msg = json.loads(raw)
|
||
|
||
if msg.get("jsonrpc") == "2.0" and msg.get("method") == "tool_call":
|
||
await self._handle_tool_call(msg)
|
||
elif msg.get("type") == "approved_tools":
|
||
# 热更新白名单
|
||
if self._managed_mcp:
|
||
self._managed_mcp.update_approved(msg.get("tools", []))
|
||
elif msg.get("type") == "ping":
|
||
await self._ws.send(json.dumps({"type": "pong"}))
|
||
else:
|
||
logger.debug("[Tunnel] 未知消息: %s", msg.get("type"))
|
||
except Exception as e:
|
||
logger.warning("[Tunnel] 消息循环异常: %s", e)
|
||
raise
|
||
|
||
async def _handle_tool_call(self, msg: dict) -> None:
|
||
"""处理工具调用请求。"""
|
||
call_id = msg.get("id", "unknown")
|
||
params = msg.get("params", {})
|
||
tool_name = params.get("tool", "")
|
||
tool_args = params.get("args", {})
|
||
|
||
logger.info("[Tunnel] 工具调用: %s (id=%s)", tool_name, call_id)
|
||
|
||
if not self._managed_mcp:
|
||
result = {"error": "ManagedMCP not initialized"}
|
||
else:
|
||
result = await self._managed_mcp.execute(tool_name, tool_args)
|
||
|
||
# 截断过大的输出(防止 WS 阻塞)
|
||
output = result.get("output", "")
|
||
if isinstance(output, str) and len(output) > 50000:
|
||
result["output"] = output[:50000] + f"\n\n... [截断:原始 {len(output)} 字符]"
|
||
result["truncated"] = True
|
||
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": call_id,
|
||
"result": result,
|
||
}
|
||
await self._ws.send(json.dumps(response))
|
||
|
||
|
||
# ── 全局单例 ──────────────────────────────────────────
|
||
_tunnel_client = TunnelClient()
|
||
|
||
|
||
def get_tunnel_client() -> TunnelClient:
|
||
"""获取全局 TunnelClient 实例。"""
|
||
return _tunnel_client
|