- 统一为 DASHSCOPE_API_KEY(百炼 Key 通吃 LLM + Embedding) - import-voc 后自动触发 UDE 转写 + 向量化(后台 asyncio task) - 新增 GET /pipeline-status 查询流水线进度 - run_clustering 变纯 CPU(向量已预计算) - 新增独立 run_vectorization 函数 - 修复 Python 3.9 类型注解兼容性
354 lines
12 KiB
Python
354 lines
12 KiB
Python
"""
|
||
黑手党提案 — 独立后端(BYOK v2.1)
|
||
|
||
FastAPI 服务,端口 8093。
|
||
合伙人通过 X-DashScope-Key Header 传入百炼 Key,服务器统一用于 LLM + Embedding。
|
||
import-voc 完成后自动触发 UDE 转写 + 向量化流水线。
|
||
"""
|
||
import os
|
||
import logging
|
||
import asyncio
|
||
import httpx
|
||
|
||
from fastapi import FastAPI, Header, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
from db import get_case_conn, init_case_db, list_cases as _list_cases, DATA_DIR, VOC_API_BASE
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s %(message)s")
|
||
logger = logging.getLogger("mafia")
|
||
|
||
app = FastAPI(title="黑手党提案后端", version="2.1.0", description="独立后端:阿里云内闭环 + BYOK,合伙人自带 API Key")
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ═══════════ BYOK Key 解析 ═══════════
|
||
|
||
def resolve_key(header_key: Optional[str]) -> str:
|
||
"""优先使用合伙人 Header 传入的 DashScope Key(百炼 Key 一个通吃 LLM + Embedding)"""
|
||
if header_key:
|
||
logger.info(f"[BYOK] 使用合伙人 Key: ...{header_key[-6:]}")
|
||
return header_key
|
||
env_val = os.getenv("DASHSCOPE_API_KEY", "")
|
||
if env_val:
|
||
logger.info(f"[BYOK] 使用 fallback Key: ...{env_val[-6:]}")
|
||
return env_val
|
||
raise HTTPException(401, "缺少 DashScope API Key。请在 CoPaw 中配置 DASHSCOPE_API_KEY 并通过 X-DashScope-Key Header 传入")
|
||
|
||
|
||
# ═══════════ Models ═══════════
|
||
|
||
class CreateCaseRequest(BaseModel):
|
||
brandName: str
|
||
category: str = ""
|
||
focusProduct: str = ""
|
||
competitors: str = "[]"
|
||
vocResearchId: str = None
|
||
|
||
|
||
class LinkVocRequest(BaseModel):
|
||
vocResearchId: str
|
||
|
||
|
||
# ═══════════ 案例管理 ═══════════
|
||
|
||
@app.post("/api/cases")
|
||
async def create_case(req: CreateCaseRequest):
|
||
case_id = init_case_db(
|
||
brand_name=req.brandName,
|
||
category=req.category,
|
||
focus_product=req.focusProduct,
|
||
competitors=req.competitors,
|
||
voc_research_id=req.vocResearchId,
|
||
)
|
||
return {"caseId": case_id}
|
||
|
||
|
||
@app.get("/api/cases")
|
||
async def get_cases():
|
||
return _list_cases()
|
||
|
||
|
||
@app.get("/api/cases/{case_id}")
|
||
async def get_case(case_id: str):
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
card = conn.execute("SELECT * FROM case_card LIMIT 1").fetchone()
|
||
comment_count = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||
ude_count = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
|
||
cluster_count = conn.execute("SELECT count(*) FROM ude_clusters").fetchone()[0]
|
||
if not card:
|
||
raise HTTPException(404, "案例不存在")
|
||
return {
|
||
"caseId": case_id, **dict(card),
|
||
"commentCount": comment_count,
|
||
"udeCount": ude_count,
|
||
"clusterCount": cluster_count,
|
||
}
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
|
||
|
||
@app.delete("/api/cases/{case_id}")
|
||
async def delete_case(case_id: str):
|
||
path = DATA_DIR / f"{case_id}.db"
|
||
if path.exists():
|
||
path.unlink()
|
||
return {"deleted": True}
|
||
raise HTTPException(404, "案例不存在")
|
||
|
||
|
||
# ═══════════ VOC 导入 + 自动流水线 ═══════════
|
||
|
||
# 后台流水线状态(内存缓存,进程级)
|
||
_pipeline_status: dict[str, dict] = {}
|
||
|
||
|
||
async def _auto_pipeline(case_id: str, dashscope_key: str):
|
||
"""后台自动流水线:UDE 转写 → 向量化"""
|
||
from tools.ude_extract import run_ude_extraction, run_vectorization
|
||
|
||
_pipeline_status[case_id] = {"stage": "extracting", "progress": {}}
|
||
try:
|
||
# Stage 1: UDE 转写
|
||
result = await run_ude_extraction(case_id, limit=0, dashscope_key=dashscope_key)
|
||
_pipeline_status[case_id]["progress"]["extraction"] = result
|
||
logger.info(f"[Pipeline] {case_id} UDE 转写完成: {result.get('totalUdes', 0)} 条")
|
||
|
||
# Stage 2: 向量化
|
||
_pipeline_status[case_id]["stage"] = "vectorizing"
|
||
vec_result = run_vectorization(case_id, dashscope_key=dashscope_key)
|
||
_pipeline_status[case_id]["progress"]["vectorization"] = vec_result
|
||
logger.info(f"[Pipeline] {case_id} 向量化完成: {vec_result.get('vectorized', 0)} 条")
|
||
|
||
_pipeline_status[case_id]["stage"] = "done"
|
||
except Exception as e:
|
||
logger.error(f"[Pipeline] {case_id} 失败: {e}")
|
||
_pipeline_status[case_id]["stage"] = "error"
|
||
_pipeline_status[case_id]["error"] = str(e)
|
||
|
||
|
||
@app.post("/api/cases/{case_id}/link-voc")
|
||
async def link_voc(case_id: str, req: LinkVocRequest):
|
||
"""关联 VOC 研究 ID"""
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
conn.execute("UPDATE case_card SET voc_research_id = ?", (req.vocResearchId,))
|
||
conn.commit()
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
return {"linked": True, "vocResearchId": req.vocResearchId}
|
||
|
||
|
||
@app.post("/api/cases/{case_id}/import-voc")
|
||
async def import_voc(
|
||
case_id: str,
|
||
page: int = Query(1),
|
||
pageSize: int = Query(100),
|
||
x_dashscope_key: str = Header(None),
|
||
):
|
||
"""从 VOC 公网 API 拉取评论,完成后自动触发 UDE 转写 + 向量化"""
|
||
dashscope_key = resolve_key(x_dashscope_key)
|
||
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
card = conn.execute("SELECT voc_research_id, voc_api_base FROM case_card LIMIT 1").fetchone()
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
|
||
if not card or not card["voc_research_id"]:
|
||
raise HTTPException(400, "未关联 VOC 研究,请先调用 link-voc")
|
||
|
||
voc_rid = card["voc_research_id"]
|
||
api_base = card["voc_api_base"] or VOC_API_BASE
|
||
|
||
total_imported = 0
|
||
current_page = page
|
||
|
||
async with httpx.AsyncClient(timeout=30) as client:
|
||
while True:
|
||
url = f"{api_base}/{voc_rid}/voc-list?page={current_page}&page_size={pageSize}"
|
||
try:
|
||
resp = await client.get(url)
|
||
if resp.status_code != 200:
|
||
logger.warning(f"[Import] VOC API 返回 {resp.status_code}: {resp.text[:100]}")
|
||
break
|
||
data = resp.json()
|
||
except Exception as e:
|
||
logger.error(f"[Import] VOC API 请求失败: {e}")
|
||
break
|
||
|
||
items = data.get("items") or data.get("data") or []
|
||
if not items:
|
||
break
|
||
|
||
with get_case_conn(case_id) as conn:
|
||
for item in items:
|
||
text = item.get("text", "")
|
||
if len(text) < 10:
|
||
continue
|
||
try:
|
||
conn.execute(
|
||
"INSERT OR IGNORE INTO comments (voc_id, platform, text, like_count, published_at) VALUES (?,?,?,?,?)",
|
||
(
|
||
item.get("id"),
|
||
item.get("platform", ""),
|
||
text,
|
||
item.get("like_count", 0),
|
||
item.get("published_at", ""),
|
||
)
|
||
)
|
||
total_imported += 1
|
||
except Exception:
|
||
pass
|
||
conn.commit()
|
||
|
||
total = data.get("total", 0)
|
||
if current_page * pageSize >= total:
|
||
break
|
||
current_page += 1
|
||
|
||
with get_case_conn(case_id) as conn:
|
||
local_count = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||
|
||
# 自动触发后台流水线(UDE 转写 + 向量化)
|
||
pipeline_triggered = False
|
||
if local_count > 0:
|
||
asyncio.create_task(_auto_pipeline(case_id, dashscope_key))
|
||
pipeline_triggered = True
|
||
logger.info(f"[Import] {case_id} 自动流水线已触发({local_count} 条评论)")
|
||
|
||
return {
|
||
"imported": total_imported,
|
||
"totalLocal": local_count,
|
||
"vocResearchId": voc_rid,
|
||
"pagesProcessed": current_page - page + 1,
|
||
"pipelineTriggered": pipeline_triggered,
|
||
}
|
||
|
||
|
||
@app.get("/api/cases/{case_id}/pipeline-status")
|
||
async def pipeline_status(case_id: str):
|
||
"""查询后台流水线进度"""
|
||
status = _pipeline_status.get(case_id)
|
||
if not status:
|
||
# 没有正在运行的流水线,从 DB 推断状态
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
comments = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||
udes = conn.execute("SELECT count(*) FROM ude_sentences").fetchone()[0]
|
||
vectorized = conn.execute("SELECT count(*) FROM ude_sentences WHERE vector IS NOT NULL").fetchone()[0]
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
return {
|
||
"stage": "idle",
|
||
"progress": {"comments": comments, "udesExtracted": udes, "udesVectorized": vectorized},
|
||
}
|
||
return status
|
||
|
||
|
||
@app.get("/api/cases/{case_id}/comments")
|
||
async def get_comments(case_id: str, page: int = 1, pageSize: int = 50):
|
||
"""查看本地导入的评论"""
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
total = conn.execute("SELECT count(*) FROM comments").fetchone()[0]
|
||
rows = conn.execute("""
|
||
SELECT id, voc_id, platform, text, like_count, published_at
|
||
FROM comments ORDER BY like_count DESC
|
||
LIMIT ? OFFSET ?
|
||
""", (pageSize, (page - 1) * pageSize)).fetchall()
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
return {"total": total, "page": page, "items": [dict(r) for r in rows]}
|
||
|
||
|
||
# ═══════════ UDE 分析(BYOK) ═══════════
|
||
|
||
@app.post("/api/cases/{case_id}/ude/extract")
|
||
async def extract_ude(
|
||
case_id: str,
|
||
limit: int = Query(0),
|
||
x_dashscope_key: str = Header(None),
|
||
):
|
||
"""手动触发 UDE 转写(通常由 import-voc 自动触发,此接口用于补跑)"""
|
||
from tools.ude_extract import run_ude_extraction
|
||
key = resolve_key(x_dashscope_key)
|
||
try:
|
||
result = await run_ude_extraction(case_id, limit, dashscope_key=key)
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(404, str(e))
|
||
return result
|
||
|
||
|
||
@app.post("/api/cases/{case_id}/ude/cluster")
|
||
async def cluster_ude(
|
||
case_id: str,
|
||
eps: float = Query(0.25),
|
||
minSamples: int = Query(3),
|
||
):
|
||
"""向量聚类(纯 CPU,不调外部 API。向量已在流水线中预计算)"""
|
||
from tools.ude_extract import run_clustering
|
||
try:
|
||
result = run_clustering(case_id, eps, minSamples)
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(404, str(e))
|
||
return result
|
||
|
||
|
||
@app.get("/api/cases/{case_id}/ude/clusters")
|
||
async def get_clusters(case_id: str):
|
||
try:
|
||
with get_case_conn(case_id) as conn:
|
||
clusters = conn.execute(
|
||
"SELECT * FROM ude_clusters ORDER BY coverage DESC"
|
||
).fetchall()
|
||
except FileNotFoundError:
|
||
raise HTTPException(404, "案例不存在")
|
||
return [dict(r) for r in clusters]
|
||
|
||
|
||
@app.get("/api/cases/{case_id}/ude/coverage")
|
||
async def get_coverage(case_id: str):
|
||
from tools.ude_extract import run_coverage_scan
|
||
try:
|
||
result = run_coverage_scan(case_id)
|
||
except FileNotFoundError as e:
|
||
raise HTTPException(404, str(e))
|
||
return result
|
||
|
||
|
||
# ═══════════ 健康检查 ═══════════
|
||
|
||
@app.get("/api/health")
|
||
async def health():
|
||
return {
|
||
"status": "ok",
|
||
"version": "2.1.0",
|
||
"architecture": "BYOK (Bring Your Own Key)",
|
||
"vocApiBase": VOC_API_BASE,
|
||
"caseDataDir": str(DATA_DIR),
|
||
"byok": {
|
||
"bailianFallback": bool(os.getenv("BAILIAN_API_KEY")),
|
||
"dashscopeFallback": bool(os.getenv("DASHSCOPE_API_KEY")),
|
||
},
|
||
}
|
||
|
||
|
||
# ═══════════ 启动 ═══════════
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
port = int(os.getenv("PORT", "8093"))
|
||
uvicorn.run(app, host="0.0.0.0", port=port)
|