5 分钟搭建自己的 AI 知识库

目录结构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
AIbot-File
├── db/                         数据库文件(自动生成)
├── static/                     存储 Web 静态资源
├── templates/                  存储 Web 模板
├── app.py                      对外服务接口
├── chatbot.py                  对接大语言模型
├── embedding.py                对接嵌入模型
├── file_service.py             文件存储与管理
├── FileManage.py               文件管理 Web 服务
├── models.py                   定义 File 数据库(用于保存文件信息)
└── requirements.txt

模型对接

分成两部分,对接大语言模型和对接嵌入模型

大语言模型

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from openai import OpenAI
from embedding import EmbeddingBot
from typing import Optional

class ChatBot:
    def __init__(self, api_key="", model_name="deepseek-ai/DeepSeek-V3"):
        """初始化聊天机器人
        Args:
            api_key: 可选的 API 密钥,如果提供则使用该密钥
            model_name: 使用的模型名称
        """
        # 初始化 OpenAI 客户端
        self.model_name = model_name
        self.client = OpenAI(
            base_url="https://api.siliconflow.cn",  # 模型 API 地址
            api_key=api_key if api_key else "no-key-required"
        )

        self.embedding_bot = EmbeddingBot()
        self.context_length = 5  # 保留的对话上下文长度

    def _get_relevant_context(self, query: str) -> str:
        """从向量数据库获取相关上下文"""
        # 直接查询向量数据库
        results = self.embedding_bot.query(
            query_text=query,
            n_results=5
        )
        return "\n".join(results['documents'][0])

    def chat(self, user_input: str, history: Optional[list] = None) -> str:
        """处理用户输入并返回 AI 回复"""
        # 获取相关知识上下文
        knowledge = self._get_relevant_context(user_input)
        
        # 构建对话历史
        messages = []
        if history:
            messages.extend(history[-self.context_length:])
        
        # 添加系统提示和知识上下文
        messages.append({
            "role": "system",
            "content": f"你是一个 AI 助手,请根据以下知识回答问题:\n{knowledge}"
        })
        
        messages.append({
            "role": "user",
            "content": user_input
        })
        
        # 调用
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            temperature=0.7
        )
        
        return response.choices[0].message.content

嵌入模型

嵌入模型负责把文本转化为向量

使用向量的好处是可以更方便地找到与用户提问相关的文本

代码包含了以下功能:

  • 原始文件向量化并存储
  • 删除文件时,删除关联的向量数据
  • 用户提问时查找相关文本
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os

import chromadb
import markdown
import requests


