This commit is contained in:
l.gabrysiak 2025-02-28 19:58:02 +01:00
parent 2d37e5c858
commit 31c82973b2
4549 changed files with 281380 additions and 4685 deletions

1709
CHANGELOG.md Normal file

File diff suppressed because it is too large Load Diff

99
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,99 @@
# Contributor Covenant Code of Conduct
## Our Pledge
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
## Why These Standards Are Important
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
## Our Standards
Examples of behavior that contribute to a positive and professional community include:
- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
Examples of unacceptable behavior include:
- The use of discriminatory, demeaning, or sexualized language or behavior.
- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
### Feedback and Community Engagement
- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
### Zero Tolerance: No Warnings, Immediate Action
This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
## Enforcement Responsibilities
Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
## Scope
This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
## Reporting Violations
Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
## Enforcement Guidelines
### Ban
**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
- Harassment or abusive behavior toward contributors.
- Persistent negativity or hostility that disrupts the collaborative environment.
- Disrespectful, demanding, or aggressive interactions with others.
- Attempts to cause harm or sabotage the community.
**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
## Why Zero Tolerance Is Necessary
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

64
Caddyfile.localhost Normal file
View File

@ -0,0 +1,64 @@
# Run with
# caddy run --envfile ./example.env --config ./Caddyfile.localhost
#
# This is configured for
# - Automatic HTTPS (even for localhost)
# - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
# - CORS
# - HTTP Basic Auth API Tokens (uncomment basicauth section)
# CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
(cors-api) {
@match-cors-api-preflight method OPTIONS
handle @match-cors-api-preflight {
header {
Access-Control-Allow-Origin "{http.request.header.origin}"
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
Access-Control-Allow-Credentials "true"
Access-Control-Max-Age "3600"
defer
}
respond "" 204
}
@match-cors-api-request {
not {
header Origin "{http.request.scheme}://{http.request.host}"
}
header Origin "{http.request.header.origin}"
}
handle @match-cors-api-request {
header {
Access-Control-Allow-Origin "{http.request.header.origin}"
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
Access-Control-Allow-Credentials "true"
Access-Control-Max-Age "3600"
defer
}
}
}
# replace localhost with example.com or whatever
localhost {
## HTTP Basic Auth
## (uncomment to enable)
# basicauth {
# # see .example.env for how to generate tokens
# {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
# }
handle /api/* {
# Comment to disable CORS
import cors-api
reverse_proxy localhost:11434
}
# Same-Origin Static Web Server
file_server {
root ./build/
}
}

View File

@ -1,30 +1,176 @@
# Użyj oficjalnego obrazu Python jako bazowego
FROM --platform=linux/amd64 python:3.9-slim
# syntax=docker/dockerfile:1
# Initialize device type args
# use build args in the docker build command with --build-arg="BUILDARG=true"
ARG USE_CUDA=false
ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested
ARG UID=0
ARG GID=0
######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
ARG BUILD_HASH
# Ustaw katalog roboczy w kontenerze
WORKDIR /app
# Zainstaluj git
RUN apt-get update && apt-get install -y git nano wget curl iputils-ping
COPY package.json package-lock.json ./
RUN npm ci
# Skopiuj pliki wymagań (jeśli istnieją) i zainstaluj zależności
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Skopiuj plik requirements.txt do kontenera
COPY requirements.txt .
# Zainstaluj zależności z pliku requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
# Zainstaluj Tesseract OCR
RUN apt-get install -y tesseract-ocr
# Skopiuj kod źródłowy do kontenera
COPY . .
COPY entrypoint.sh /entrypoint.sh
ENV APP_BUILD_HASH=${BUILD_HASH}
RUN npm run build
RUN chmod +x /entrypoint.sh
######## WebUI backend ########
FROM python:3.11-slim-bookworm AS base
# Uruchom aplikację
ENTRYPOINT ["/entrypoint.sh"]
# Use args
ARG USE_CUDA
ARG USE_OLLAMA
ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL
ARG UID
ARG GID
## Basis ##
ENV ENV=prod \
PORT=8080 \
# pass build args to the build
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
## Basis URL Config ##
ENV OLLAMA_BASE_URL="/ollama" \
OPENAI_API_BASE_URL=""
## API Key and Security Config ##
ENV OPENAI_API_KEY="" \
WEBUI_SECRET_KEY="" \
SCARF_NO_ANALYTICS=true \
DO_NOT_TRACK=true \
ANONYMIZED_TELEMETRY=false
#### Other models #########################################################
## whisper TTS model settings ##
ENV WHISPER_MODEL="base" \
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
## RAG Embedding model settings ##
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Tiktoken model settings ##
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
## Torch Extensions ##
# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
#### Other models ##########################################################
WORKDIR /app/backend
ENV HOME=/root
# Create user and group if not root
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
addgroup --gid $GID app; \
fi; \
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
fi
RUN mkdir -p $HOME/.cache/chroma
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
# Make sure the user has access to the app and root directory
RUN chown -R $UID:$GID /app $HOME
RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
apt-get install -y --no-install-recommends curl jq && \
# install ollama
curl -fsSL https://ollama.com/install.sh | sh && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
else \
apt-get update && \
# Install pandoc, netcat and gcc
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
apt-get install -y --no-install-recommends gcc python3-dev && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
rm -rf /var/lib/apt/lists/*; \
fi
# install python dependencies
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
RUN pip3 install uv && \
if [ "$USE_CUDA" = "true" ]; then \
# If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
fi; \
chown -R $UID:$GID /app/backend/data/
# copy embedding weight from build
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
# copy built frontend files
COPY --chown=$UID:$GID --from=build /app/build /app/build
COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
# copy backend files
COPY --chown=$UID:$GID ./backend .
EXPOSE 8080
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
USER $UID:$GID
ARG BUILD_HASH
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
ENV DOCKER=true
CMD [ "bash", "start.sh"]

35
INSTALLATION.md Normal file
View File

@ -0,0 +1,35 @@
### Installing Both Ollama and Open WebUI Using Kustomize
For cpu-only pod
```bash
kubectl apply -f ./kubernetes/manifest/base
```
For gpu-enabled pod
```bash
kubectl apply -k ./kubernetes/manifest
```
### Installing Both Ollama and Open WebUI Using Helm
Package Helm file first
```bash
helm package ./kubernetes/helm/
```
For cpu-only pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz
```
For gpu-enabled pod
```bash
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
```
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization

27
LICENSE Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2023-2025 Timothy Jaeryang Baek
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

33
Makefile Normal file
View File

@ -0,0 +1,33 @@
ifneq ($(shell which docker-compose 2>/dev/null),)
DOCKER_COMPOSE := docker-compose
else
DOCKER_COMPOSE := docker compose
endif
install:
$(DOCKER_COMPOSE) up -d
remove:
@chmod +x confirm_remove.sh
@./confirm_remove.sh
start:
$(DOCKER_COMPOSE) start
startAndBuild:
$(DOCKER_COMPOSE) up -d --build
stop:
$(DOCKER_COMPOSE) stop
update:
# Calls the LLM update script
chmod +x update_ollama_models.sh
@./update_ollama_models.sh
@git pull
$(DOCKER_COMPOSE) down
# Make sure the ollama-webui container is stopped before rebuilding
@docker stop open-webui || true
$(DOCKER_COMPOSE) up --build -d
$(DOCKER_COMPOSE) start

36
TROUBLESHOOTING.md Normal file
View File

@ -0,0 +1,36 @@
# Open WebUI Troubleshooting Guide
## Understanding the Open WebUI Architecture
The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
## Open WebUI: Server Connection Error
If you're experiencing connection issues, its often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
**Example Docker Command**:
```bash
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
### Error on Slow Responses for Ollama
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
### General Connection Errors
**Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
**Troubleshooting Steps**:
1. **Verify Ollama URL Format**:
- When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
- In the Open WebUI, navigate to "Settings" > "General".
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.

View File

@ -1,120 +0,0 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import weaviate
from weaviate.client import WeaviateClient
from weaviate.connect import ConnectionParams
# 1⃣ Inicjalizacja modelu do embeddingów
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# 2⃣ Połączenie z Weaviate i pobranie dokumentów
client = WeaviateClient(
connection_params=ConnectionParams.from_params(
http_host="weaviate",
http_port=8080,
http_secure=False,
grpc_host="weaviate",
grpc_port=50051,
grpc_secure=False,
)
)
collection_name = "Document" # Zakładam, że to jest nazwa Twojej kolekcji
result = (
client.query.get(collection_name, ["content"])
.with_additional(["id"])
.do()
)
documents = [item['content'] for item in result['data']['Get'][collection_name]]
# 3⃣ Generowanie embeddingów
embeddings = embed_model.encode(documents)
# 4⃣ Przygotowanie danych treningowych
def create_training_data():
data = {
"text": documents,
"embedding": embeddings.tolist()
}
return Dataset.from_dict(data)
dataset = create_training_data()
# Podział danych na treningowe i ewaluacyjne
split_dataset = dataset.train_test_split(test_size=0.25)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# 5⃣ Ładowanie modelu allegro/multislav-5lang
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "allegro/multislav-5lang"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 6⃣ Konfiguracja LoRA
lora_config = LoraConfig(
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, lora_config)
# 7⃣ Tokenizacja danych
max_length = 384
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length
)
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# 8⃣ Parametry treningu
training_args = TrainingArguments(
output_dir="./results",
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
learning_rate=1e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=16,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
)
# 9⃣ Data Collator
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model
)
# 🔟 Trening modelu
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
trainer.train()
# 1⃣1⃣ Zapis modelu
model.save_pretrained("./models/allegro")
tokenizer.save_pretrained("./models/allegro")
print("✅ Model został wytrenowany i zapisany!")

14
backend/.dockerignore Normal file
View File

@ -0,0 +1,14 @@
__pycache__
.env
_old
uploads
.ipynb_checkpoints
*.db
_test
!/data
/data/*
!/data/litellm
/data/litellm/*
!data/litellm/config.yaml
!data/config.json

12
backend/.gitignore vendored Normal file
View File

@ -0,0 +1,12 @@
__pycache__
.env
_old
uploads
.ipynb_checkpoints
*.db
_test
Pipfile
!/data
/data/*
/open_webui/data/*
.webui_secret_key

2
backend/dev.sh Executable file
View File

@ -0,0 +1,2 @@
PORT="${PORT:-8080}"
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload

View File

@ -0,0 +1,96 @@
import base64
import os
import random
from pathlib import Path
import typer
import uvicorn
from typing import Optional
from typing_extensions import Annotated
app = typer.Typer()
KEY_FILE = Path.cwd() / ".webui_secret_key"
def version_callback(value: bool):
if value:
from open_webui.env import VERSION
typer.echo(f"Open WebUI version: {VERSION}")
raise typer.Exit()
@app.command()
def main(
version: Annotated[
Optional[bool], typer.Option("--version", callback=version_callback)
] = None,
):
pass
@app.command()
def serve(
host: str = "0.0.0.0",
port: int = 8080,
):
os.environ["FROM_INIT_PY"] = "true"
if os.getenv("WEBUI_SECRET_KEY") is None:
typer.echo(
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
)
if not KEY_FILE.exists():
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
typer.echo(
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
)
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
os.environ["LD_LIBRARY_PATH"] = ":".join(
LD_LIBRARY_PATH
+ [
"/usr/local/lib/python3.11/site-packages/torch/lib",
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
]
)
try:
import torch
assert torch.cuda.is_available(), "CUDA not available"
typer.echo("CUDA seems to be working")
except Exception as e:
typer.echo(
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
"Resetting USE_CUDA_DOCKER to false and removing "
f"LD_LIBRARY_PATH modifications: {e}"
)
os.environ["USE_CUDA_DOCKER"] = "false"
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
@app.command()
def dev(
host: str = "0.0.0.0",
port: int = 8080,
reload: bool = True,
):
uvicorn.run(
"open_webui.main:app",
host=host,
port=port,
reload=reload,
forwarded_allow_ips="*",
)
if __name__ == "__main__":
app()

View File

@ -0,0 +1,114 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

2436
backend/open_webui/config.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,119 @@
from enum import Enum
class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
MODEL_DELETED = (
lambda model="": f"The model '{model}' has been deleted successfully."
)
class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
USER_SIGNUP = lambda username="": (
f"New user signed up: {username}" if username else "New user signed up"
)
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = (
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
)
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again."
)
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again."
)
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure."
)
FILE_NOT_SENT = "FILE_NOT_SENT"
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err}"
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
)
WEB_SEARCH_ERROR = (
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
FILE_TOO_LARGE = (
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
)
DUPLICATE_CONTENT = (
"Duplicate content detected. Please provide unique content to proceed."
)
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"

443
backend/open_webui/env.py Normal file
View File

@ -0,0 +1,443 @@
import importlib.metadata
import json
import logging
import os
import pkgutil
import sys
import shutil
from pathlib import Path
import markdown
from bs4 import BeautifulSoup
from open_webui.constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
print(OPEN_WEBUI_DIR)
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BACKEND_DIR)
print(BASE_DIR)
try:
from dotenv import find_dotenv, load_dotenv
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
except ImportError:
print("dotenv not installed, skipping...")
DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
try:
import torch
assert torch.cuda.is_available(), "CUDA not available"
DEVICE_TYPE = "cuda"
except Exception as e:
cuda_error = (
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
f"Resetting USE_CUDA_DOCKER to false: {e}"
)
os.environ["USE_CUDA_DOCKER"] = "false"
USE_CUDA = "false"
DEVICE_TYPE = "cpu"
else:
DEVICE_TYPE = "cpu"
try:
import torch
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
DEVICE_TYPE = "mps"
except Exception:
pass
####################################
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
log = logging.getLogger(__name__)
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
if "cuda_error" in locals():
log.exception(cuda_error)
log_sources = [
"AUDIO",
"COMFYUI",
"CONFIG",
"DB",
"IMAGES",
"MAIN",
"MODELS",
"OLLAMA",
"OPENAI",
"RAG",
"WEBHOOK",
"SOCKET",
"OAUTH",
]
SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
if FROM_INIT_PY:
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
else:
try:
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
except Exception:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# ENABLE_FORWARD_USER_INFO_HEADERS
####################################
ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
)
####################################
# WEBUI_BUILD_HASH
####################################
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
if FROM_INIT_PY:
NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)
# Check if the data directory exists in the package directory
if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
for item in DATA_DIR.iterdir():
dest = NEW_DATA_DIR / item.name
if item.is_dir():
shutil.copytree(item, dest, dirs_exist_ok=True)
else:
shutil.copy2(item, dest)
# Zip the data directory
shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
# Remove the old data directory
shutil.rmtree(DATA_DIR)
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
if FROM_INIT_PY:
FRONTEND_BUILD_DIR = Path(
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
).resolve()
####################################
# Database
####################################
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
log.info("Database migrated from Ollama-WebUI successfully.")
else:
pass
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
if DATABASE_POOL_SIZE == "":
DATABASE_POOL_SIZE = 0
else:
try:
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
except Exception:
DATABASE_POOL_SIZE = 0
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
if DATABASE_POOL_MAX_OVERFLOW == "":
DATABASE_POOL_MAX_OVERFLOW = 0
else:
try:
DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW)
except Exception:
DATABASE_POOL_MAX_OVERFLOW = 0
DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30)
if DATABASE_POOL_TIMEOUT == "":
DATABASE_POOL_TIMEOUT = 30
else:
try:
DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT)
except Exception:
DATABASE_POOL_TIMEOUT = 30
DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600)
if DATABASE_POOL_RECYCLE == "":
DATABASE_POOL_RECYCLE = 3600
else:
try:
DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE)
except Exception:
DATABASE_POOL_RECYCLE = 3600
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
)
####################################
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
)
####################################
# WEBUI_SECRET_KEY
####################################
WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax")
WEBUI_SESSION_COOKIE_SECURE = (
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true"
)
WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE
)
WEBUI_AUTH_COOKIE_SECURE = (
os.environ.get(
"WEBUI_AUTH_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"),
).lower()
== "true"
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
ENABLE_WEBSOCKET_SUPPORT = (
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
)
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "":
AIOHTTP_CLIENT_TIMEOUT = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
####################################
# OFFLINE_MODE
####################################
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048
# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]

View File

@ -0,0 +1,319 @@
import logging
import sys
import inspect
import json
import asyncio
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
)
from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
function_module = get_function_module_by_id(request, pipe.id)
# Check if function is a manifold
if hasattr(function_module, "pipes"):
sub_pipes = []
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
pipe_flag = {"type": pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
else:
pipe_flag = {"type": "pipe"}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
)
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
}
)
return pipe_models
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
else:
return pipe(**params)
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
pipe_id = get_pipe_id(form_data)
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
return params
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
__event_emitter__ = None
__event_call__ = None
__task__ = None
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__metadata__": metadata,
"__request__": request,
}
extra_params["__tools__"] = get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
async def stream_content():
try:
res = await execute_pipe(pipe, params)
# Directly return if the response is a StreamingResponse
if isinstance(res, StreamingResponse):
async for data in res.body_iterator:
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
yield process_line(form_data, line)
if isinstance(res, AsyncGenerator):
async for line in res:
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
if isinstance(res, BaseModel):
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)

View File

@ -0,0 +1,116 @@
import json
import logging
from contextlib import contextmanager
from typing import Any, Optional
from open_webui.internal.wrappers import register_connection
from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,
DATABASE_SCHEMA,
SRC_LOG_LEVELS,
DATABASE_POOL_MAX_OVERFLOW,
DATABASE_POOL_RECYCLE,
DATABASE_POOL_SIZE,
DATABASE_POOL_TIMEOUT,
)
from peewee_migrate import Router
from sqlalchemy import Dialect, create_engine, MetaData, types
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.sql.type_api import _T
from typing_extensions import Self
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
return json.dumps(value)
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
def copy(self, **kw: Any) -> Self:
return JSONField(self.impl.length)
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
# Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration
def handle_peewee_migration(DATABASE_URL):
# db = None
try:
# Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
finally:
# Properly closing the database connection
if db and not db.is_closed():
db.close()
# Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open."
handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
if DATABASE_POOL_SIZE > 0:
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
pool_timeout=DATABASE_POOL_TIMEOUT,
pool_recycle=DATABASE_POOL_RECYCLE,
pool_pre_ping=True,
poolclass=QueuePool,
)
else:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj)
Session = scoped_session(SessionLocal)
def get_session():
db = SessionLocal()
try:
yield db
finally:
db.close()
get_db = contextmanager(get_session)

View File

@ -0,0 +1,254 @@
"""Peewee migrations -- 001_initial_schema.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# We perform different migrations for SQLite and other databases
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
# will require per-database SQL queries.
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
# schema instead of trying to migrate from an older schema.
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.CharField(max_length=255)
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.CharField()
filename = pw.CharField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.TextField()
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
chat = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.TextField()
filename = pw.TextField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.TextField()
content = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.TextField()
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("user")
migrator.remove_model("tag")
migrator.remove_model("prompt")
migrator.remove_model("modelfile")
migrator.remove_model("document")
migrator.remove_model("chatidtag")
migrator.remove_model("chat")
migrator.remove_model("auth")

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id")

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")

View File

@ -0,0 +1,46 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "archived")

View File

@ -0,0 +1,130 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
if isinstance(database, pw.SqliteDatabase):
migrate_sqlite(migrator, database, fake=fake)
else:
migrate_external(migrator, database, fake=fake)
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.DateTimeField(null=True), # Allow null for transition
updated_at=pw.DateTimeField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.DateTimeField(null=False),
updated_at=pw.DateTimeField(null=False),
)
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
rollback_sqlite(migrator, database, fake=fake)
else:
rollback_external(migrator, database, fake=fake)
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))

View File

@ -0,0 +1,130 @@
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
password=pw.TextField(),
)
migrator.change_fields(
"chat",
title=pw.TextField(),
)
migrator.change_fields(
"document",
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
title=pw.TextField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.TextField(),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
title=pw.CharField(),
)
migrator.change_fields(
"document",
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
title=pw.CharField(),
)
migrator.change_fields(
"user",
profile_image_url=pw.CharField(),
)

View File

@ -0,0 +1,79 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields created_at and updated_at to the 'user' table
migrator.add_fields(
"user",
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL'
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("user", "timestamp")
# Update the fields to be not null now that they are populated
migrator.change_fields(
"user",
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
last_active_at=pw.BigIntegerField(null=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql('UPDATE "user" SET timestamp = created_at')
# Remove the created_at and updated_at fields
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))

View File

@ -0,0 +1,53 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Memory(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
content = pw.TextField(null=False)
updated_at = pw.BigIntegerField(null=False)
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "memory"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("memory")

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Model(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
base_model_id = pw.TextField(null=True)
name = pw.TextField()
meta = pw.TextField()
params = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "model"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("model")

View File

@ -0,0 +1,130 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
import json
from open_webui.utils.misc import parse_ollama_modelfile
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
modelfiles = ModelFile.select()
for modelfile in modelfiles:
# Extract and transform data in Python
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": True},
}
)
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
# Insert the processed data into the 'model' table
Model.create(
id=f"ollama-{modelfile.tag_name}",
user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"),
name=modelfile.modelfile.get("title"),
meta=meta,
params=json.dumps(info.get("params", {})),
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
query = """
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator.sql(query)
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
models = Model.select()
for model in models:
# Extract and transform data in Python
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
}
# Insert the processed data back into the 'modelfile' table
Modelfile.create(
user_id=model.user_id,
tag_name=model.id,
modelfile=modelfile_data,
timestamp=model.created_at,
)

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")

View File

@ -0,0 +1,55 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class File(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
filename = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "file"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("file")

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")

View File

@ -0,0 +1,50 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("tool", valves=pw.TextField(null=True))
migrator.add_fields("function", valves=pw.TextField(null=True))
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("tool", "valves")
migrator.remove_fields("function", "valves")
migrator.remove_fields("function", "is_active")

View File

@ -0,0 +1,45 @@
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")

View File

@ -0,0 +1,49 @@
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"function",
is_global=pw.BooleanField(default=False),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("function", "is_global")

View File

@ -0,0 +1,66 @@
import logging
from contextvars import ContextVar
from open_webui.env import SRC_LOG_LEVELS
from peewee import *
from peewee import InterfaceError as PeeWeeInterfaceError
from peewee import PostgresqlDatabase
from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
class PeeweeConnectionState(object):
def __init__(self, **kwargs):
super().__setattr__("_state", db_state)
super().__init__(**kwargs)
def __setattr__(self, name, value):
self._state.get()[name] = value
def __getattr__(self, name):
value = self._state.get()[name]
return value
class CustomReconnectMixin(ReconnectMixin):
reconnect_errors = (
# psycopg2
(OperationalError, "termin"),
(InterfaceError, "closed"),
# peewee
(PeeWeeInterfaceError, "closed"),
)
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
pass
def register_connection(db_url):
db = connect(db_url, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else:
raise ValueError("Unsupported database connection")
return db

1384
backend/open_webui/main.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"

View File

@ -0,0 +1,81 @@
from logging.config import fileConfig
from alembic import context
from open_webui.models.auths import Auth
from open_webui.env import DATABASE_URL
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,27 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,15 @@
from alembic import op
from sqlalchemy import Inspector
def get_existing_tables():
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
return tables
def get_revision_id():
import uuid
return str(uuid.uuid4()).replace("-", "")[:12]

View File

@ -0,0 +1,151 @@
"""Migrate tags
Revision ID: 1af9b942657b
Revises: 242a2047eae0
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
branch_labels = None
depends_on = None
def upgrade():
# Setup an inspection on the existing table to avoid issues
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
# Check if the 'tag' table exists
tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables:
# Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")]
# Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Check if the unique constraint already exists
if not any(
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
# Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
# Check if the 'data' column exists before trying to drop it
if "data" in columns:
batch_op.drop_column("data")
# Check if the 'meta' column needs to be created
if "meta" not in columns:
# Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
)
# Step 2: Migrate tags
conn = op.get_bind()
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
# Check if the new_tag_id already exists in the database
existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
existing_tag_result = conn.execute(existing_tag_query).fetchone()
if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists
print(
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
else:
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
update_stmt = update_stmt.values(id=new_tag_id)
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
)
# Fetch existing tags
conn = op.get_bind()
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
if tag_name == "pinned":
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
else:
chat_updates[chat_id]["pinned"] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt)
pass
def downgrade():
pass

View File

@ -0,0 +1,107 @@
"""Update chat table
Revision ID: 242a2047eae0
Revises: 6a39f3d8e55c
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update
import json
revision = "242a2047eae0"
down_revision = "6a39f3d8e55c"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = inspector.get_columns("chat")
column_dict = {col["name"]: col for col in columns}
chat_column = column_dict.get("chat")
old_chat_exists = "old_chat" in column_dict
if chat_column:
if isinstance(chat_column["type"], sa.Text):
print("Converting 'chat' column to JSON")
if old_chat_exists:
print("Dropping old 'old_chat' column")
op.drop_column("chat", "old_chat")
# Step 1: Rename current 'chat' column to 'old_chat'
print("Renaming 'chat' column to 'old_chat'")
op.alter_column(
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
)
# Step 2: Add new 'chat' column of type JSON
print("Adding new 'chat' column of type JSON")
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
else:
# If the column is already JSON, no need to do anything
pass
# Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table(
"chat",
sa.Column("id", sa.String(), primary_key=True),
sa.Column("old_chat", sa.Text()),
sa.Column("chat", sa.JSON()),
)
# - Selecting all data from the table
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
for row in results:
try:
# Convert text JSON to actual JSON object, assuming the text is in JSON format
json_data = json.loads(row.old_chat)
except json.JSONDecodeError:
json_data = None # Handle cases where the text cannot be converted to JSON
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(chat=json_data)
)
# Step 4: Drop 'old_chat' column
print("Dropping 'old_chat' column")
op.drop_column("chat", "old_chat")
def downgrade():
# Step 1: Add 'old_chat' column back as Text
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table(
"chat",
sa.Column("id", sa.String(), primary_key=True),
sa.Column("chat", sa.JSON()),
sa.Column("old_chat", sa.Text()),
)
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
for row in results:
text_data = json.dumps(row.chat) if row.chat is not None else None
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(old_chat=text_data)
)
# Step 3: Remove the new 'chat' JSON column
op.drop_column("chat", "chat")
# Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())

View File

@ -0,0 +1,70 @@
"""Update message & channel tables
Revision ID: 3781e22d8b01
Revises: 7826ab40b532
Create Date: 2024-12-30 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "3781e22d8b01"
down_revision = "7826ab40b532"
branch_labels = None
depends_on = None
def upgrade():
# Add 'type' column to the 'channel' table
op.add_column(
"channel",
sa.Column(
"type",
sa.Text(),
nullable=True,
),
)
# Add 'parent_id' column to the 'message' table for threads
op.add_column(
"message",
sa.Column("parent_id", sa.Text(), nullable=True),
)
op.create_table(
"message_reaction",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Unique reaction ID
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
sa.Column(
"message_id", sa.Text(), nullable=False
), # Message that was reacted to
sa.Column(
"name", sa.Text(), nullable=False
), # Reaction name (e.g. "thumbs_up")
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the reaction was added
)
op.create_table(
"channel_member",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Record ID for the membership row
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the user joined the channel
)
def downgrade():
# Revert 'type' column addition to the 'channel' table
op.drop_column("channel", "type")
op.drop_column("message", "parent_id")
op.drop_table("message_reaction")
op.drop_table("channel_member")

View File

@ -0,0 +1,81 @@
"""Update tags
Revision ID: 3ab32c4b8f59
Revises: 1af9b942657b
Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
revision = "3ab32c4b8f59"
down_revision = "1af9b942657b"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
# Inspecting the 'tag' table constraints and structure
existing_pk = inspector.get_pk_constraint("tag")
unique_constraints = inspector.get_unique_constraints("tag")
existing_indexes = inspector.get_indexes("tag")
print(f"Primary Key: {existing_pk}")
print(f"Unique Constraints: {unique_constraints}")
print(f"Indexes: {existing_indexes}")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop existing primary key constraint if it exists
if existing_pk and existing_pk.get("constrained_columns"):
pk_name = existing_pk.get("name")
if pk_name:
print(f"Dropping primary key constraint: {pk_name}")
batch_op.drop_constraint(pk_name, type_="primary")
# Now create the new primary key with the combination of 'id' and 'user_id'
print("Creating new primary key with 'id' and 'user_id'.")
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
# Drop unique constraints that could conflict with the new primary key
for constraint in unique_constraints:
if (
constraint["name"] == "uq_id_user_id"
): # Adjust this name according to what is actually returned by the inspector
print(f"Dropping unique constraint: {constraint['name']}")
batch_op.drop_constraint(constraint["name"], type_="unique")
for index in existing_indexes:
if index["unique"]:
if not any(
constraint["name"] == index["name"]
for constraint in unique_constraints
):
# You are attempting to drop unique indexes
print(f"Dropping unique index: {index['name']}")
batch_op.drop_index(index["name"])
def downgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
current_pk = inspector.get_pk_constraint("tag")
with op.batch_alter_table("tag", schema=None) as batch_op:
# Drop the current primary key first, if it matches the one we know we added in upgrade
if current_pk and "pk_id_user_id" == current_pk.get("name"):
batch_op.drop_constraint("pk_id_user_id", type_="primary")
# Restore the original primary key
batch_op.create_primary_key("pk_id", ["id"])
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])

View File

@ -0,0 +1,67 @@
"""Update folder table and change DateTime to BigInteger for timestamp fields
Revision ID: 4ace53fd72c8
Revises: af906e964978
Create Date: 2024-10-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "4ace53fd72c8"
down_revision = "af906e964978"
branch_labels = None
depends_on = None
def upgrade():
# Perform safe alterations using batch operation
with op.batch_alter_table("folder", schema=None) as batch_op:
# Step 1: Remove server defaults for created_at and updated_at
batch_op.alter_column(
"created_at",
server_default=None, # Removing server default
)
batch_op.alter_column(
"updated_at",
server_default=None, # Removing server default
)
# Step 2: Change the column types to BigInteger for created_at
batch_op.alter_column(
"created_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
)
# Change the column types to BigInteger for updated_at
batch_op.alter_column(
"updated_at",
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
)
def downgrade():
# Downgrade: Convert columns back to DateTime and restore defaults
with op.batch_alter_table("folder", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
)
batch_op.alter_column(
"updated_at",
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
onupdate=sa.func.now(), # Restoring onupdate behavior if it was there
)

View File

@ -0,0 +1,48 @@
"""Add channel table
Revision ID: 57c599a3cb57
Revises: 922e7a387820
Create Date: 2024-12-22 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "57c599a3cb57"
down_revision = "922e7a387820"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"channel",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("name", sa.Text()),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
op.create_table(
"message",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("content", sa.Text()),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("channel")
op.drop_table("message")

View File

@ -0,0 +1,80 @@
"""Add knowledge table
Revision ID: 6a39f3d8e55c
Revises: c0fbf31ca0db
Create Date: 2024-10-01 14:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
import json
revision = "6a39f3d8e55c"
down_revision = "c0fbf31ca0db"
branch_labels = None
depends_on = None
def upgrade():
# Creating the 'knowledge' table
print("Creating knowledge table")
knowledge_table = op.create_table(
"knowledge",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
print("Migrating data from document table to knowledge table")
# Representation of the existing 'document' table
document_table = table(
"document",
column("collection_name", sa.String()),
column("user_id", sa.String()),
column("name", sa.String()),
column("title", sa.Text()),
column("content", sa.Text()),
column("timestamp", sa.BigInteger()),
)
# Select all from existing document table
documents = op.get_bind().execute(
select(
document_table.c.collection_name,
document_table.c.user_id,
document_table.c.name,
document_table.c.title,
document_table.c.content,
document_table.c.timestamp,
)
)
# Insert data into knowledge table from document table
for doc in documents:
op.get_bind().execute(
knowledge_table.insert().values(
id=doc.collection_name,
user_id=doc.user_id,
description=doc.name,
meta={
"legacy": True,
"document": True,
"tags": json.loads(doc.content or "{}").get("tags", []),
},
name=doc.title,
created_at=doc.timestamp,
updated_at=doc.timestamp, # using created_at for both created_at and updated_at in project
)
)
def downgrade():
op.drop_table("knowledge")

View File

@ -0,0 +1,26 @@
"""Update file table
Revision ID: 7826ab40b532
Revises: 57c599a3cb57
Create Date: 2024-12-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "7826ab40b532"
down_revision = "57c599a3cb57"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"file",
sa.Column("access_control", sa.JSON(), nullable=True),
)
def downgrade():
op.drop_column("file", "access_control")

View File

@ -0,0 +1,204 @@
"""init
Revision ID: 7e5b5dc7342b
Revises:
Create Date: 2024-06-24 13:15:33.808998
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
import open_webui.internal.db
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "chat" not in existing_tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if "chatidtag" not in existing_tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "document" not in existing_tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if "file" not in existing_tables:
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "function" not in existing_tables:
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "memory" not in existing_tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "model" not in existing_tables:
op.create_table(
"model",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "prompt" not in existing_tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if "tag" not in existing_tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "tool" not in existing_tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if "user" not in existing_tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", JSONField(), nullable=True),
sa.Column("info", JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
# ### end Alembic commands ###

View File

@ -0,0 +1,85 @@
"""Add group table
Revision ID: 922e7a387820
Revises: 4ace53fd72c8
Create Date: 2024-11-14 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "922e7a387820"
down_revision = "4ace53fd72c8"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"group",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("permissions", sa.JSON(), nullable=True),
sa.Column("user_ids", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
# Add 'access_control' column to 'model' table
op.add_column(
"model",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'is_active' column to 'model' table
op.add_column(
"model",
sa.Column(
"is_active",
sa.Boolean(),
nullable=False,
server_default=sa.sql.expression.true(),
),
)
# Add 'access_control' column to 'knowledge' table
op.add_column(
"knowledge",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'prompt' table
op.add_column(
"prompt",
sa.Column("access_control", sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'tools' table
op.add_column(
"tool",
sa.Column("access_control", sa.JSON(), nullable=True),
)
def downgrade():
op.drop_table("group")
# Drop 'access_control' column from 'model' table
op.drop_column("model", "access_control")
# Drop 'is_active' column from 'model' table
op.drop_column("model", "is_active")
# Drop 'access_control' column from 'knowledge' table
op.drop_column("knowledge", "access_control")
# Drop 'access_control' column from 'prompt' table
op.drop_column("prompt", "access_control")
# Drop 'access_control' column from 'tools' table
op.drop_column("tool", "access_control")

View File

@ -0,0 +1,51 @@
"""Add feedback table
Revision ID: af906e964978
Revises: c29facfe716b
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
# Revision identifiers, used by Alembic.
revision = "af906e964978"
down_revision = "c29facfe716b"
branch_labels = None
depends_on = None
def upgrade():
# ### Create feedback table ###
op.create_table(
"feedback",
sa.Column(
"id", sa.Text(), primary_key=True
), # Unique identifier for each feedback (TEXT type)
sa.Column(
"user_id", sa.Text(), nullable=True
), # ID of the user providing the feedback (TEXT type)
sa.Column(
"version", sa.BigInteger(), default=0
), # Version of feedback (BIGINT type)
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column(
"meta", sa.JSON(), nullable=True
), # Metadata for feedback (JSON type)
sa.Column(
"snapshot", sa.JSON(), nullable=True
), # snapshot data for feedback (JSON type)
sa.Column(
"created_at", sa.BigInteger(), nullable=False
), # Feedback creation timestamp (BIGINT representing epoch)
sa.Column(
"updated_at", sa.BigInteger(), nullable=False
), # Feedback update timestamp (BIGINT representing epoch)
)
def downgrade():
# ### Drop feedback table ###
op.drop_table("feedback")

View File

@ -0,0 +1,32 @@
"""Update file table
Revision ID: c0fbf31ca0db
Revises: ca81bd47c050
Create Date: 2024-09-20 15:26:35.241684
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c0fbf31ca0db"
down_revision: Union[str, None] = "ca81bd47c050"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("file", "updated_at")
op.drop_column("file", "data")
op.drop_column("file", "hash")

View File

@ -0,0 +1,79 @@
"""Update file table path
Revision ID: c29facfe716b
Revises: c69f45358db4
Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
import json
from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_
revision = "c29facfe716b"
down_revision = "c69f45358db4"
branch_labels = None
depends_on = None
def upgrade():
# 1. Add the `path` column to the "file" table.
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
# Use Alembic's default batch_op for dialect compatibility.
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta",
type_=sa.JSON(),
existing_type=sa.Text(),
existing_nullable=True,
nullable=True,
postgresql_using="meta::json",
)
# 3. Migrate legacy data from `meta` JSONField
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
# We will use SQLAlchemy core bindings to ensure safety across different databases.
file_table = table(
"file", column("id", String), column("meta", JSON), column("path", Text)
)
# Create connection to the database
connection = op.get_bind()
# Get the rows where `meta` has a path and `path` column is null (new column)
# Loop through each row in the result set to update the path
results = connection.execute(
sa.select(file_table.c.id, file_table.c.meta).where(
and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None))
)
).fetchall()
# Iterate over each row to extract and update the `path` from `meta` column
for row in results:
if "path" in row.meta:
# Extract the `path` field from the `meta` JSON
path = row.meta.get("path")
# Update the `file` table with the new `path` value
connection.execute(
file_table.update()
.where(file_table.c.id == row.id)
.values({"path": path})
)
def downgrade():
# 1. Remove the `path` column
op.drop_column("file", "path")
# 2. Revert the `meta` column back to Text/JSONField
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
)

View File

@ -0,0 +1,50 @@
"""Add folder table
Revision ID: c69f45358db4
Revises: 3ab32c4b8f59
Create Date: 2024-10-16 02:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
revision = "c69f45358db4"
down_revision = "3ab32c4b8f59"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"folder",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("items", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.PrimaryKeyConstraint("id", "user_id"),
)
op.add_column(
"chat",
sa.Column("folder_id", sa.Text(), nullable=True),
)
def downgrade():
op.drop_column("chat", "folder_id")
op.drop_table("folder")

View File

@ -0,0 +1,41 @@
"""Add config table
Revision ID: ca81bd47c050
Revises: 7e5b5dc7342b
Create Date: 2024-08-25 15:26:35.241684
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ca81bd47c050"
down_revision: Union[str, None] = "7e5b5dc7342b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
op.create_table(
"config",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column(
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
),
sa.Column(
"updated_at",
sa.DateTime(),
nullable=True,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
)
def downgrade():
op.drop_table("config")

View File

@ -0,0 +1,206 @@
import logging
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import UserModel, Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
from open_webui.utils.auth import verify_password
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# DB MODEL
####################
class Auth(Base):
__tablename__ = "auth"
id = Column(String, primary_key=True)
email = Column(String)
password = Column(Text)
active = Column(Boolean)
class AuthModel(BaseModel):
id: str
email: str
password: str
active: bool = True
####################
# Forms
####################
class Token(BaseModel):
token: str
token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel):
id: str
email: str
name: str
role: str
profile_image_url: str
class SigninResponse(Token, UserResponse):
pass
class SigninForm(BaseModel):
email: str
password: str
class LdapForm(BaseModel):
user: str
password: str
class ProfileImageUrlForm(BaseModel):
profile_image_url: str
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
class UpdatePasswordForm(BaseModel):
password: str
new_password: str
class SignupForm(BaseModel):
name: str
email: str
password: str
profile_image_url: Optional[str] = "/user.png"
class AddUserForm(SignupForm):
role: Optional[str] = "pending"
class AuthsTable:
def insert_new_auth(
self,
email: str,
password: str,
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
log.info("insert_new_auth")
id = str(uuid.uuid4())
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
db.commit()
db.refresh(result)
if result and user:
return user
else:
return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
return user
else:
return None
else:
return None
except Exception:
return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except Exception:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
user = Users.get_user_by_id(auth.id)
return user
except Exception:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try:
with get_db() as db:
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
db.commit()
return True if result == 1 else False
except Exception:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
try:
with get_db() as db:
result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit()
return True if result == 1 else False
except Exception:
return False
def delete_auth_by_id(self, id: str) -> bool:
try:
with get_db() as db:
# Delete User
result = Users.delete_user_by_id(id)
if result:
db.query(Auth).filter_by(id=id).delete()
db.commit()
return True
else:
return False
except Exception:
return False
Auths = AuthsTable()

View File

@ -0,0 +1,136 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Channel DB Schema
####################
class Channel(Base):
__tablename__ = "channel"
id = Column(Text, primary_key=True)
user_id = Column(Text)
type = Column(Text, nullable=True)
name = Column(Text)
description = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class ChannelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
type: Optional[str] = None
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class ChannelForm(BaseModel):
name: str
description: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class ChannelTable:
def insert_new_channel(
self, type: Optional[str], form_data: ChannelForm, user_id: str
) -> Optional[ChannelModel]:
with get_db() as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
"type": type,
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_channel = Channel(**channel.model_dump())
db.add(new_channel)
db.commit()
return channel
def get_channels(self) -> list[ChannelModel]:
with get_db() as db:
channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_user_id(
self, user_id: str, permission: str = "read"
) -> list[ChannelModel]:
channels = self.get_channels()
return [
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
]
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
return ChannelModel.model_validate(channel) if channel else None
def update_channel_by_id(
self, id: str, form_data: ChannelForm
) -> Optional[ChannelModel]:
with get_db() as db:
channel = db.query(Channel).filter(Channel.id == id).first()
if not channel:
return None
channel.name = form_data.name
channel.data = form_data.data
channel.meta = form_data.meta
channel.access_control = form_data.access_control
channel.updated_at = int(time.time_ns())
db.commit()
return ChannelModel.model_validate(channel) if channel else None
def delete_channel_by_id(self, id: str):
with get_db() as db:
db.query(Channel).filter(Channel.id == id).delete()
db.commit()
return True
Channels = ChannelTable()

View File

@ -0,0 +1,912 @@
import logging
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default="{}")
folder_id = Column(Text, nullable=True)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
chat: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
archived: bool = False
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
####################
# Forms
####################
class ChatForm(BaseModel):
chat: dict
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
folder_id: Optional[str] = None
class ChatTitleMessagesForm(BaseModel):
title: str
messages: list[dict]
class ChatTitleForm(BaseModel):
title: str
class ChatResponse(BaseModel):
id: str
user_id: str
title: str
chat: dict
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
archived: bool
pinned: Optional[bool] = False
meta: dict = {}
folder_id: Optional[str] = None
class ChatTitleIdResponse(BaseModel):
id: str
title: str
updated_at: int
created_at: int
class ChatTable:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def import_chat(
self, user_id: str, form_data: ChatImportForm
) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": form_data.chat,
"meta": form_data.meta,
"pinned": form_data.pinned,
"folder_id": form_data.folder_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
with get_db() as db:
chat_item = db.get(Chat, id)
chat_item.chat = chat
chat_item.title = chat["title"] if "title" in chat else "New Chat"
chat_item.updated_at = int(time.time())
db.commit()
db.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception:
return None
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
chat["title"] = title
return self.update_chat_by_id(id, chat)
def update_chat_tags_by_id(
self, id: str, tags: list[str], user
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
self.delete_all_tags_by_id_and_user_id(id, user.id)
for tag in chat.meta.get("tags", []):
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
for tag_name in tags:
if tag_name.lower() == "none":
continue
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
return self.get_chat_by_id(id)
def get_chat_title_by_id(self, id: str) -> Optional[str]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("title", "New Chat")
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}) or {}
def get_message_by_id_and_message_id(
self, id: str, message_id: str
) -> Optional[dict]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
history["messages"][message_id] = {
**history["messages"][message_id],
**message,
}
else:
history["messages"][message_id] = message
history["currentId"] = message_id
chat["history"] = history
return self.update_chat_by_id(id, chat)
def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
chat = self.get_chat_by_id(id)
if chat is None:
return None
chat = chat.chat
history = chat.get("history", {})
if message_id in history.get("messages", {}):
status_history = history["messages"][message_id].get("statusHistory", [])
status_history.append(status)
history["messages"][message_id]["statusHistory"] = status_history
chat["history"] = history
return self.update_chat_by_id(id, chat)
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
db.commit()
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, chat_id)
shared_chat = (
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
)
if shared_chat is None:
return self.insert_shared_chat_by_chat_id(chat_id)
shared_chat.title = chat.title
shared_chat.chat = chat.chat
shared_chat.updated_at = int(time.time())
db.commit()
db.refresh(shared_chat)
return ChatModel.model_validate(shared_chat)
except Exception:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
return True
except Exception:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived
chat.updated_at = int(time.time())
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit()
return True
except Exception:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_title_id_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[ChatTitleIdResponse]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc()).with_entities(
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
)
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
# result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
return [
ChatTitleIdResponse.model_validate(
{
"id": chat[0],
"title": chat[1],
"updated_at": chat[2],
"created_at": chat[3],
}
)
for chat in all_chats
]
def get_chat_list_by_chat_ids(
self, chat_ids: list[str], skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
# it is possible that the shared link was deleted. hence,
# we check if the chat is still shared by checking if a chat with the share_id exists
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
return self.get_chat_by_id(id)
else:
return None
except Exception:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
# .limit(limit).offset(skip)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id_and_search_text(
self,
user_id: str,
search_text: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
"""
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
search_text_words = search_text.split(" ")
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
tag_ids = [
word.replace("tag:", "").replace(" ", "_").lower()
for word in search_text_words
if word.startswith("tag:")
]
search_text_words = [
word for word in search_text_words if not word.startswith("tag:")
]
search_text = " ".join(search_text_words)
with get_db() as db:
query = db.query(Chat).filter(Chat.user_id == user_id)
if not include_archived:
query = query.filter(Chat.archived == False)
query = query.order_by(Chat.updated_at.desc())
# Check if the database dialect is either 'sqlite' or 'postgresql'
dialect_name = db.bind.dialect.name
if dialect_name == "sqlite":
# SQLite case: using JSON1 extension for JSON searching
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_each(Chat.chat, '$.messages') AS message
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_each(Chat.meta, '$.tags') AS tag
WHERE tag.value = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
elif dialect_name == "postgresql":
# PostgreSQL relies on proper JSON query for search
query = query.filter(
(
Chat.title.ilike(
f"%{search_text}%"
) # Case-insensitive search in title
| text(
"""
EXISTS (
SELECT 1
FROM json_array_elements(Chat.chat->'messages') AS message
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
)
"""
)
).params(search_text=search_text)
)
# Check if there are any tags to filter, it should have all the tags
if "none" in tag_ids:
query = query.filter(
text(
"""
NOT EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
)
"""
)
)
elif tag_ids:
query = query.filter(
and_(
*[
text(
f"""
EXISTS (
SELECT 1
FROM json_array_elements_text(Chat.meta->'tags') AS tag
WHERE tag = :tag_id_{tag_idx}
)
"""
).params(**{f"tag_id_{tag_idx}": tag_id})
for tag_idx, tag_id in enumerate(tag_ids)
]
)
)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter(
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
)
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str
) -> Optional[ChatModel]:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
def get_chat_list_by_user_id_and_tag_name(
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
) -> list[ChatModel]:
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
all_chats = query.all()
log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> Optional[ChatModel]:
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
if tag is None:
tag = Tags.insert_new_tag(tag_name, user_id)
try:
with get_db() as db:
chat = db.get(Chat, id)
tag_id = tag.id
if tag_id not in chat.meta.get("tags", []):
chat.meta = {
**chat.meta,
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
}
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
with get_db() as db: # Assuming `get_db()` returns a session object
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
# Normalize the tag_name for consistency
tag_id = tag_name.replace(" ", "_").lower()
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
)
).params(tag_id=tag_id)
elif db.bind.dialect.name == "postgresql":
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
query = query.filter(
text(
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
)
).params(tag_id=tag_id)
else:
raise NotImplementedError(
f"Unsupported dialect: {db.bind.dialect.name}"
)
# Get the count of matching records
count = query.count()
# Debugging output for inspection
log.info(f"Count of chats for tag '{tag_name}': {count}")
return count
def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str
) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
tags = chat.meta.get("tags", [])
tag_id = tag_name.replace(" ", "_").lower()
tags = [tag for tag in tags if tag != tag_id]
chat.meta = {
**chat.meta,
"tags": list(set(tags)),
}
db.commit()
return True
except Exception:
return False
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
chat = db.get(Chat, id)
chat.meta = {
**chat.meta,
"tags": [],
}
db.commit()
return True
except Exception:
return False
def delete_chat_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except Exception:
return False
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id)
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception:
return False
def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str
) -> bool:
try:
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
db.commit()
return True
except Exception:
return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()
return True
except Exception:
return False
Chats = ChatTable()

View File

@ -0,0 +1,254 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Feedback DB Schema
####################
class Feedback(Base):
__tablename__ = "feedback"
id = Column(Text, primary_key=True)
user_id = Column(Text)
version = Column(BigInteger, default=0)
type = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
snapshot = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FeedbackModel(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
snapshot: Optional[dict] = None
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FeedbackResponse(BaseModel):
id: str
user_id: str
version: int
type: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int
updated_at: int
class RatingData(BaseModel):
rating: Optional[str | int] = None
model_id: Optional[str] = None
sibling_model_ids: Optional[list[str]] = None
reason: Optional[str] = None
comment: Optional[str] = None
model_config = ConfigDict(extra="allow", protected_namespaces=())
class MetaData(BaseModel):
arena: Optional[bool] = None
chat_id: Optional[str] = None
message_id: Optional[str] = None
tags: Optional[list[str]] = None
model_config = ConfigDict(extra="allow")
class SnapshotData(BaseModel):
chat: Optional[dict] = None
model_config = ConfigDict(extra="allow")
class FeedbackForm(BaseModel):
type: str
data: Optional[RatingData] = None
meta: Optional[dict] = None
snapshot: Optional[SnapshotData] = None
model_config = ConfigDict(extra="allow")
class FeedbackTable:
def insert_new_feedback(
self, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
id = str(uuid.uuid4())
feedback = FeedbackModel(
**{
"id": id,
"user_id": user_id,
"version": 0,
**form_data.model_dump(),
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Feedback(**feedback.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FeedbackModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error creating a new feedback: {e}")
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_feedback_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FeedbackModel]:
try:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
return FeedbackModel.model_validate(feedback)
except Exception:
return None
def get_all_feedbacks(self) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(type=type)
.order_by(Feedback.updated_at.desc())
.all()
]
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
with get_db() as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc())
.all()
]
def update_feedback_by_id(
self, id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def update_feedback_by_id_and_user_id(
self, id: str, user_id: str, form_data: FeedbackForm
) -> Optional[FeedbackModel]:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return None
if form_data.data:
feedback.data = form_data.data.model_dump()
if form_data.meta:
feedback.meta = form_data.meta
if form_data.snapshot:
feedback.snapshot = form_data.snapshot.model_dump()
feedback.updated_at = int(time.time())
db.commit()
return FeedbackModel.model_validate(feedback)
def delete_feedback_by_id(self, id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
return False
db.delete(feedback)
db.commit()
return True
def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
def delete_all_feedbacks(self) -> bool:
with get_db() as db:
feedbacks = db.query(Feedback).all()
if not feedbacks:
return False
for feedback in feedbacks:
db.delete(feedback)
db.commit()
return True
Feedbacks = FeedbackTable()

View File

@ -0,0 +1,235 @@
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Files DB Schema
####################
class File(Base):
__tablename__ = "file"
id = Column(String, primary_key=True)
user_id = Column(String)
hash = Column(Text, nullable=True)
filename = Column(Text)
path = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FileModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
hash: Optional[str] = None
filename: str
path: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: Optional[int] # timestamp in epoch
updated_at: Optional[int] # timestamp in epoch
####################
# Forms
####################
class FileMeta(BaseModel):
name: Optional[str] = None
content_type: Optional[str] = None
size: Optional[int] = None
model_config = ConfigDict(extra="allow")
class FileModelResponse(BaseModel):
id: str
user_id: str
hash: Optional[str] = None
filename: str
data: Optional[dict] = None
meta: FileMeta
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(extra="allow")
class FileMetadataResponse(BaseModel):
id: str
meta: dict
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class FileForm(BaseModel):
id: str
hash: Optional[str] = None
filename: str
path: str
data: dict = {}
meta: dict = {}
access_control: Optional[dict] = None
class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
file = FileModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = File(**file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FileModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error inserting a new file: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.get(File, id)
return FileModel.model_validate(file)
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
except Exception:
return None
def get_files(self) -> list[FileModel]:
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
with get_db() as db:
return [
FileModel.model_validate(file)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
with get_db() as db:
return [
FileMetadataResponse(
id=file.id,
meta=file.meta,
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
with get_db() as db:
return [
FileModel.model_validate(file)
for file in db.query(File).filter_by(user_id=user_id).all()
]
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.hash = hash
db.commit()
return FileModel.model_validate(file)
except Exception:
return None
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data}
db.commit()
return FileModel.model_validate(file)
except Exception as e:
return None
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta}
db.commit()
return FileModel.model_validate(file)
except Exception:
return None
def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try:
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_files(self) -> bool:
with get_db() as db:
try:
db.query(File).delete()
db.commit()
return True
except Exception:
return False
Files = FilesTable()

View File

@ -0,0 +1,271 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Folder DB Schema
####################
class Folder(Base):
__tablename__ = "folder"
id = Column(Text, primary_key=True)
parent_id = Column(Text, nullable=True)
user_id = Column(Text)
name = Column(Text)
items = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
is_expanded = Column(Boolean, default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class FolderModel(BaseModel):
id: str
parent_id: Optional[str] = None
user_id: str
name: str
items: Optional[dict] = None
meta: Optional[dict] = None
is_expanded: bool = False
created_at: int
updated_at: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FolderForm(BaseModel):
name: str
model_config = ConfigDict(extra="allow")
class FolderTable:
def insert_new_folder(
self, user_id: str, name: str, parent_id: Optional[str] = None
) -> Optional[FolderModel]:
with get_db() as db:
id = str(uuid.uuid4())
folder = FolderModel(
**{
"id": id,
"user_id": user_id,
"name": name,
"parent_id": parent_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Folder(**folder.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FolderModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error inserting a new folder: {e}")
return None
def get_folder_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception:
return None
def get_children_folders_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folders = []
def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for child in children:
get_children(child)
folders.append(child)
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
get_children(folder)
return folders
except Exception:
return None
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all()
]
def get_folder_by_parent_id_and_user_id_and_name(
self, parent_id: Optional[str], user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
# Check if folder exists
folder = (
db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.filter(Folder.name.ilike(name))
.first()
)
if not folder:
return None
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
return None
def get_folders_by_parent_id_and_user_id(
self, parent_id: Optional[str], user_id: str
) -> list[FolderModel]:
with get_db() as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.all()
]
def update_folder_parent_id_by_id_and_user_id(
self,
id: str,
user_id: str,
parent_id: str,
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.parent_id = parent_id
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_name_by_id_and_user_id(
self, id: str, user_id: str, name: str
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
existing_folder = (
db.query(Folder)
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
.first()
)
if existing_folder:
return None
folder.name = name
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def update_folder_is_expanded_by_id_and_user_id(
self, id: str, user_id: str, is_expanded: bool
) -> Optional[FolderModel]:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return None
folder.is_expanded = is_expanded
folder.updated_at = int(time.time())
db.commit()
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return False
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders
def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id
)
for folder_child in folder_children:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child)
folder = db.query(Folder).filter_by(id=folder_child.id).first()
db.delete(folder)
db.commit()
delete_children(folder)
db.delete(folder)
db.commit()
return True
except Exception as e:
log.error(f"delete_folder: {e}")
return False
Folders = FolderTable()

View File

@ -0,0 +1,274 @@
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Functions DB Schema
####################
class Function(Base):
__tablename__ = "function"
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
is_global = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class FunctionMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
class FunctionModel(BaseModel):
id: str
user_id: str
name: str
type: str
content: str
meta: FunctionMeta
is_active: bool = False
is_global: bool = False
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class FunctionResponse(BaseModel):
id: str
user_id: str
type: str
name: str
meta: FunctionMeta
is_active: bool
is_global: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class FunctionForm(BaseModel):
id: str
name: str
content: str
meta: FunctionMeta
class FunctionValves(BaseModel):
valves: Optional[dict] = None
class FunctionsTable:
def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"type": type,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
with get_db() as db:
result = Function(**function.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return FunctionModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error creating a new function: {e}")
return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
with get_db() as db:
function = db.get(Function, id)
return FunctionModel.model_validate(function)
except Exception:
return None
def get_functions(self, active_only=False) -> list[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(is_active=True).all()
]
else:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).all()
]
def get_functions_by_type(
self, type: str, active_only=False
) -> list[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(type=type).all()
]
def get_global_filter_functions(self) -> list[FunctionModel]:
with get_db() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True)
.all()
]
def get_global_action_functions(self) -> list[FunctionModel]:
with get_db() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="action", is_active=True, is_global=True)
.all()
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
try:
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
log.exception(f"Error getting function valves by id {id}: {e}")
return None
def update_function_valves_by_id(
self, id: str, valves: dict
) -> Optional[FunctionValves]:
with get_db() as db:
try:
function = db.get(Function, id)
function.valves = valves
function.updated_at = int(time.time())
db.commit()
db.refresh(function)
return self.get_function_by_id(id)
except Exception:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["functions"]["valves"][id]
except Exception as e:
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
with get_db() as db:
try:
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except Exception:
return None
def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db:
try:
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception:
return None
def delete_function_by_id(self, id: str) -> bool:
with get_db() as db:
try:
db.query(Function).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
Functions = FunctionsTable()

View File

@ -0,0 +1,211 @@
import json
import logging
import time
from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# UserGroup DB Schema
####################
class Group(Base):
__tablename__ = "group"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
name: str
description: str
permissions: Optional[dict] = None
class GroupUpdateForm(GroupForm):
user_ids: Optional[list[str]] = None
class GroupTable:
def insert_new_group(
self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]:
with get_db() as db:
group = GroupModel(
**{
**form_data.model_dump(exclude_none=True),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Group(**group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return GroupModel.model_validate(result)
else:
return None
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_group_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_group_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_groups(self) -> bool:
with get_db() as db:
try:
db.query(Group).delete()
db.commit()
return True
except Exception:
return False
def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db:
try:
groups = self.get_groups_by_member_id(user_id)
for group in groups:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception:
return False
Groups = GroupTable()

View File

@ -0,0 +1,221 @@
import json
import logging
import time
from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Knowledge DB Schema
####################
class Knowledge(Base):
__tablename__ = "knowledge"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class KnowledgeModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class KnowledgeUserModel(KnowledgeModel):
user: Optional[UserResponse] = None
class KnowledgeResponse(KnowledgeModel):
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeUserResponse(KnowledgeUserModel):
files: Optional[list[FileMetadataResponse | dict]] = None
class KnowledgeForm(BaseModel):
name: str
description: str
data: Optional[dict] = None
access_control: Optional[dict] = None
class KnowledgeTable:
def insert_new_knowledge(
self, user_id: str, form_data: KnowledgeForm
) -> Optional[KnowledgeModel]:
with get_db() as db:
knowledge = KnowledgeModel(
**{
**form_data.model_dump(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Knowledge(**knowledge.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return KnowledgeModel.model_validate(result)
else:
return None
except Exception:
return None
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
with get_db() as db:
knowledge_bases = []
for knowledge in (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
):
user = Users.get_user_by_id(knowledge.user_id)
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(knowledge).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return knowledge_bases
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases()
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = db.query(Knowledge).filter_by(id=id).first()
return KnowledgeModel.model_validate(knowledge) if knowledge else None
except Exception:
return None
def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
**form_data.model_dump(),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def update_knowledge_data_by_id(
self, id: str, data: dict
) -> Optional[KnowledgeModel]:
try:
with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
{
"data": data,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_knowledge_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Knowledge).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_knowledge(self) -> bool:
with get_db() as db:
try:
db.query(Knowledge).delete()
db.commit()
return True
except Exception:
return False
Knowledges = KnowledgeTable()

View File

@ -0,0 +1,137 @@
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
####################
# Memory DB Schema
####################
class Memory(Base):
__tablename__ = "memory"
id = Column(String, primary_key=True)
user_id = Column(String)
content = Column(Text)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class MemoryModel(BaseModel):
id: str
user_id: str
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class MemoriesTable:
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
id = str(uuid.uuid4())
memory = MemoryModel(
**{
"id": id,
"user_id": user_id,
"content": content,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Memory(**memory.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return MemoryModel.model_validate(result)
else:
return None
def update_memory_by_id(
self,
id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
try:
db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()
return self.get_memory_by_id(id)
except Exception:
return None
def get_memories(self) -> list[MemoryModel]:
with get_db() as db:
try:
memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except Exception:
return None
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
with get_db() as db:
try:
memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except Exception:
return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = db.get(Memory, id)
return MemoryModel.model_validate(memory)
except Exception:
return None
def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db:
try:
db.query(Memory).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
try:
db.query(Memory).filter_by(user_id=user_id).delete()
db.commit()
return True
except Exception:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
try:
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True
except Exception:
return False
Memories = MemoriesTable()

View File

@ -0,0 +1,279 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Message DB Schema
####################
class MessageReaction(Base):
__tablename__ = "message_reaction"
id = Column(Text, primary_key=True)
user_id = Column(Text)
message_id = Column(Text)
name = Column(Text)
created_at = Column(BigInteger)
class MessageReactionModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
message_id: str
name: str
created_at: int # timestamp in epoch
class Message(Base):
__tablename__ = "message"
id = Column(Text, primary_key=True)
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger) # time_ns
updated_at = Column(BigInteger) # time_ns
class MessageModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
channel_id: Optional[str] = None
parent_id: Optional[str] = None
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class MessageForm(BaseModel):
content: str
parent_id: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class Reactions(BaseModel):
name: str
user_ids: list[str]
count: int
class MessageResponse(MessageModel):
latest_reply_at: Optional[int]
reply_count: int
reactions: list[Reactions]
class MessageTable:
def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str
) -> Optional[MessageModel]:
with get_db() as db:
id = str(uuid.uuid4())
ts = int(time.time_ns())
message = MessageModel(
**{
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"parent_id": form_data.parent_id,
"content": form_data.content,
"data": form_data.data,
"meta": form_data.meta,
"created_at": ts,
"updated_at": ts,
}
)
result = Message(**message.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
with get_db() as db:
message = db.get(Message, id)
if not message:
return None
reactions = self.get_reactions_by_message_id(id)
replies = self.get_replies_by_message_id(id)
return MessageResponse(
**{
**MessageModel.model_validate(message).model_dump(),
"latest_reply_at": replies[0].created_at if replies else None,
"reply_count": len(replies),
"reactions": reactions,
}
)
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(parent_id=id)
.order_by(Message.created_at.desc())
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
with get_db() as db:
return [
message.user_id
for message in db.query(Message).filter_by(parent_id=id).all()
]
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=None)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
) -> list[MessageModel]:
with get_db() as db:
message = db.get(Message, parent_id)
if not message:
return []
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=parent_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# If length of all_messages is less than limit, then add the parent message
if len(all_messages) < limit:
all_messages.append(message)
return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id(
self, id: str, form_data: MessageForm
) -> Optional[MessageModel]:
with get_db() as db:
message = db.get(Message, id)
message.content = form_data.content
message.data = form_data.data
message.meta = form_data.meta
message.updated_at = int(time.time_ns())
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def add_reaction_to_message(
self, id: str, user_id: str, name: str
) -> Optional[MessageReactionModel]:
with get_db() as db:
reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel(
id=reaction_id,
user_id=user_id,
message_id=id,
name=name,
created_at=int(time.time_ns()),
)
result = MessageReaction(**reaction.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
with get_db() as db:
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
reactions = {}
for reaction in all_reactions:
if reaction.name not in reactions:
reactions[reaction.name] = {
"name": reaction.name,
"user_ids": [],
"count": 0,
}
reactions[reaction.name]["user_ids"].append(reaction.user_id)
reactions[reaction.name]["count"] += 1
return [Reactions(**reaction) for reaction in reactions.values()]
def remove_reaction_by_id_and_user_id_and_name(
self, id: str, user_id: str, name: str
) -> bool:
with get_db() as db:
db.query(MessageReaction).filter_by(
message_id=id, user_id=user_id, name=name
).delete()
db.commit()
return True
def delete_reactions_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
def delete_replies_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(Message).filter_by(parent_id=id).delete()
db.commit()
return True
def delete_message_by_id(self, id: str) -> bool:
with get_db() as db:
db.query(Message).filter_by(id=id).delete()
# Delete all reactions to this message
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
Messages = MessageTable()

View File

@ -0,0 +1,273 @@
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Models DB Schema
####################
# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/static/favicon.png"
description: Optional[str] = None
"""
User-facing description of the model.
"""
capabilities: Optional[dict] = None
model_config = ConfigDict(extra="allow")
pass
class Model(Base):
__tablename__ = "model"
id = Column(Text, primary_key=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = Column(Text)
base_model_id = Column(Text, nullable=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = Column(Text)
"""
The human-readable display name of the model.
"""
params = Column(JSONField)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = Column(JSONField)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
is_active = Column(Boolean, default=True)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
is_active: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
class ModelResponse(ModelModel):
pass
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
meta: ModelMeta
params: ModelParams
access_control: Optional[dict] = None
is_active: bool = True
class ModelsTable:
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Failed to insert a new model: {e}")
return None
def get_all_models(self) -> list[ModelModel]:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelUserResponse]:
with get_db() as db:
models = []
for model in db.query(Model).filter(Model.base_model_id != None).all():
user = Users.get_user_by_id(model.user_id)
models.append(
ModelUserResponse.model_validate(
{
**ModelModel.model_validate(model).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return models
def get_base_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
except Exception:
return None
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
with get_db() as db:
try:
is_active = db.query(Model).filter_by(id=id).first().is_active
db.query(Model).filter_by(id=id).update(
{
"is_active": not is_active,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_model_by_id(id)
except Exception:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
with get_db() as db:
# update only the fields that are present in the model
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}))
)
db.commit()
model = db.get(Model, id)
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}")
return None
def delete_model_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_models(self) -> bool:
try:
with get_db() as db:
db.query(Model).delete()
db.commit()
return True
except Exception:
return False
Models = ModelsTable()

View File

@ -0,0 +1,159 @@
import time
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
####################
# Prompts DB Schema
####################
class Prompt(Base):
__tablename__ = "prompt"
command = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
class PromptModel(BaseModel):
command: str
user_id: str
title: str
content: str
timestamp: int # timestamp in epoch
access_control: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class PromptUserResponse(PromptModel):
user: Optional[UserResponse] = None
class PromptForm(BaseModel):
command: str
title: str
content: str
access_control: Optional[dict] = None
class PromptsTable:
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
**form_data.model_dump(),
"timestamp": int(time.time()),
}
)
try:
with get_db() as db:
result = Prompt(**prompt.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return PromptModel.model_validate(result)
else:
return None
except Exception:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try:
with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except Exception:
return None
def get_prompts(self) -> list[PromptUserResponse]:
with get_db() as db:
prompts = []
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
user = Users.get_user_by_id(prompt.user_id)
prompts.append(
PromptUserResponse.model_validate(
{
**PromptModel.model_validate(prompt).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return prompts
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]:
prompts = self.get_prompts()
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title
prompt.content = form_data.content
prompt.access_control = form_data.access_control
prompt.timestamp = int(time.time())
db.commit()
return PromptModel.model_validate(prompt)
except Exception:
return None
def delete_prompt_by_command(self, command: str) -> bool:
try:
with get_db() as db:
db.query(Prompt).filter_by(command=command).delete()
db.commit()
return True
except Exception:
return False
Prompts = PromptsTable()

View File

@ -0,0 +1,109 @@
import logging
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tag DB Schema
####################
class Tag(Base):
__tablename__ = "tag"
id = Column(String)
name = Column(String)
user_id = Column(String)
meta = Column(JSON, nullable=True)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
class TagModel(BaseModel):
id: str
name: str
user_id: str
meta: Optional[dict] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class TagChatIdForm(BaseModel):
name: str
chat_id: str
class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error inserting a new tag: {e}")
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
id = name.replace(" ", "_").lower()
with get_db() as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str
) -> list[TagModel]:
with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try:
with get_db() as db:
id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
Tags = TagTable()

View File

@ -0,0 +1,262 @@
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.users import Users, UserResponse
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
from open_webui.utils.access_control import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Base):
__tablename__ = "tool"
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
access_control = Column(JSON, nullable=True) # Controls data access levels.
# Defines access control rules for this entry.
# - `None`: Public access, available to all users with the "user" role.
# - `{}`: Private access, restricted exclusively to the owner.
# - Custom permissions: Specific access control for reading and writing;
# Can specify group or user-level restrictions:
# {
# "read": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# },
# "write": {
# "group_ids": ["group_id1", "group_id2"],
# "user_ids": ["user_id1", "user_id2"]
# }
# }
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: list[dict]
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class ToolUserModel(ToolModel):
user: Optional[UserResponse] = None
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
access_control: Optional[dict] = None
class ToolValves(BaseModel):
valves: Optional[dict] = None
class ToolsTable:
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict]
) -> Optional[ToolModel]:
with get_db() as db:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
log.exception(f"Error creating a new tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
with get_db() as db:
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except Exception:
return None
def get_tools(self) -> list[ToolUserModel]:
with get_db() as db:
tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = Users.get_user_by_id(tool.user_id)
tools.append(
ToolUserModel.model_validate(
{
**ToolModel.model_validate(tool).model_dump(),
"user": user.model_dump() if user else None,
}
)
)
return tools
def get_tools_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]:
tools = self.get_tools()
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
with get_db() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try:
with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())}
)
db.commit()
return self.get_tool_by_id(id)
except Exception:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id]
except Exception as e:
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
db.commit()
tool = db.query(Tool).get(id)
db.refresh(tool)
return ToolModel.model_validate(tool)
except Exception:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Tool).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
Tools = ToolsTable()

View File

@ -0,0 +1,334 @@
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_db
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
####################
# User DB Schema
####################
class User(Base):
__tablename__ = "user"
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel):
id: str
name: str
email: str
role: str = "pending"
profile_image_url: str
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str
profile_image_url: str
class UserNameResponse(BaseModel):
id: str
name: str
role: str
profile_image_url: str
class UserRoleUpdateForm(BaseModel):
id: str
role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable:
def insert_new_user(
self,
id: str,
name: str,
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
user = UserModel(
**{
"id": id,
"name": name,
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
with get_db() as db:
query = db.query(User).order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel:
try:
with get_db() as db:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except Exception:
return None
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
if user.settings is None:
return None
else:
return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
db.query(User).filter_by(id=id).update(updated)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception:
return None
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
with get_db() as db:
user_settings = db.query(User).filter_by(id=id).first().settings
if user_settings is None:
user_settings = {}
user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Remove User from Groups
Groups.remove_user_from_all_groups(id)
# Delete User Chats
result = Chats.delete_chats_by_user_id(id)
if result:
with get_db() as db:
# Delete User
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
return False
except Exception:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except Exception:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key
except Exception:
return None
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users]
Users = UsersTable()

View File

@ -0,0 +1,215 @@
import requests
import logging
import ftfy
import sys
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
OutlookMessageLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredExcelLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredRSTLoader,
UnstructuredXMLLoader,
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
known_source_ext = [
"go",
"py",
"java",
"sh",
"bat",
"ps1",
"cmd",
"js",
"ts",
"css",
"cpp",
"hpp",
"h",
"c",
"cs",
"sql",
"log",
"ini",
"pl",
"pm",
"r",
"dart",
"dockerfile",
"env",
"php",
"hs",
"hsc",
"lua",
"nginxconf",
"conf",
"m",
"mm",
"plsql",
"perl",
"rb",
"rs",
"db2",
"scala",
"bash",
"swift",
"vue",
"svelte",
"msg",
"ex",
"exs",
"erl",
"tsx",
"jsx",
"hs",
"lhs",
"json",
]
class TikaLoader:
def __init__(self, url, file_path, mime_type=None):
self.url = url
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = self.url
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.debug("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
class Loader:
def __init__(self, engine: str = "", **kwargs):
self.engine = engine
self.kwargs = kwargs
def load(
self, filename: str, file_content_type: str, file_path: str
) -> list[Document]:
loader = self._get_loader(filename, file_content_type, file_path)
docs = loader.load()
return [
Document(
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
)
for doc in docs
]
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext == "docx"
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_content_type in [
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TextLoader(file_path, autodetect_encoding=True)
return loader

View File

@ -0,0 +1,117 @@
import logging
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_NETLOCS = {
"youtu.be",
"m.youtube.com",
"youtube.com",
"www.youtube.com",
"www.youtube-nocookie.com",
"vid.plus",
}
def _parse_video_id(url: str) -> Optional[str]:
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
parsed_url = urlparse(url)
if parsed_url.scheme not in ALLOWED_SCHEMES:
return None
if parsed_url.netloc not in ALLOWED_NETLOCS:
return None
path = parsed_url.path
if path.endswith("/watch"):
query = parsed_url.query
parsed_query = parse_qs(query)
if "v" in parsed_query:
ids = parsed_query["v"]
video_id = ids if isinstance(ids, str) else ids[0]
else:
return None
else:
path = parsed_url.path.lstrip("/")
video_id = path.split("/")[-1]
if len(video_id) != 11: # Video IDs are 11 characters long
return None
return video_id
class YoutubeLoader:
"""Load `YouTube` video transcripts."""
def __init__(
self,
video_id: str,
language: Union[str, Sequence[str]] = "en",
proxy_url: Optional[str] = None,
):
"""Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self.language = language
self.proxy_url = proxy_url
if isinstance(language, str):
self.language = [language]
else:
self.language = language
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
try:
from youtube_transcript_api import (
NoTranscriptFound,
TranscriptsDisabled,
YouTubeTranscriptApi,
)
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`."
)
if self.proxy_url:
youtube_proxies = {
"http": self.proxy_url,
"https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(
self.video_id, proxies=youtube_proxies
)
except Exception as e:
log.exception("Loading YouTube transcript failed")
return []
try:
transcript = transcript_list.find_transcript(self.language)
except NoTranscriptFound:
transcript = transcript_list.find_transcript(["en"])
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript = " ".join(
map(
lambda transcript_piece: transcript_piece["text"].strip(" "),
transcript_pieces,
)
)
return [Document(page_content=transcript, metadata=self._metadata)]

View File

@ -0,0 +1,87 @@
import os
import logging
import torch
import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT:
def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker"
if DOCKER:
# This is a workaround for the issue with the docker container
# where the torch extension is not loaded properly
# and the following error is thrown:
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
lock_file = (
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
)
if os.path.exists(lock_file):
os.remove(lock_file)
self.ckpt = Checkpoint(
name,
colbert_config=ColBERTConfig(model_name=name),
).to(self.device)
pass
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
# Compute similarity scores using batch matrix multiplication
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
# Apply max pooling to extract the highest semantic similarity across each document's sequence
maximum_scores = torch.max(computed_scores, dim=1).values
# Sum up the maximum scores across features to get the overall document relevance scores
final_scores = maximum_scores.sum(dim=1)
normalized_scores = torch.softmax(final_scores, dim=0)
return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
# Embedding the documents
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
# Embedding the queries
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
return scores

View File

@ -0,0 +1,699 @@
import logging
import os
import uuid
from typing import Optional, Union
import asyncio
import requests
import hashlib
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query)],
limit=self.top_k,
)
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
results = []
for idx in range(len(ids)):
results.append(
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
)
return results
def query_doc(
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
limit=k,
)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e
def query_doc_with_hybrid_search(
collection_name: str,
query: str,
embedding_function,
k: int,
reranking_function,
r: float,
) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
bm25_retriever = BM25Retriever.from_texts(
texts=result.documents[0],
metadatas=result.metadatas[0],
)
bm25_retriever.k = k
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
embedding_function=embedding_function,
top_k=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
log.info(
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
)
return result
except Exception as e:
raise e
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
combined_ids = []
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_ids.extend(data["ids"][0])
# Create the output dictionary
result = {
"documents": [combined_documents],
"metadatas": [combined_metadatas],
"ids": [combined_ids],
}
return result
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> dict:
# Initialize lists to store combined data
combined = []
seen_hashes = set() # To store unique document hashes
for data in query_results:
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.md5(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in seen_hashes:
seen_hashes.add(doc_hash)
combined.append((distance, document, metadata))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection(
collection_names: list[str],
queries: list[str],
embedding_function,
k: int,
) -> dict:
results = []
for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
def query_collection_with_hybrid_search(
collection_names: list[str],
queries: list[str],
embedding_function,
k: int,
reranking_function,
r: float,
) -> dict:
results = []
error = False
for collection_name in collection_names:
try:
for query in queries:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
)
error = True
if error:
raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
def get_embedding_function(
embedding_engine,
embedding_model,
embedding_function,
url,
key,
embedding_batch_size,
):
if embedding_engine == "":
return lambda query, user=None: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
func = lambda query, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
url=url,
key=key,
user=user,
)
def generate_multiple(query, user, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(
func(query[i : i + embedding_batch_size], user=user)
)
return embeddings
else:
return func(query, user)
return lambda query, user=None: generate_multiple(query, user, func)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files(
request,
files,
queries,
embedding_function,
k,
reranking_function,
r,
hybrid_search,
full_context=False,
):
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
relevant_contexts = []
for file in files:
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
collection_names = file.get("collection_names", [])
else:
collection_names.append(file["id"])
elif file.get("collection_name"):
collection_names.append(file["collection_name"])
elif file.get("id"):
if file.get("legacy"):
collection_names.append(f"{file['id']}")
else:
collection_names.append(f"file-{file['id']}")
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {file} as it has already been extracted")
continue
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
sources = []
for context in relevant_contexts:
try:
if "documents" in context:
if "metadatas" in context:
source = {
"source": context["file"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
if "distances" in context and context["distances"]:
source["distances"] = context["distances"][0]
sources.append(source)
except Exception as e:
log.exception(e)
return sources
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
if OFFLINE_MODE:
local_files_only = True
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/api/embed",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
)
else:
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
else:
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
return embeddings[0] if isinstance(text, str) else embeddings
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
top_n: int
reranking_function: Any
r_score: float
class Config:
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
reranking = self.reranking_function is not None
if reranking:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results

View File

@ -0,0 +1,22 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@ -0,0 +1,178 @@
import chromadb
import logging
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
if CHROMA_HTTP_HOST != "":
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(**settings_dict),
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(**settings_dict),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collection_names = self.client.list_collections()
return collection_name in collection_names
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
except:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [item["metadata"] for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
def reset(self):
# Resets the database. This will delete all collections and item entries.
return self.client.reset()

View File

@ -0,0 +1,297 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
MILVUS_TOKEN,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
def __init__(self):
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
else:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for match in result:
_ids = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
"documents": documents,
"metadatas": metadatas,
}
)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for match in result:
_ids = []
_distances = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
)
def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
)
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 100},
)
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
schema=schema,
index_params=index_params,
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
)
return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
return None
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
max_limit = 16383 # The maximum number of records per request
all_results = []
if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
try:
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
log.info(f"remaining: {remaining}")
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=["*"],
limit=current_fetch,
offset=offset,
)
if not results:
break
all_results.extend(results)
results_count = len(results)
remaining -= (
results_count # Decrease remaining by the number of items fetched
)
offset += results_count
# Break the loop if the results returned are less than the requested fetch count
if results_count < current_fetch:
break
log.debug(all_results)
return self._result_to_get_result([all_results])
except Exception as e:
log.exception(
f"Error querying collection {collection_name} with limit {limit}: {e}"
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection_name = collection_name.replace("-", "_")
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
)
return self._result_to_get_result([result])
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": item["metadata"],
}
for item in items
],
)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection_name = collection_name.replace("-", "_")
if ids:
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.list_collections()
for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix):
self.client.drop_collection(collection_name=collection_name)

View File

@ -0,0 +1,206 @@
from opensearchpy import OpenSearch
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
OPENSEARCH_PASSWORD,
)
class OpenSearchClient:
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(
hosts=[OPENSEARCH_URI],
use_ssl=OPENSEARCH_SSL,
verify_certs=OPENSEARCH_CERT_VERIFY,
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
def _create_index(self, index_name: str, dimension: int):
body = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16,
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
def delete_colleciton(self, index_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
def search(
self, index_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_search_result(result)
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
for item in batch
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
for index in indices:
self.client.indices.delete(index=index)

View File

@ -0,0 +1,403 @@
from typing import Optional, List, Dict, Any
import logging
from sqlalchemy import (
cast,
column,
create_engine,
Column,
Integer,
MetaData,
select,
text,
Text,
Table,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.internal.db import Session
self.session = Session
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# Check vector length consistency
self.check_vector_length()
# Create the tables if they do not exist
# Base.metadata.create_all requires a bind (engine or connection)
# Get the connection from the session
connection = self.session.connection()
Base.metadata.create_all(bind=connection)
# Create an index on the vector column if it doesn't exist
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
)
)
self.session.commit()
log.info("Initialization complete.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during initialization: {e}")
raise
def check_vector_length(self) -> None:
"""
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
Raises an exception if there is a mismatch.
"""
metadata = MetaData()
try:
# Attempt to reflect the 'document_chunk' table
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
except NoSuchTableError:
# Table does not exist; no action needed
return
# Proceed to check the vector column
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
vector_type = vector_column.type
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
)
else:
raise Exception(
"The 'vector' column exists but is not of type 'Vector'."
)
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
current_length = len(vector)
if current_length < VECTOR_LENGTH:
# Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
raise Exception(
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
)
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
raise
def search(
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
) -> Optional[SearchResult]:
try:
if not vectors:
return None
# Adjust query vectors to VECTOR_LENGTH
vectors = [self.adjust_vector_length(vector) for vector in vectors]
num_queries = len(vectors)
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
)
# Build the lateral subquery for each query vector
subq = (
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
# Build the main query by joining query_vectors and the lateral subquery
stmt = (
select(
query_vectors.c.qid,
subq.c.id,
subq.c.text,
subq.c.vmetadata,
subq.c.distance,
)
.select_from(query_vectors)
.join(subq, true())
.order_by(query_vectors.c.qid, subq.c.distance)
)
result_proxy = self.session.execute(stmt)
results = result_proxy.all()
ids = [[] for _ in range(num_queries)]
distances = [[] for _ in range(num_queries)]
documents = [[] for _ in range(num_queries)]
metadatas = [[] for _ in range(num_queries)]
if not results:
return SearchResult(
ids=ids,
distances=distances,
documents=documents,
metadatas=metadatas,
)
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
log.exception(f"Error during search: {e}")
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(
ids=ids,
documents=documents,
metadatas=metadatas,
)
except Exception as e:
log.exception(f"Error during query: {e}")
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
results = query.all()
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during get: {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
pass
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
is not None
)
return exists
except Exception as e:
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
log.info(f"Collection '{collection_name}' deleted.")

View File

@ -0,0 +1,189 @@
from typing import Optional
import logging
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
)
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={"text": item["text"], "metadata": item["metadata"]},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query=vectors[0],
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
if not self.has_collection(collection_name):
return None
try:
if limit is None:
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
query_filter=models.Filter(should=field_conditions),
limit=limit,
)
return self._result_to_get_result(points.points)
except Exception as e:
log.exception(f"Error querying a collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
),
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
),
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)

View File

@ -0,0 +1,19 @@
from pydantic import BaseModel
from typing import Optional, List, Any
class VectorItem(BaseModel):
id: str
text: str
vector: List[float | int]
metadata: Any
class GetResult(BaseModel):
ids: Optional[List[List[str]]]
documents: Optional[List[List[str]]]
metadatas: Optional[List[List[Any]]]
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]

View File

@ -0,0 +1,73 @@
import logging
import os
from pprint import pprint
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
import argparse
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
"""
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
"""
def search_bing(
subscription_key: str,
endpoint: str,
locale: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
mkt = locale
params = {"q": query, "mkt": mkt, "count": count}
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
try:
response = requests.get(endpoint, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("webPages", {}).get("value", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("snippet"),
)
for result in results
]
except Exception as ex:
log.error(f"Error: {ex}")
raise ex
def main():
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
parser.add_argument(
"query",
type=str,
default="Top 10 international news today",
help="The search query.",
)
parser.add_argument(
"--count", type=int, default=10, help="Number of search results to return."
)
parser.add_argument(
"--filter", nargs="*", help="List of filters to apply to the search results."
)
parser.add_argument(
"--locale",
type=str,
default="en-US",
help="The locale to use for the search, maps to market in api",
)
args = parser.parse_args()
results = search_bing(args.locale, args.query, args.count, args.filter)
pprint(results)

View File

@ -0,0 +1,65 @@
import logging
from typing import Optional
import requests
import json
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def _parse_response(response):
result = {}
if "data" in response:
data = response["data"]
if "webPages" in data:
webPages = data["webPages"]
if "value" in webPages:
result["webpage"] = [
{
"id": item.get("id", ""),
"name": item.get("name", ""),
"url": item.get("url", ""),
"snippet": item.get("snippet", ""),
"summary": item.get("summary", ""),
"siteName": item.get("siteName", ""),
"siteIcon": item.get("siteIcon", ""),
"datePublished": item.get("datePublished", "")
or item.get("dateLastCrawled", ""),
}
for item in webPages["value"]
]
return result
def search_bocha(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Bocha Search API key
query (str): The query to search for
"""
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = json.dumps(
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
)
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
results = _parse_response(response.json())
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("name"), snippet=result.get("summary")
)
for result in results.get("webpage", [])[:count]
]

View File

@ -0,0 +1,42 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Brave Search API key
query (str): The query to search for
"""
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:count]
]

View File

@ -0,0 +1,46 @@
import logging
from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(
query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
list[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
# Return the list of search results
return [
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
for result in search_results
]

View File

@ -0,0 +1,76 @@
import logging
from dataclasses import dataclass
from typing import Optional
import requests
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
EXA_API_BASE = "https://api.exa.ai"
@dataclass
class ExaResult:
url: str
title: str
text: str
def search_exa(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Exa Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Exa Search API key
query (str): The query to search for
count (int): Number of results to return
filter_list (Optional[list[str]]): List of domains to filter results by
"""
log.info(f"Searching with Exa for query: {query}")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"query": query,
"numResults": count or 5,
"includeDomains": filter_list,
"contents": {"text": True, "highlights": True},
"type": "auto", # Use the auto search type (keyword or neural)
}
try:
response = requests.post(
f"{EXA_API_BASE}/search", headers=headers, json=payload
)
response.raise_for_status()
data = response.json()
results = []
for result in data["results"]:
results.append(
ExaResult(
url=result["url"],
title=result["title"],
text=result["text"],
)
)
log.info(f"Found {len(results)} results")
return [
SearchResult(
link=result.url,
title=result.title,
snippet=result.text,
)
for result in results
]
except Exception as e:
log.error(f"Error searching Exa: {e}")
return []

View File

@ -0,0 +1,69 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str,
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Handles pagination for counts greater than 10.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
Returns:
list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"}
all_results = []
start_index = 1 # Google PSE start parameter is 1-based
while count > 0:
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": num_results_this_page,
"start": start_index,
}
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
if results: # check if results are returned. If not, no more pages to fetch.
all_results.extend(results)
count -= len(
results
) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
else:
break # No more results from Google PSE, break the loop
if filter_list:
all_results = get_filtered_results(all_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in all_results
]

View File

@ -0,0 +1,48 @@
import logging
import requests
from open_webui.retrieval.web.main import SearchResult
from open_webui.env import SRC_LOG_LEVELS
from yarl import URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": api_key,
"X-Retain-Images": "none",
}
payload = {"q": query, "count": count if count <= 10 else 10}
url = str(URL(jina_search_endpoint))
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"]:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
)
)
return results

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_kagi(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
The Search API will inherit the settings in your account, including results personalization and snippet length.
Args:
api_key (str): A Kagi Search API key
query (str): The query to search for
count (int): The number of results to return
"""
url = "https://kagi.com/api/v0/search"
headers = {
"Authorization": f"Bot {api_key}",
}
params = {"q": query, "limit": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
search_results = json_response.get("data", [])
results = [
SearchResult(
link=result["url"], title=result["title"], snippet=result.get("snippet")
)
for result in search_results
if result["t"] == 0
]
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return results

View File

@ -0,0 +1,26 @@
import validators
from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
url = result.get("url") or result.get("link", "")
if not validators.url(url):
continue
domain = urlparse(url).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel):
link: str
title: Optional[str]
snippet: Optional[str]

View File

@ -0,0 +1,40 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_mojeek(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Mojeek's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Mojeek Search API key
query (str): The query to search for
"""
url = "https://api.mojeek.com/search"
headers = {
"Accept": "application/json",
}
params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("response", {}).get("results", [])
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("desc")
)
for result in results
]

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searchapi(
api_key: str,
engine: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using searchapi.io's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A searchapi.io API key
query (str): The query to search for
"""
url = "https://www.searchapi.io/api/v1/search"
engine = engine or "google"
payload = {"engine": engine, "q": query, "api_key": api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
json_response = response.json()
log.info(f"results from searchapi search: {json_response}")
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]

View File

@ -0,0 +1,91 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(
query_url: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
**kwargs,
) -> list[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
The function allows passing additional parameters such as language or time_range to tailor the search result.
Args:
query_url (str): The base URL of the SearXNG server.
query (str): The search term or question to find in the SearXNG database.
count (int): The maximum number of results to retrieve from the search.
Keyword Args:
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns:
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get("language", "en-US")
safesearch = kwargs.get("safesearch", "1")
time_range = kwargs.get("time_range", "")
categories = "".join(kwargs.get("categories", []))
params = {
"q": query,
"format": "json",
"pageno": 1,
"safesearch": safesearch,
"language": language,
"time_range": time_range,
"categories": categories,
"theme": "simple",
"image_proxy": 0,
}
# Legacy query format
if "<query>" in query_url:
# Strip all query parameters from the URL
query_url = query_url.split("?")[0]
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
for result in sorted_results[:count]
]

View File

@ -0,0 +1,48 @@
import logging
from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpapi(
api_key: str,
engine: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serpapi.com API key
query (str): The query to search for
"""
url = "https://serpapi.com/search"
engine = engine or "google"
payload = {"engine": engine, "q": query, "api_key": api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
json_response = response.json()
log.info(f"results from serpapi search: {json_response}")
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]

View File

@ -0,0 +1,43 @@
import json
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serper.dev API key
query (str): The query to search for
"""
url = "https://google.serper.dev/search"
payload = json.dumps({"q": query})
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
response.raise_for_status()
json_response = response.json()
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]

View File

@ -0,0 +1,69 @@
import logging
from typing import Optional
from urllib.parse import urlencode
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serply.io API key
query (str): The query to search for
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
limit (int): The maximum number of results to return [10-100, defaults to 10]
"""
log.info("Searching with Serply")
url = "https://api.serply.io/v1/search/"
query_payload = {
"q": query,
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
headers = {
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location,
}
response = requests.request("GET", url, headers=headers)
response.raise_for_status()
json_response = response.json()
log.info(f"results from serply search: {json_response}")
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]

Some files were not shown because too many files have changed in this diff Show More