""" Mind CLI — WebSocket Tunnel 客户端。 连接到 Cloud 端 mindcli_bridge,接收工具调用指令并在本地执行。 采用 Browser-Donated JWT 认证:浏览器授权 CLI,CLI 不需独立认证。 """ import asyncio import json import logging import os import time from typing import Any logger = logging.getLogger("mindcli.tunnel") # 连接状态 DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" class TunnelClient: """ CLI → Cloud WebSocket 隧道。 生命周期: 1. 浏览器 POST /tunnel/activate → 提供 JWT + tunnelUrl 2. TunnelClient.connect() → WebSocket 握手 + 能力协商 3. 消息循环:接收 tool_call → ManagedMCP 执行 → 返回结果 4. 心跳维持 30s / 断线指数退避重连 """ def __init__(self): self._status = DISCONNECTED self._ws = None self._jwt: str | None = None self._tunnel_url: str | None = None self._user_id: str | None = None self._managed_mcp = None self._reconnect_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None self._message_task: asyncio.Task | None = None # 重连参数 self._reconnect_delay = 1.0 self._max_reconnect_delay = 30.0 self._reconnect_attempts = 0 # 回调:通知 health.py 更新状态 self._on_status_change = None @property def status(self) -> str: return self._status @property def user_id(self) -> str | None: return self._user_id def set_status_callback(self, callback) -> None: """设置状态变更回调(由 health.py 注册)。""" self._on_status_change = callback def _set_status(self, status: str) -> None: self._status = status if self._on_status_change: self._on_status_change(status) async def activate(self, jwt: str, tunnel_url: str) -> dict: """ 浏览器授权激活 Tunnel。 Args: jwt: MindPass JWT(浏览器提供) tunnel_url: Cloud Tunnel WebSocket URL Returns: {"ok": True, "status": "connecting"} """ self._jwt = jwt self._tunnel_url = tunnel_url # 取消旧连接 await self.disconnect() # 后台启动连接 self._reconnect_task = asyncio.create_task(self._connect_loop()) return {"ok": True, "status": "connecting"} async def disconnect(self) -> None: """断开 Tunnel 连接。""" # 取消所有后台任务 for task in [self._reconnect_task, self._heartbeat_task, self._message_task]: if task and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass if self._ws: try: await self._ws.close() except Exception: pass self._ws = None self._set_status(DISCONNECTED) self._reconnect_attempts = 0 self._reconnect_delay = 1.0 logger.info("[Tunnel] 已断开") async def _connect_loop(self) -> None: """连接循环:握手 → 消息循环 → 断线重连。""" while True: try: self._set_status(CONNECTING) await self._connect() # 连接成功,重置重连参数 self._reconnect_delay = 1.0 self._reconnect_attempts = 0 # 进入消息循环(阻塞直到断线) await self._run() except asyncio.CancelledError: return except Exception as e: logger.warning("[Tunnel] 连接异常: %s", e) # 断线,指数退避重连 self._set_status(DISCONNECTED) self._reconnect_attempts += 1 delay = min(self._reconnect_delay, self._max_reconnect_delay) logger.info("[Tunnel] %ds 后重连(第 %d 次)...", delay, self._reconnect_attempts) await asyncio.sleep(delay) self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) async def _connect(self) -> None: """WebSocket 握手 + 能力协商。""" try: import websockets except ImportError: logger.error("[Tunnel] websockets 未安装。运行: pip install websockets") raise headers = {"Authorization": f"Bearer {self._jwt}"} self._ws = await websockets.connect( self._tunnel_url, additional_headers=headers, ping_interval=30, ping_timeout=10, close_timeout=5, ) # 等待 Cloud 确认 raw = await asyncio.wait_for(self._ws.recv(), timeout=10) msg = json.loads(raw) if msg.get("type") != "connected": raise ConnectionError(f"握手失败: {msg}") self._user_id = msg.get("userId") logger.info("[Tunnel] 已连接,userId=%s", self._user_id) # 发送能力报告 from mindcli.capability import scan_capabilities cap = scan_capabilities() await self._ws.send(json.dumps({ "type": "capability_report", **cap, })) # 接收审批结果 raw = await asyncio.wait_for(self._ws.recv(), timeout=10) approval = json.loads(raw) if approval.get("type") == "approved_tools": approved = approval.get("tools", []) # 初始化 ManagedMCP from mindcli.managed_mcp import ManagedMCP self._managed_mcp = ManagedMCP(approved_tools=approved) logger.info("[Tunnel] 审批通过工具: %s", approved) # 更新 health 状态 from mindcli.health import set_tunnel_status set_tunnel_status("connected", len(approved) if approval.get("type") == "approved_tools" else 0) self._set_status(CONNECTED) async def _run(self) -> None: """主消息循环:接收 Cloud 指令 → 本地执行 → 返回结果。""" try: async for raw in self._ws: msg = json.loads(raw) if msg.get("jsonrpc") == "2.0" and msg.get("method") == "tool_call": await self._handle_tool_call(msg) elif msg.get("type") == "approved_tools": # 热更新白名单 if self._managed_mcp: self._managed_mcp.update_approved(msg.get("tools", [])) elif msg.get("type") == "ping": await self._ws.send(json.dumps({"type": "pong"})) else: logger.debug("[Tunnel] 未知消息: %s", msg.get("type")) except Exception as e: logger.warning("[Tunnel] 消息循环异常: %s", e) raise async def _handle_tool_call(self, msg: dict) -> None: """处理工具调用请求。""" call_id = msg.get("id", "unknown") params = msg.get("params", {}) tool_name = params.get("tool", "") tool_args = params.get("args", {}) logger.info("[Tunnel] 工具调用: %s (id=%s)", tool_name, call_id) if not self._managed_mcp: result = {"error": "ManagedMCP not initialized"} else: result = await self._managed_mcp.execute(tool_name, tool_args) # 截断过大的输出(防止 WS 阻塞) output = result.get("output", "") if isinstance(output, str) and len(output) > 50000: result["output"] = output[:50000] + f"\n\n... [截断:原始 {len(output)} 字符]" result["truncated"] = True response = { "jsonrpc": "2.0", "id": call_id, "result": result, } await self._ws.send(json.dumps(response)) # ── 全局单例 ────────────────────────────────────────── _tunnel_client = TunnelClient() def get_tunnel_client() -> TunnelClient: """获取全局 TunnelClient 实例。""" return _tunnel_client