ably.do/backend/open_webui/models/users.py

424 lines
13 KiB
Python

import time
from typing import Optional, List
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from PIL import Image, ImageDraw, ImageFont
import random
import io
import base64
import requests
import requests
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
####################
# User DB Schema
####################
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
oauth_sub = Column(Text, unique=True)
permissions = Column(JSON, nullable=True)
subscription = Column(JSON, nullable=True)
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel):
id: str
name: str
email: str
role: str = "pending"
profile_image_url: str
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None
permissions: Optional[dict] = None
subscription: Optional[List[str]] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str
profile_image_url: str
class UserNameResponse(BaseModel):
id: str
name: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable:
def insert_new_user(
self,
id: str,
name: str,
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
permissions: Optional[dict] = None,
subscription: Optional[List[str]] = None
) -> Optional[UserModel]:
with get_db() as db:
user = UserModel(
**{
"id": id,
"name": name,
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
"permissions": permissions,
"subscription": subscription
}
)
result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
with get_db() as db:
query = db.query(User).order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel:
try:
with get_db() as db:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
if user.settings is None:
return None
else:
return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(updated)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception:
return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
user_settings = db.query(User).filter_by(id=id).first().settings
if user_settings is None:
user_settings = {}
user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups
Groups.remove_user_from_all_groups(id)
# Delete User Chats
result = Chats.delete_chats_by_user_id(id)
if result:
with get_db() as db:
# Delete User
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
return False
except Exception:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except Exception:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key
except Exception:
return None
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users]
@staticmethod
def generate_image_base64(letter, size=200):
if isinstance(size, int):
size = (size, size)
background_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
image = Image.new('RGB', size, color=background_color)
draw = ImageDraw.Draw(image)
text_color = (255, 255, 255)
def get_font_size(image_size, text, font_source, target_percentage=0.5):
font_size = 1
while True:
if isinstance(font_source, bytes):
font_file = io.BytesIO(font_source)
font = ImageFont.truetype(font_file, font_size)
elif isinstance(font_source, io.BytesIO):
font_data = font_source.getvalue()
font_file = io.BytesIO(font_data)
font = ImageFont.truetype(font_file, font_size)
else:
font = ImageFont.truetype(font_source, font_size)
bbox = font.getbbox(text)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
if (text_width >= target_percentage * image_size[0] or
text_height >= target_percentage * image_size[1]):
break
font_size += 1
return font_size - 1
try:
font_path = "arial.ttf"
font_size = get_font_size(size, letter, font_path)
font = ImageFont.truetype(font_path, font_size)
except IOError:
font_url = "https://fonts.gstatic.com/s/readexpro/v21/SLXYc1bJ7HE5YDoGPuzj_dh8uc7wUy8ZQQyX2IwwZEzehiB9.woff2"
response = requests.get(font_url)
font_data = response.content
font_size = get_font_size(size, letter, font_data)
font_file = io.BytesIO(font_data)
font = ImageFont.truetype(font_file, font_size)
bbox = font.getbbox(letter)
text_width = bbox[2] - bbox[0]
text_hight = bbox[3] - bbox[1]
x = (size[0] - text_width) // 2
y = (size[1] - text_hight - 80) // 2 # Poprawione obliczenie pozycji y
draw.text((x, y), letter, font=font, fill=text_color)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
def update_user_profile(self, id: str, name: str, profile_image_url: str, role: str, permissions: Optional[dict] = None, subscription: Optional[List[str]] = None) -> bool:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({
"name": name,
"profile_image_url": profile_image_url,
"role": role,
"permissions": permissions,
"subscription": subscription
})
db.commit()
return result
except Exception:
return 0
Users = UsersTable()