class EmbeddingBot:
    def __init__(self, api_url="https://api.siliconflow.cn/v1/embeddings",
                 api_key="", model_name="BAAI/bge-m3"):
        """初始化嵌入 API 和向量数据库
        Args:
            api_url: 嵌入 API 地址
            api_key: 可选的 API 密钥
            model_name: 使用的嵌入模型名称
        """
        self.api_url = api_url
        self.api_key = api_key
        self.model_name = model_name

        # 初始化 ChromaDB (持久化模式)
        db_path = os.path.join("db", "files")
        os.makedirs(db_path, exist_ok=True)
        from chromadb.config import Settings
        self.chroma_client = chromadb.PersistentClient(
            path=db_path,
            settings=Settings(anonymized_telemetry=False)
        )
        self.collection = self.chroma_client.get_or_create_collection(name="knowledge_base",
                                                                      metadata={"dimension": 1024})

    def extract_text(self, file_path):
        """提取文件内容"""
        text = ""
        if file_path.endswith('.md'):
            with open(file_path, 'r', encoding='utf-8') as f:
                text = markdown.markdown(f.read())
        return text

    def process_file(self, file_path, metadata=None):
        """处理文件并存储向量"""
        text = self.extract_text(file_path)
        if not text:
            return False

        # 向量化文本
        embeddings = self.embedding([text])
        if embeddings:
            # print(f"文件向量化成功,维度:{len(embeddings[0])}")
            # 合并元数据
            doc_metadata = {"file_path": file_path}
            if metadata:
                doc_metadata.update(metadata)

            # 存储到 ChromaDB
            self.collection.add(
                documents=[text],
                embeddings=[embeddings[0]],
                metadatas=[doc_metadata],
                ids=[os.path.basename(file_path)]
            )
            return True
        return False

    def process_files(self, file_paths, metadata=None):
        """批量处理文件"""
        results = []
        for file_path in file_paths:
            results.append(self.process_file(file_path, metadata))
        return results

    def delete_embeddings(self, file_path, delete_file=False):
        """删除文件对应的向量数据"""
        file_id = os.path.basename(file_path)
        if delete_file and os.path.exists(file_path):
            os.remove(file_path)
        self.collection.delete(ids=[file_id])

    def query(self, query_text, n_results=5, include_metadata=True):
        """查询知识库
        Args:
            query_text: 查询文本
            n_results: 返回结果数量
            include_metadata: 是否包含元数据
        Returns:
            dict: 包含查询结果的字典,格式为:
                {
                    'documents': 匹配的文档列表,
                    'metadatas': 匹配的元数据列表 (可选),
                    'scores': 匹配分数列表
                }
        """
        try:
            # 先将查询文本向量化
            query_embedding = self.embedding([query_text])
            
            # 使用向量进行查询
            results = self.collection.query(
                query_embeddings=query_embedding,
                n_results=n_results,
                include=["documents", "metadatas", "distances"] if include_metadata else ["documents", "distances"]
            )
            
            # 转换距离为相似度分数 (0-1)
            if 'distances' in results:
                results['scores'] = [1 - (d / 2) for d in results['distances'][0]]
                del results['distances']
            
            return results
        except Exception as e:
            print(f"查询失败:{str(e)}")
            return {'error': str(e)}

    def embedding(self, input_data):
        """通过 API 生成 embedding 向量"""
        headers = {}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        response = requests.post(
            self.api_url,
            json={"input": input_data, "model": self.model_name, "dimensions": 1024},
            headers=headers,
            timeout=30
        )
        response.raise_for_status()
        embedding_data = response.json()
        # 提取 embedding 向量数组
        embeddings = embedding_data['data'][0]['embedding']

        # 验证维度
        actual_dim = len(embeddings)
        if actual_dim != 1024:
            raise ValueError(f"维度不匹配:预期 1024 维,实际{actual_dim}维")

        return [embeddings]

文件存储与管理

定义数据库

这里使用 sqlite 存储源文件信息,方便后续的下载查看

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from sqlalchemy import create_engine, Column, Integer, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime

Base = declarative_base()

class File(Base):
    __tablename__ = 'files'
    
    id = Column(Integer, primary_key=True)
    filename = Column(String(255), nullable=False)
    upload_time = Column(DateTime, default=datetime.now)
    file_path = Column(String(255), nullable=False)

def init_db():
    engine = create_engine('sqlite:///db/files.db')
    Base.metadata.create_all(engine)
    return sessionmaker(bind=engine)

文件管理类

用于处理文件的上传(以及向量化存储)、下载、删除

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import logging
from datetime import datetime
from models import File, init_db
from embedding import EmbeddingBot


# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class FileService:
    def __init__(self):
        self.storage_dir = 'storage'
        self.db_dir = 'db'
        self.Session = init_db()
        self.embedding_bot = EmbeddingBot()
        
        
    def save_file(self, uploaded_file):
        """保存文件到 storage 目录并记录到数据库"""
        logger.info(f"开始处理文件上传:{uploaded_file.filename}")
        
        if not os.path.exists(self.storage_dir):
            logger.info(f"创建存储目录:{self.storage_dir}")
            os.makedirs(self.storage_dir)
            
        original_filename = uploaded_file.filename
        # 检查文件类型
        allowed_extensions = {'.md'}
        file_ext = os.path.splitext(original_filename)[1].lower()
        if file_ext not in allowed_extensions:
            error_msg = f"不支持的文件类型:{file_ext},仅支持:{', '.join(allowed_extensions)}"
            logger.error(error_msg)
            raise ValueError(error_msg)
            
        # 生成唯一文件名
        import uuid
        unique_filename = f"{uuid.uuid4()}{file_ext}"
        file_path = os.path.join(self.storage_dir, unique_filename)
        logger.info(f"文件保存路径:{file_path}")
        
        try:
            # 保存文件
            logger.info("开始保存文件...")
            with open(file_path, 'wb') as buffer:
                uploaded_file.save(buffer)
            logger.info("文件保存成功")
            
            # 验证文件内容
            if os.path.getsize(file_path) == 0:
                error_msg = "上传的文件为空"
                logger.error(error_msg)
                os.remove(file_path)
                raise ValueError(error_msg)
            
            # 数据库记录
            session = self.Session()
            try:
                logger.info("开始数据库记录...")
                file_record = File(
                    filename=original_filename,  # 保存原始文件名
                    file_path=file_path,
                    upload_time=datetime.now()
                )
                session.add(file_record)
                session.commit()
                logger.info(f"数据库记录成功,ID: {file_record.id}")
                
                # 向量化处理 (通过 embedding_bot)
                logger.info("开始向量化处理...")
                try:
                    if not self.embedding_bot.process_file(file_path):
                        error_msg = "文件向量化处理失败"
                        logger.error(error_msg)
                        raise ValueError(error_msg)
                    logger.info("向量化处理成功")
                except Exception as e:
                    logger.error(f"向量化处理异常:{str(e)}")
                    raise ValueError(f"向量化处理失败:{str(e)}")
                
                return file_record
            except Exception as e:
                logger.error(f"数据库操作失败:{str(e)}")
                session.rollback()
                if os.path.exists(file_path):
                    os.remove(file_path)
                raise
            finally:
                session.close()
        except Exception as e:
            logger.error(f"文件处理失败:{str(e)}")
            if os.path.exists(file_path):
                os.remove(file_path)
            raise
    
    def delete_file(self, file_id):
        """删除文件及相关记录"""
        session = self.Session()
        try:
            file_record = session.query(File).filter(File.id == file_id).first()
            if file_record:
                # 删除物理文件
                if os.path.exists(file_record.file_path):
                    os.remove(file_record.file_path)
                
                # 删除向量数据 (通过 embedding_bot)
                try:
                    self.embedding_bot.delete_embeddings(file_record.file_path)
                except Exception as e:
                    logger.error(f"向量数据删除失败:{str(e)}")
                    raise
                
                # 删除数据库记录
                session.delete(file_record)
                session.commit()
                return True
            return False
        finally:
            session.close()
    
    def get_file(self, file_id):
        """获取文件记录"""
        session = self.Session()
        try:
            return session.query(File).filter(File.id == file_id).first()
        finally:
            session.close()

    def get_file_content(self, file_id):
        """获取文件内容 (返回内容和原始文件名)"""
        session = self.Session()
        try:
            file_record = session.query(File).filter(File.id == file_id).first()
            if not file_record:
                return None
                
            with open(file_record.file_path, 'rb') as f:
                content = f.read()
            
            return {
                'content': content,
                'filename': file_record.filename  # 原始文件名
            }
        finally:
            session.close()

Web 端文件管理

为了更方便地管理知识库文件,这里使用 Flask 做一个简单的 Web 页面,用于上传文件并管理已有文件

模板文件中的静态资源为 Bootstrap v5.3.0 点击此处获取

Jinja2 模板

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>AI知识库管理系统</title>
    <link href="/static/css/bootstrap.min.css" rel="stylesheet">
    <style>
        .file-list {
            max-height: 500px;
            overflow-y: auto;
        }
        .upload-area {
            border: 2px dashed #ccc;
            padding: 20px;
            text-align: center;
            margin-bottom: 20px;
            cursor: pointer;
        }
        .upload-area:hover {
            border-color: #0d6efd;
        }
    </style>
