157 lines
4.5 KiB
Python
157 lines
4.5 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
from pydantic import BaseModel
|
|
from typing import Optional, List
|
|
from datetime import datetime
|
|
import io
|
|
import os
|
|
|
|
from database.database import get_db, ChatMessage as DBChatMessage, ChatFile as DBChatFile
|
|
from database.minio_processor import MinIOProcessor
|
|
|
|
minio_client = MinIOProcessor(
|
|
endpoint=os.getenv("MINIO_ENDPOINT", "localhost:9000"),
|
|
access_key=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
|
secret_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
|
secure=False
|
|
)
|
|
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "resumes")
|
|
|
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
|
|
|
|
|
class ChatFileResponse(BaseModel):
|
|
id: int
|
|
file_name: str
|
|
file_type: str
|
|
file_url: str
|
|
created_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class ChatMessageResponse(BaseModel):
|
|
id: int
|
|
user_id: int
|
|
role: str
|
|
content: str
|
|
is_favorite: bool
|
|
files: List[ChatFileResponse]
|
|
created_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
@router.post("", response_model=ChatMessageResponse)
|
|
async def create_chat_message(
|
|
user_id: int = Form(...),
|
|
role: str = Form(...),
|
|
content: str = Form(...),
|
|
files: Optional[List[UploadFile]] = File(None),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
message = DBChatMessage(
|
|
user_id=user_id,
|
|
role=role,
|
|
content=content,
|
|
is_favorite=False
|
|
)
|
|
db.add(message)
|
|
await db.flush()
|
|
|
|
if files:
|
|
for file in files:
|
|
object_name = f"chat/{user_id}/{message.id}/{file.filename}"
|
|
try:
|
|
file_content = await file.read()
|
|
minio_client.put_object(
|
|
bucket_name=MINIO_BUCKET,
|
|
object_name=object_name,
|
|
data=io.BytesIO(file_content),
|
|
length=len(file_content),
|
|
content_type=file.content_type or "application/octet-stream"
|
|
)
|
|
file_url = f"minio://{MINIO_BUCKET}/{object_name}"
|
|
except Exception as e:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=500, detail=f"Failed to upload file: {str(e)}")
|
|
|
|
chat_file = DBChatFile(
|
|
message_id=message.id,
|
|
file_name=file.filename,
|
|
file_type=file.content_type or "application/octet-stream",
|
|
file_url=file_url
|
|
)
|
|
db.add(chat_file)
|
|
|
|
await db.commit()
|
|
await db.refresh(message)
|
|
|
|
result = await db.execute(
|
|
select(DBChatMessage)
|
|
.where(DBChatMessage.id == message.id)
|
|
.options(selectinload(DBChatMessage.files))
|
|
)
|
|
message = result.scalar_one()
|
|
|
|
return message
|
|
|
|
|
|
@router.get("", response_model=List[ChatMessageResponse])
|
|
async def get_chat_messages(
|
|
user_id: Optional[int] = None,
|
|
is_favorite: Optional[bool] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(DBChatMessage).options(selectinload(DBChatMessage.files))
|
|
|
|
if user_id is not None:
|
|
query = query.where(DBChatMessage.user_id == user_id)
|
|
if is_favorite is not None:
|
|
query = query.where(DBChatMessage.is_favorite == is_favorite)
|
|
|
|
query = query.order_by(DBChatMessage.created_at.desc())
|
|
|
|
result = await db.execute(query)
|
|
messages = result.scalars().all()
|
|
|
|
return messages
|
|
|
|
|
|
@router.get("/{message_id}", response_model=ChatMessageResponse)
|
|
async def get_chat_message(message_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(DBChatMessage)
|
|
.where(DBChatMessage.id == message_id)
|
|
.options(selectinload(DBChatMessage.files))
|
|
)
|
|
message = result.scalar_one_or_none()
|
|
|
|
if not message:
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
|
|
return message
|
|
|
|
|
|
@router.put("/{message_id}/favorite", response_model=ChatMessageResponse)
|
|
async def toggle_favorite(message_id: int, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(
|
|
select(DBChatMessage)
|
|
.where(DBChatMessage.id == message_id)
|
|
.options(selectinload(DBChatMessage.files))
|
|
)
|
|
message = result.scalar_one_or_none()
|
|
|
|
if not message:
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
|
|
message.is_favorite = not message.is_favorite
|
|
await db.commit()
|
|
await db.refresh(message)
|
|
|
|
return message
|