</head>
<body>
    <div class="container mt-4">
        <h1 class="text-center mb-4">AI 知识库管理系统</h1>
        
        <div class="card">
            <div class="card-header">
                <h5>上传文件</h5>
            </div>
            <div class="card-body">
                <div id="uploadArea" class="upload-area">
                    <p>点击或拖拽文件到此处上传</p>
                    <p class="small text-muted">仅支持 Markdown(.md) 文件</p>
                    <input type="file" id="fileInput" class="d-none" accept=".md" multiple>
                </div>
                <div class="progress d-none" id="uploadProgress">
                    <div class="progress-bar" role="progressbar" style="width: 0%"></div>
                </div>
                <div id="uploadError" class="alert alert-danger d-none mt-2"></div>
                <div id="fileNameDisplay" class="mt-2 small text-muted d-none"></div>
                <button id="uploadBtn" class="btn btn-primary mt-2" disabled>上传</button>
            </div>
        </div>

        <div class="card mt-4">
            <div class="card-header">
                <h5>文件列表</h5>
            </div>
            <div class="card-body">
                <div class="file-list">
                    <table class="table table-striped">
                        <thead>
                            <tr>
                                <th><input type="checkbox" id="selectAll"></th>
                                <th>文件名</th>
                                <th>上传时间</th>
                                <th>操作</th>
                            </tr>
                        </thead>
                        <tbody id="fileTableBody">
                            <!-- 文件列表将通过 JavaScript 动态加载 -->
                        </tbody>
                    </table>
            </div>
            <div class="card-footer">
                <button class="btn btn-danger" onclick="batchDelete()">批量删除选中文件</button>
            </div>
        </div>
    </div>

    <!-- 删除进度模态框 -->
    <div class="modal fade" id="deleteModal" tabindex="-1" aria-hidden="true">
        <div class="modal-dialog">
            <div class="modal-content">
                <div class="modal-header">
                    <h5 class="modal-title">正在删除文件</h5>
                    <button type="button" class="btn-close" data-bs-dismiss="modal" aria-label="Close"></button>
                </div>
                <div class="modal-body">
                    <div class="progress">
                        <div id="deleteProgressBar" class="progress-bar" role="progressbar" style="width: 0%"></div>
                    </div>
                    <p id="deleteStatusText" class="mt-2">准备删除...</p>
                </div>
                <div class="modal-footer">
                    <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">取消</button>
                </div>
            </div>
        </div>
    </div>

    <script src="/static/js/bootstrap.bundle.min.js"></script>
    <script>
        document.addEventListener('DOMContentLoaded', function() {
            // 初始化
            loadFiles();

            // 全选/取消全选
            document.getElementById('selectAll').addEventListener('change', function() {
                const checkboxes = document.querySelectorAll('#fileTableBody input[type="checkbox"]');
                checkboxes.forEach(checkbox => {
                    checkbox.checked = this.checked;
                });
            });
            
            // 上传区域点击事件
            const uploadArea = document.getElementById('uploadArea');
            const fileInput = document.getElementById('fileInput');
            const uploadBtn = document.getElementById('uploadBtn');
            
            uploadArea.addEventListener('click', () => fileInput.click());
            
            // 拖拽上传
            uploadArea.addEventListener('dragover', (e) => {
                e.preventDefault();
                uploadArea.classList.add('border-primary');
            });
            
            uploadArea.addEventListener('dragleave', () => {
                uploadArea.classList.remove('border-primary');
            });
            
            uploadArea.addEventListener('drop', (e) => {
                e.preventDefault();
                uploadArea.classList.remove('border-primary');
                const errorEl = document.getElementById('uploadError');
                
                if (e.dataTransfer.files.length) {
                    const files = Array.from(e.dataTransfer.files);
                    const invalidFiles = files.filter(file => !file.name.toLowerCase().endsWith('.md'));
                    
                    if (invalidFiles.length) {
                        errorEl.textContent = `有${invalidFiles.length}个文件不是.md 格式`;
                        errorEl.classList.remove('d-none');
                        uploadBtn.disabled = true;
                        return;
                    }
                    
                    errorEl.classList.add('d-none');
                    fileInput.files = e.dataTransfer.files;
                    uploadBtn.disabled = false;
                    document.getElementById('fileNameDisplay').textContent = `已选择${files.length}个文件`;
                    document.getElementById('fileNameDisplay').classList.remove('d-none');
                }
            });
            
            // 文件选择变化
            fileInput.addEventListener('change', () => {
                const errorEl = document.getElementById('uploadError');
                if (fileInput.files.length) {
                    const files = Array.from(fileInput.files);
                    const invalidFiles = files.filter(file => !file.name.toLowerCase().endsWith('.md'));
                    
                    if (invalidFiles.length) {
                        errorEl.textContent = `有${invalidFiles.length}个文件不是.md 格式`;
                        errorEl.classList.remove('d-none');
                        uploadBtn.disabled = true;
                        return;
                    }
                    
                    errorEl.classList.add('d-none');
                    uploadBtn.disabled = false;
                    document.getElementById('fileNameDisplay').textContent = `已选择${files.length}个文件`;
                    document.getElementById('fileNameDisplay').classList.remove('d-none');
                } else {
                    uploadBtn.disabled = true;
                    document.getElementById('fileNameDisplay').classList.add('d-none');
                }
            });
            
            // 上传按钮点击
            uploadBtn.addEventListener('click', uploadFile);
        });
        
        // 加载文件列表
        function loadFiles() {
            fetch('/files')
                .then(response => response.json())
                .then(files => {
                    const tbody = document.getElementById('fileTableBody');
                    tbody.innerHTML = '';
                    
                    files.forEach(file => {
                        const row = document.createElement('tr');
                        row.innerHTML = `
                            <td><input type="checkbox" class="file-checkbox" data-file-id="${file.id}"></td>
                            <td>${file.filename}</td>
                            <td>${new Date(file.upload_time).toLocaleString()}</td>
                            <td>
                                <button class="btn btn-sm btn-success" onclick="downloadFile(${file.id})">下载</button>
                                <button class="btn btn-sm btn-danger" onclick="deleteFile(${file.id})">删除</button>
                            </td>
                        `;
                        tbody.appendChild(row);
                    });
                });
        }
        
        // 上传文件队列
        let uploadQueue = [];
        let isUploading = false;
        
        // 上传文件
        function uploadFile() {
            const fileInput = document.getElementById('fileInput');
            const progressBar = document.getElementById('uploadProgress');
            const progressBarInner = progressBar.querySelector('.progress-bar');
            
            if (!fileInput.files.length) return;
            
            uploadQueue = Array.from(fileInput.files);
            progressBar.classList.remove('d-none');
            uploadBtn.disabled = true;
            
            processUploadQueue();
        }
        
        function processUploadQueue() {
            if (uploadQueue.length === 0) {
                isUploading = false;
                const progressBar = document.getElementById('uploadProgress');
                const progressBarInner = progressBar.querySelector('.progress-bar');
                
                progressBarInner.style.width = '100%';
                setTimeout(() => {
                    progressBar.classList.add('d-none');
                    progressBarInner.style.width = '0%';
                    document.getElementById('fileInput').value = '';
                    document.getElementById('uploadError').classList.add('d-none');
                }, 500);
                return;
            }
            
            isUploading = true;
            const file = uploadQueue[0];
            const progressBarInner = document.getElementById('uploadProgress').querySelector('.progress-bar');
            
            document.getElementById('fileNameDisplay').textContent = `正在上传:${file.name} (剩余${uploadQueue.length}个文件)`;
            progressBarInner.style.width = '0%';
            
            const formData = new FormData();
            formData.append('file', file);
            
            fetch('/upload', {
                method: 'POST',
                body: formData
            })
            .then(response => {
                if (!response.ok) {
                    return response.json().then(err => { 
                        throw new Error(err.error || '上传失败');
                    });
                }
                return response.json();
            })
            .then(data => {
                progressBarInner.style.width = '100%';
                uploadQueue.shift();
                loadFiles(); // 每次上传成功刷新列表
                setTimeout(() => {
                    processUploadQueue();
                }, 300);
            })
            .catch(error => {
                const errorEl = document.getElementById('uploadError');
                errorEl.textContent = `${file.name}上传失败:${error.message}`;
                errorEl.classList.remove('d-none');
                uploadQueue.shift();
                setTimeout(() => {
                    processUploadQueue();
                }, 300);
            });
        }
        
        // 下载文件
        function downloadFile(fileId) {
            window.open(`/download/${fileId}`, '_blank');
        }
        
        // 批量删除文件
        function batchDelete() {
            const checkboxes = document.querySelectorAll('#fileTableBody input.file-checkbox:checked');
            if (checkboxes.length === 0) {
                alert('请先选择要删除的文件');
                return;
            }

            const fileIds = Array.from(checkboxes).map(checkbox => checkbox.dataset.fileId);
            if (confirm(`确定要删除选中的${fileIds.length}个文件吗?`)) {
                const deleteModal = new bootstrap.Modal(document.getElementById('deleteModal'));
                const progressBar = document.getElementById('deleteProgressBar');
                const statusText = document.getElementById('deleteStatusText');
                
                deleteModal.show();
                progressBar.style.width = '0%';
                statusText.textContent = `准备删除 ${fileIds.length} 个文件...`;

                let completed = 0;
                const total = fileIds.length;

                function deleteNextFile() {
                    if (completed >= total) {
                        statusText.textContent = '所有文件删除完成';
                        setTimeout(() => {
                            deleteModal.hide();
                            loadFiles();
                        }, 1000);
                        return;
                    }

                    const currentFileId = fileIds[completed];
                    statusText.textContent = `正在删除文件 ${completed + 1}/${total} (ID: ${currentFileId})`;

                    fetch(`/delete/${currentFileId}`, {
                        method: 'DELETE'
                    })
                    .then(response => {
                        if (!response.ok) throw new Error('删除失败');
                        return response.json();
                    })
                    .then(() => {
                        completed++;
                        progressBar.style.width = `${(completed / total) * 100}%`;
                        deleteNextFile();
                    })
                    .catch(error => {
                        alert(`删除文件 ${currentFileId} 失败:${error.message}`);
                        completed++;
                        progressBar.style.width = `${(completed / total) * 100}%`;
                        deleteNextFile();
                    });
                }

                deleteNextFile();
            }
        }

        // 删除单个文件
        function deleteFile(fileId) {
            if (confirm('确定要删除这个文件吗?')) {
                const deleteModal = new bootstrap.Modal(document.getElementById('deleteModal'));
                const progressBar = document.getElementById('deleteProgressBar');
                const statusText = document.getElementById('deleteStatusText');
                
                deleteModal.show();
                progressBar.style.width = '0%';
                statusText.textContent = '正在删除文件...';

                fetch(`/delete/${fileId}`, {
                    method: 'DELETE'
                })
                .then(response => {
                    if (!response.ok) throw new Error('删除失败');
                    return response.json();
                })
                .then(() => {
                    progressBar.style.width = '100%';
                    statusText.textContent = '文件删除成功';
                    setTimeout(() => {
                        deleteModal.hide();
                        loadFiles();
                    }, 1000);
                })
                .catch(error => {
                    progressBar.style.width = '100%';
                    statusText.textContent = '删除失败';
                    alert(error.message);
                });
            }
        }
    </script>
</body>
</html>

Flask

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from flask import Flask, request, render_template, send_file, jsonify
from file_service import FileService
from models import File

app = Flask(__name__, static_folder='static')

file_service = FileService()

# Web 界面
@app.route('/')
def index():
    return render_template('index.html')

# 文件上传
@app.route('/upload', methods=['POST'])
def upload_file():
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'}), 400
    
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'}), 400
    
    try:
        file_record = file_service.save_file(file)
        return jsonify({
            'id': file_record.id,
            'filename': file_record.filename,
            'upload_time': file_record.upload_time.isoformat()
        })
    except Exception as e:
        return jsonify({'error': str(e)}), 500

# 文件下载
@app.route('/download/<int:file_id>')
def download_file(file_id):
    file_data = file_service.get_file_content(file_id)
    if not file_data:
        return jsonify({'error': 'File not found'}), 404
    
    from io import BytesIO
    return send_file(
        BytesIO(file_data['content']),
        as_attachment=True,
        download_name=file_data['filename']
    )

# 文件删除
@app.route('/delete/<int:file_id>', methods=['DELETE'])
def delete_file(file_id):
    if not file_service.delete_file(file_id):
        return jsonify({'error': 'File not found'}), 404
    return jsonify({'message': 'File deleted successfully'})

# 文件列表
@app.route('/files')
def list_files():
    session = file_service.Session()
    try:
        files = session.query(File).all()
        return jsonify([
            {
                'id': f.id,
                'filename': f.filename,
                'upload_time': f.upload_time.isoformat()
            }
            for f in files
        ])
    finally:
        session.close()

if __name__ == '__main__':
    app.run(debug=False,port=5051)

对外接口

这里使用 FastAPI 提供一个对外接口示例,可以根据实际需要更改输入输出格式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json

import uvicorn
from fastapi import FastAPI, Request

from chatbot import ChatBot

app = FastAPI()
# 初始化
chatbot = ChatBot()


@app.post("/api/chat")
async def chat_endpoint(request: Request):
    try:
        data = await request.json()
        if not data:
            return {"error": "Empty request body"}, 400
        user_input = data.get("message")
        if not user_input:
            return {"error": "Missing 'message' field"}, 400
        response = chatbot.chat(user_input)
        return {"response": response}
    except json.JSONDecodeError:
        return {"error": "Invalid JSON format"}, 400
    except Exception as e:
        return {"error": f"Server error: {str(e)}"}, 500


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
Licensed under CC BY-NC-SA 4.0