init
This commit is contained in:
parent
2d37e5c858
commit
31c82973b2
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,99 @@
|
|||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
|
||||
|
||||
## Why These Standards Are Important
|
||||
|
||||
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
|
||||
|
||||
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
|
||||
|
||||
This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contribute to a positive and professional community include:
|
||||
|
||||
- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
|
||||
- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
|
||||
- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
|
||||
- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
- The use of discriminatory, demeaning, or sexualized language or behavior.
|
||||
- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
|
||||
- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
|
||||
- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
|
||||
- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
|
||||
- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
|
||||
- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
|
||||
|
||||
### Feedback and Community Engagement
|
||||
|
||||
- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
|
||||
- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
|
||||
|
||||
### Zero Tolerance: No Warnings, Immediate Action
|
||||
|
||||
This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
|
||||
|
||||
We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
|
||||
|
||||
Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
|
||||
|
||||
## Reporting Violations
|
||||
|
||||
Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
|
||||
|
||||
All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
### Ban
|
||||
|
||||
**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
|
||||
|
||||
A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
|
||||
|
||||
- Harassment or abusive behavior toward contributors.
|
||||
- Persistent negativity or hostility that disrupts the collaborative environment.
|
||||
- Disrespectful, demanding, or aggressive interactions with others.
|
||||
- Attempts to cause harm or sabotage the community.
|
||||
|
||||
**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
|
||||
|
||||
This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
|
||||
|
||||
## Why Zero Tolerance Is Necessary
|
||||
|
||||
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
|
||||
|
||||
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
|
||||
|
||||
Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# Run with
|
||||
# caddy run --envfile ./example.env --config ./Caddyfile.localhost
|
||||
#
|
||||
# This is configured for
|
||||
# - Automatic HTTPS (even for localhost)
|
||||
# - Reverse Proxying to Ollama API Base URL (http://localhost:11434/api)
|
||||
# - CORS
|
||||
# - HTTP Basic Auth API Tokens (uncomment basicauth section)
|
||||
|
||||
|
||||
# CORS Preflight (OPTIONS) + Request (GET, POST, PATCH, PUT, DELETE)
|
||||
(cors-api) {
|
||||
@match-cors-api-preflight method OPTIONS
|
||||
handle @match-cors-api-preflight {
|
||||
header {
|
||||
Access-Control-Allow-Origin "{http.request.header.origin}"
|
||||
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
||||
Access-Control-Allow-Credentials "true"
|
||||
Access-Control-Max-Age "3600"
|
||||
defer
|
||||
}
|
||||
respond "" 204
|
||||
}
|
||||
|
||||
@match-cors-api-request {
|
||||
not {
|
||||
header Origin "{http.request.scheme}://{http.request.host}"
|
||||
}
|
||||
header Origin "{http.request.header.origin}"
|
||||
}
|
||||
handle @match-cors-api-request {
|
||||
header {
|
||||
Access-Control-Allow-Origin "{http.request.header.origin}"
|
||||
Access-Control-Allow-Methods "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
Access-Control-Allow-Headers "Origin, Accept, Authorization, Content-Type, X-Requested-With"
|
||||
Access-Control-Allow-Credentials "true"
|
||||
Access-Control-Max-Age "3600"
|
||||
defer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# replace localhost with example.com or whatever
|
||||
localhost {
|
||||
## HTTP Basic Auth
|
||||
## (uncomment to enable)
|
||||
# basicauth {
|
||||
# # see .example.env for how to generate tokens
|
||||
# {env.OLLAMA_API_ID} {env.OLLAMA_API_TOKEN_DIGEST}
|
||||
# }
|
||||
|
||||
handle /api/* {
|
||||
# Comment to disable CORS
|
||||
import cors-api
|
||||
|
||||
reverse_proxy localhost:11434
|
||||
}
|
||||
|
||||
# Same-Origin Static Web Server
|
||||
file_server {
|
||||
root ./build/
|
||||
}
|
||||
}
|
||||
192
Dockerfile
192
Dockerfile
|
|
@ -1,30 +1,176 @@
|
|||
# Użyj oficjalnego obrazu Python jako bazowego
|
||||
FROM --platform=linux/amd64 python:3.9-slim
|
||||
# syntax=docker/dockerfile:1
|
||||
# Initialize device type args
|
||||
# use build args in the docker build command with --build-arg="BUILDARG=true"
|
||||
ARG USE_CUDA=false
|
||||
ARG USE_OLLAMA=false
|
||||
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
|
||||
ARG USE_CUDA_VER=cu121
|
||||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
ARG USE_RERANKING_MODEL=""
|
||||
|
||||
# Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
|
||||
ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
|
||||
|
||||
ARG BUILD_HASH=dev-build
|
||||
# Override at your own risk - non-root configurations are untested
|
||||
ARG UID=0
|
||||
ARG GID=0
|
||||
|
||||
######## WebUI frontend ########
|
||||
FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build
|
||||
ARG BUILD_HASH
|
||||
|
||||
# Ustaw katalog roboczy w kontenerze
|
||||
WORKDIR /app
|
||||
|
||||
# Zainstaluj git
|
||||
RUN apt-get update && apt-get install -y git nano wget curl iputils-ping
|
||||
COPY package.json package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
# Skopiuj pliki wymagań (jeśli istnieją) i zainstaluj zależności
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Skopiuj plik requirements.txt do kontenera
|
||||
COPY requirements.txt .
|
||||
|
||||
# Zainstaluj zależności z pliku requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Zainstaluj Tesseract OCR
|
||||
RUN apt-get install -y tesseract-ocr
|
||||
|
||||
# Skopiuj kod źródłowy do kontenera
|
||||
COPY . .
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
ENV APP_BUILD_HASH=${BUILD_HASH}
|
||||
RUN npm run build
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
######## WebUI backend ########
|
||||
FROM python:3.11-slim-bookworm AS base
|
||||
|
||||
# Uruchom aplikację
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
# Use args
|
||||
ARG USE_CUDA
|
||||
ARG USE_OLLAMA
|
||||
ARG USE_CUDA_VER
|
||||
ARG USE_EMBEDDING_MODEL
|
||||
ARG USE_RERANKING_MODEL
|
||||
ARG UID
|
||||
ARG GID
|
||||
|
||||
## Basis ##
|
||||
ENV ENV=prod \
|
||||
PORT=8080 \
|
||||
# pass build args to the build
|
||||
USE_OLLAMA_DOCKER=${USE_OLLAMA} \
|
||||
USE_CUDA_DOCKER=${USE_CUDA} \
|
||||
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
|
||||
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
|
||||
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
|
||||
|
||||
## Basis URL Config ##
|
||||
ENV OLLAMA_BASE_URL="/ollama" \
|
||||
OPENAI_API_BASE_URL=""
|
||||
|
||||
## API Key and Security Config ##
|
||||
ENV OPENAI_API_KEY="" \
|
||||
WEBUI_SECRET_KEY="" \
|
||||
SCARF_NO_ANALYTICS=true \
|
||||
DO_NOT_TRACK=true \
|
||||
ANONYMIZED_TELEMETRY=false
|
||||
|
||||
#### Other models #########################################################
|
||||
## whisper TTS model settings ##
|
||||
ENV WHISPER_MODEL="base" \
|
||||
WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
|
||||
|
||||
## RAG Embedding model settings ##
|
||||
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
|
||||
RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
|
||||
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
|
||||
|
||||
## Tiktoken model settings ##
|
||||
ENV TIKTOKEN_ENCODING_NAME="cl100k_base" \
|
||||
TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
|
||||
|
||||
## Hugging Face download cache ##
|
||||
ENV HF_HOME="/app/backend/data/cache/embedding/models"
|
||||
|
||||
## Torch Extensions ##
|
||||
# ENV TORCH_EXTENSIONS_DIR="/.cache/torch_extensions"
|
||||
|
||||
#### Other models ##########################################################
|
||||
|
||||
WORKDIR /app/backend
|
||||
|
||||
ENV HOME=/root
|
||||
# Create user and group if not root
|
||||
RUN if [ $UID -ne 0 ]; then \
|
||||
if [ $GID -ne 0 ]; then \
|
||||
addgroup --gid $GID app; \
|
||||
fi; \
|
||||
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
|
||||
fi
|
||||
|
||||
RUN mkdir -p $HOME/.cache/chroma
|
||||
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
|
||||
|
||||
# Make sure the user has access to the app and root directory
|
||||
RUN chown -R $UID:$GID /app $HOME
|
||||
|
||||
RUN if [ "$USE_OLLAMA" = "true" ]; then \
|
||||
apt-get update && \
|
||||
# Install pandoc and netcat
|
||||
apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \
|
||||
apt-get install -y --no-install-recommends gcc python3-dev && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# install helper tools
|
||||
apt-get install -y --no-install-recommends curl jq && \
|
||||
# install ollama
|
||||
curl -fsSL https://ollama.com/install.sh | sh && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
else \
|
||||
apt-get update && \
|
||||
# Install pandoc, netcat and gcc
|
||||
apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \
|
||||
apt-get install -y --no-install-recommends gcc python3-dev && \
|
||||
# for RAG OCR
|
||||
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
|
||||
# cleanup
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# install python dependencies
|
||||
COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
|
||||
|
||||
RUN pip3 install uv && \
|
||||
if [ "$USE_CUDA" = "true" ]; then \
|
||||
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
||||
else \
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
|
||||
fi; \
|
||||
chown -R $UID:$GID /app/backend/data/
|
||||
|
||||
|
||||
|
||||
# copy embedding weight from build
|
||||
# RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
|
||||
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
|
||||
|
||||
# copy built frontend files
|
||||
COPY --chown=$UID:$GID --from=build /app/build /app/build
|
||||
COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
|
||||
COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
|
||||
|
||||
# copy backend files
|
||||
COPY --chown=$UID:$GID ./backend .
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1
|
||||
|
||||
USER $UID:$GID
|
||||
|
||||
ARG BUILD_HASH
|
||||
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
|
||||
ENV DOCKER=true
|
||||
|
||||
CMD [ "bash", "start.sh"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
### Installing Both Ollama and Open WebUI Using Kustomize
|
||||
|
||||
For cpu-only pod
|
||||
|
||||
```bash
|
||||
kubectl apply -f ./kubernetes/manifest/base
|
||||
```
|
||||
|
||||
For gpu-enabled pod
|
||||
|
||||
```bash
|
||||
kubectl apply -k ./kubernetes/manifest
|
||||
```
|
||||
|
||||
### Installing Both Ollama and Open WebUI Using Helm
|
||||
|
||||
Package Helm file first
|
||||
|
||||
```bash
|
||||
helm package ./kubernetes/helm/
|
||||
```
|
||||
|
||||
For cpu-only pod
|
||||
|
||||
```bash
|
||||
helm install ollama-webui ./ollama-webui-*.tgz
|
||||
```
|
||||
|
||||
For gpu-enabled pod
|
||||
|
||||
```bash
|
||||
helm install ollama-webui ./ollama-webui-*.tgz --set ollama.resources.limits.nvidia.com/gpu="1"
|
||||
```
|
||||
|
||||
Check the `kubernetes/helm/values.yaml` file to know which parameters are available for customization
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2023-2025 Timothy Jaeryang Baek
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
|
||||
ifneq ($(shell which docker-compose 2>/dev/null),)
|
||||
DOCKER_COMPOSE := docker-compose
|
||||
else
|
||||
DOCKER_COMPOSE := docker compose
|
||||
endif
|
||||
|
||||
install:
|
||||
$(DOCKER_COMPOSE) up -d
|
||||
|
||||
remove:
|
||||
@chmod +x confirm_remove.sh
|
||||
@./confirm_remove.sh
|
||||
|
||||
start:
|
||||
$(DOCKER_COMPOSE) start
|
||||
startAndBuild:
|
||||
$(DOCKER_COMPOSE) up -d --build
|
||||
|
||||
stop:
|
||||
$(DOCKER_COMPOSE) stop
|
||||
|
||||
update:
|
||||
# Calls the LLM update script
|
||||
chmod +x update_ollama_models.sh
|
||||
@./update_ollama_models.sh
|
||||
@git pull
|
||||
$(DOCKER_COMPOSE) down
|
||||
# Make sure the ollama-webui container is stopped before rebuilding
|
||||
@docker stop open-webui || true
|
||||
$(DOCKER_COMPOSE) up --build -d
|
||||
$(DOCKER_COMPOSE) start
|
||||
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Open WebUI Troubleshooting Guide
|
||||
|
||||
## Understanding the Open WebUI Architecture
|
||||
|
||||
The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
|
||||
|
||||
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
|
||||
|
||||
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
|
||||
|
||||
## Open WebUI: Server Connection Error
|
||||
|
||||
If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
|
||||
|
||||
**Example Docker Command**:
|
||||
|
||||
```bash
|
||||
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
|
||||
### Error on Slow Responses for Ollama
|
||||
|
||||
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
|
||||
|
||||
### General Connection Errors
|
||||
|
||||
**Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
|
||||
|
||||
**Troubleshooting Steps**:
|
||||
|
||||
1. **Verify Ollama URL Format**:
|
||||
- When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
|
||||
- In the Open WebUI, navigate to "Settings" > "General".
|
||||
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
|
||||
|
||||
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
|
||||
120
allegro.py
120
allegro.py
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq
|
||||
import weaviate
|
||||
from weaviate.client import WeaviateClient
|
||||
from weaviate.connect import ConnectionParams
|
||||
|
||||
# 1️⃣ Inicjalizacja modelu do embeddingów
|
||||
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
# 2️⃣ Połączenie z Weaviate i pobranie dokumentów
|
||||
client = WeaviateClient(
|
||||
connection_params=ConnectionParams.from_params(
|
||||
http_host="weaviate",
|
||||
http_port=8080,
|
||||
http_secure=False,
|
||||
grpc_host="weaviate",
|
||||
grpc_port=50051,
|
||||
grpc_secure=False,
|
||||
)
|
||||
)
|
||||
|
||||
collection_name = "Document" # Zakładam, że to jest nazwa Twojej kolekcji
|
||||
result = (
|
||||
client.query.get(collection_name, ["content"])
|
||||
.with_additional(["id"])
|
||||
.do()
|
||||
)
|
||||
|
||||
documents = [item['content'] for item in result['data']['Get'][collection_name]]
|
||||
|
||||
# 3️⃣ Generowanie embeddingów
|
||||
embeddings = embed_model.encode(documents)
|
||||
|
||||
# 4️⃣ Przygotowanie danych treningowych
|
||||
def create_training_data():
|
||||
data = {
|
||||
"text": documents,
|
||||
"embedding": embeddings.tolist()
|
||||
}
|
||||
return Dataset.from_dict(data)
|
||||
|
||||
dataset = create_training_data()
|
||||
|
||||
# Podział danych na treningowe i ewaluacyjne
|
||||
split_dataset = dataset.train_test_split(test_size=0.25)
|
||||
train_dataset = split_dataset["train"]
|
||||
eval_dataset = split_dataset["test"]
|
||||
|
||||
# 5️⃣ Ładowanie modelu allegro/multislav-5lang
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_name = "allegro/multislav-5lang"
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# 6️⃣ Konfiguracja LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM"
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# 7️⃣ Tokenizacja danych
|
||||
max_length = 384
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(
|
||||
examples["text"],
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
tokenized_train = train_dataset.map(tokenize_function, batched=True)
|
||||
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
|
||||
|
||||
# 8️⃣ Parametry treningu
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=2,
|
||||
per_device_eval_batch_size=2,
|
||||
num_train_epochs=16,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="loss",
|
||||
greater_is_better=False,
|
||||
)
|
||||
|
||||
# 9️⃣ Data Collator
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
model=model
|
||||
)
|
||||
|
||||
# 🔟 Trening modelu
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_train,
|
||||
eval_dataset=tokenized_eval,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# 1️⃣1️⃣ Zapis modelu
|
||||
model.save_pretrained("./models/allegro")
|
||||
tokenizer.save_pretrained("./models/allegro")
|
||||
|
||||
print("✅ Model został wytrenowany i zapisany!")
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
__pycache__
|
||||
.env
|
||||
_old
|
||||
uploads
|
||||
.ipynb_checkpoints
|
||||
*.db
|
||||
_test
|
||||
!/data
|
||||
/data/*
|
||||
!/data/litellm
|
||||
/data/litellm/*
|
||||
!data/litellm/config.yaml
|
||||
|
||||
!data/config.json
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
__pycache__
|
||||
.env
|
||||
_old
|
||||
uploads
|
||||
.ipynb_checkpoints
|
||||
*.db
|
||||
_test
|
||||
Pipfile
|
||||
!/data
|
||||
/data/*
|
||||
/open_webui/data/*
|
||||
.webui_secret_key
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
PORT="${PORT:-8080}"
|
||||
uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import base64
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import uvicorn
|
||||
from typing import Optional
|
||||
from typing_extensions import Annotated
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
KEY_FILE = Path.cwd() / ".webui_secret_key"
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
from open_webui.env import VERSION
|
||||
|
||||
typer.echo(f"Open WebUI version: {VERSION}")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
version: Annotated[
|
||||
Optional[bool], typer.Option("--version", callback=version_callback)
|
||||
] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
):
|
||||
os.environ["FROM_INIT_PY"] = "true"
|
||||
if os.getenv("WEBUI_SECRET_KEY") is None:
|
||||
typer.echo(
|
||||
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
|
||||
)
|
||||
if not KEY_FILE.exists():
|
||||
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
|
||||
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
||||
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
|
||||
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
|
||||
|
||||
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
|
||||
typer.echo(
|
||||
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
||||
)
|
||||
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
|
||||
os.environ["LD_LIBRARY_PATH"] = ":".join(
|
||||
LD_LIBRARY_PATH
|
||||
+ [
|
||||
"/usr/local/lib/python3.11/site-packages/torch/lib",
|
||||
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
|
||||
]
|
||||
)
|
||||
try:
|
||||
import torch
|
||||
|
||||
assert torch.cuda.is_available(), "CUDA not available"
|
||||
typer.echo("CUDA seems to be working")
|
||||
except Exception as e:
|
||||
typer.echo(
|
||||
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
||||
"Resetting USE_CUDA_DOCKER to false and removing "
|
||||
f"LD_LIBRARY_PATH modifications: {e}"
|
||||
)
|
||||
os.environ["USE_CUDA_DOCKER"] = "false"
|
||||
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
|
||||
|
||||
import open_webui.main # we need set environment variables before importing main
|
||||
|
||||
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
|
||||
|
||||
|
||||
@app.command()
|
||||
def dev(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
reload: bool = True,
|
||||
):
|
||||
uvicorn.run(
|
||||
"open_webui.main:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
forwarded_allow_ips="*",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to migrations/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,119 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
|
||||
MODEL_DELETED = (
|
||||
lambda model="": f"The model '{model}' has been deleted successfully."
|
||||
)
|
||||
|
||||
|
||||
class WEBHOOK_MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
USER_SIGNUP = lambda username="": (
|
||||
f"New user signed up: {username}" if username else "New user signed up"
|
||||
)
|
||||
|
||||
|
||||
class ERROR_MESSAGES(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = (
|
||||
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
||||
)
|
||||
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
|
||||
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
|
||||
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
|
||||
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
|
||||
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
|
||||
USERNAME_TAKEN = (
|
||||
"Uh-oh! This username is already registered. Please choose another username."
|
||||
)
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
||||
|
||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||
|
||||
INVALID_TOKEN = (
|
||||
"Your session has expired or the token is invalid. Please sign in again."
|
||||
)
|
||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||
INVALID_PASSWORD = (
|
||||
"The password provided is incorrect. Please check for typos and try again."
|
||||
)
|
||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
||||
|
||||
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
|
||||
|
||||
UNAUTHORIZED = "401 Unauthorized"
|
||||
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
||||
ACTION_PROHIBITED = (
|
||||
"The requested action has been restricted as a security measure."
|
||||
)
|
||||
|
||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
||||
|
||||
NOT_FOUND = "We could not find what you're looking for :/"
|
||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
||||
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
|
||||
|
||||
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
||||
|
||||
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
|
||||
INCORRECT_FORMAT = (
|
||||
lambda err="": f"Invalid format. Please use the correct format{err}"
|
||||
)
|
||||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
||||
|
||||
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
||||
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
||||
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
||||
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
||||
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
|
||||
|
||||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
||||
|
||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
||||
|
||||
INVALID_URL = (
|
||||
"Oops! The URL you provided is invalid. Please double-check and try again."
|
||||
)
|
||||
|
||||
WEB_SEARCH_ERROR = (
|
||||
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
|
||||
)
|
||||
|
||||
OLLAMA_API_DISABLED = (
|
||||
"The Ollama API is disabled. Please enable it to use this feature."
|
||||
)
|
||||
|
||||
FILE_TOO_LARGE = (
|
||||
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
||||
)
|
||||
|
||||
DUPLICATE_CONTENT = (
|
||||
"Duplicate content detected. Please provide unique content to proceed."
|
||||
)
|
||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||
|
||||
|
||||
class TASKS(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
|
||||
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
MOA_RESPONSE_GENERATION = "moa_response_generation"
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
import importlib.metadata
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
####################################
|
||||
# Load .env file
|
||||
####################################
|
||||
|
||||
OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file
|
||||
print(OPEN_WEBUI_DIR)
|
||||
|
||||
BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file
|
||||
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
|
||||
|
||||
print(BACKEND_DIR)
|
||||
print(BASE_DIR)
|
||||
|
||||
try:
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
|
||||
except ImportError:
|
||||
print("dotenv not installed, skipping...")
|
||||
|
||||
DOCKER = os.environ.get("DOCKER", "False").lower() == "true"
|
||||
|
||||
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
||||
|
||||
if USE_CUDA.lower() == "true":
|
||||
try:
|
||||
import torch
|
||||
|
||||
assert torch.cuda.is_available(), "CUDA not available"
|
||||
DEVICE_TYPE = "cuda"
|
||||
except Exception as e:
|
||||
cuda_error = (
|
||||
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
||||
f"Resetting USE_CUDA_DOCKER to false: {e}"
|
||||
)
|
||||
os.environ["USE_CUDA_DOCKER"] = "false"
|
||||
USE_CUDA = "false"
|
||||
DEVICE_TYPE = "cpu"
|
||||
else:
|
||||
DEVICE_TYPE = "cpu"
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
DEVICE_TYPE = "mps"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
####################################
|
||||
# LOGGING
|
||||
####################################
|
||||
|
||||
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
|
||||
|
||||
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
|
||||
if GLOBAL_LOG_LEVEL in log_levels:
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
|
||||
else:
|
||||
GLOBAL_LOG_LEVEL = "INFO"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
|
||||
|
||||
if "cuda_error" in locals():
|
||||
log.exception(cuda_error)
|
||||
|
||||
log_sources = [
|
||||
"AUDIO",
|
||||
"COMFYUI",
|
||||
"CONFIG",
|
||||
"DB",
|
||||
"IMAGES",
|
||||
"MAIN",
|
||||
"MODELS",
|
||||
"OLLAMA",
|
||||
"OPENAI",
|
||||
"RAG",
|
||||
"WEBHOOK",
|
||||
"SOCKET",
|
||||
"OAUTH",
|
||||
]
|
||||
|
||||
SRC_LOG_LEVELS = {}
|
||||
|
||||
for source in log_sources:
|
||||
log_env_var = source + "_LOG_LEVEL"
|
||||
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
|
||||
if SRC_LOG_LEVELS[source] not in log_levels:
|
||||
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
|
||||
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
|
||||
|
||||
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
|
||||
|
||||
|
||||
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
|
||||
|
||||
####################################
|
||||
# ENV (dev,test,prod)
|
||||
####################################
|
||||
|
||||
ENV = os.environ.get("ENV", "dev")
|
||||
|
||||
FROM_INIT_PY = os.environ.get("FROM_INIT_PY", "False").lower() == "true"
|
||||
|
||||
if FROM_INIT_PY:
|
||||
PACKAGE_DATA = {"version": importlib.metadata.version("open-webui")}
|
||||
else:
|
||||
try:
|
||||
PACKAGE_DATA = json.loads((BASE_DIR / "package.json").read_text())
|
||||
except Exception:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
|
||||
|
||||
# Function to parse each section
|
||||
def parse_section(section):
|
||||
items = []
|
||||
for li in section.find_all("li"):
|
||||
# Extract raw HTML string
|
||||
raw_html = str(li)
|
||||
|
||||
# Extract text without HTML tags
|
||||
text = li.get_text(separator=" ", strip=True)
|
||||
|
||||
# Split into title and content
|
||||
parts = text.split(": ", 1)
|
||||
title = parts[0].strip() if len(parts) > 1 else ""
|
||||
content = parts[1].strip() if len(parts) > 1 else text
|
||||
|
||||
items.append({"title": title, "content": content, "raw": raw_html})
|
||||
return items
|
||||
|
||||
|
||||
try:
|
||||
changelog_path = BASE_DIR / "CHANGELOG.md"
|
||||
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
|
||||
changelog_content = file.read()
|
||||
|
||||
except Exception:
|
||||
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
|
||||
|
||||
|
||||
# Convert markdown content to HTML
|
||||
html_content = markdown.markdown(changelog_content)
|
||||
|
||||
# Parse the HTML content
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
|
||||
# Initialize JSON structure
|
||||
changelog_json = {}
|
||||
|
||||
# Iterate over each version
|
||||
for version in soup.find_all("h2"):
|
||||
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
|
||||
date = version.get_text().strip().split(" - ")[1]
|
||||
|
||||
version_data = {"date": date}
|
||||
|
||||
# Find the next sibling that is a h3 tag (section title)
|
||||
current = version.find_next_sibling()
|
||||
|
||||
while current and current.name != "h2":
|
||||
if current.name == "h3":
|
||||
section_title = current.get_text().lower() # e.g., "added", "fixed"
|
||||
section_items = parse_section(current.find_next_sibling("ul"))
|
||||
version_data[section_title] = section_items
|
||||
|
||||
# Move to the next element
|
||||
current = current.find_next_sibling()
|
||||
|
||||
changelog_json[version_number] = version_data
|
||||
|
||||
|
||||
CHANGELOG = changelog_json
|
||||
|
||||
####################################
|
||||
# SAFE_MODE
|
||||
####################################
|
||||
|
||||
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
|
||||
|
||||
####################################
|
||||
# ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
####################################
|
||||
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS = (
|
||||
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_BUILD_HASH
|
||||
####################################
|
||||
|
||||
WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
|
||||
|
||||
####################################
|
||||
# DATA/FRONTEND BUILD DIR
|
||||
####################################
|
||||
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
|
||||
|
||||
if FROM_INIT_PY:
|
||||
NEW_DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data")).resolve()
|
||||
NEW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if the data directory exists in the package directory
|
||||
if DATA_DIR.exists() and DATA_DIR != NEW_DATA_DIR:
|
||||
log.info(f"Moving {DATA_DIR} to {NEW_DATA_DIR}")
|
||||
for item in DATA_DIR.iterdir():
|
||||
dest = NEW_DATA_DIR / item.name
|
||||
if item.is_dir():
|
||||
shutil.copytree(item, dest, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(item, dest)
|
||||
|
||||
# Zip the data directory
|
||||
shutil.make_archive(DATA_DIR.parent / "open_webui_data", "zip", DATA_DIR)
|
||||
|
||||
# Remove the old data directory
|
||||
shutil.rmtree(DATA_DIR)
|
||||
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
|
||||
|
||||
|
||||
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
|
||||
|
||||
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
|
||||
|
||||
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
|
||||
|
||||
if FROM_INIT_PY:
|
||||
FRONTEND_BUILD_DIR = Path(
|
||||
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
|
||||
).resolve()
|
||||
|
||||
|
||||
####################################
|
||||
# Database
|
||||
####################################
|
||||
|
||||
# Check if the file exists
|
||||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
# Rename the file
|
||||
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
|
||||
log.info("Database migrated from Ollama-WebUI successfully.")
|
||||
else:
|
||||
pass
|
||||
|
||||
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
|
||||
|
||||
# Replace the postgres:// with postgresql://
|
||||
if "postgres://" in DATABASE_URL:
|
||||
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
|
||||
|
||||
DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None)
|
||||
|
||||
DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0)
|
||||
|
||||
if DATABASE_POOL_SIZE == "":
|
||||
DATABASE_POOL_SIZE = 0
|
||||
else:
|
||||
try:
|
||||
DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE)
|
||||
except Exception:
|
||||
DATABASE_POOL_SIZE = 0
|
||||
|
||||
DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0)
|
||||
|
||||
if DATABASE_POOL_MAX_OVERFLOW == "":
|
||||
DATABASE_POOL_MAX_OVERFLOW = 0
|
||||
else:
|
||||
try:
|
||||
DATABASE_POOL_MAX_OVERFLOW = int(DATABASE_POOL_MAX_OVERFLOW)
|
||||
except Exception:
|
||||
DATABASE_POOL_MAX_OVERFLOW = 0
|
||||
|
||||
DATABASE_POOL_TIMEOUT = os.environ.get("DATABASE_POOL_TIMEOUT", 30)
|
||||
|
||||
if DATABASE_POOL_TIMEOUT == "":
|
||||
DATABASE_POOL_TIMEOUT = 30
|
||||
else:
|
||||
try:
|
||||
DATABASE_POOL_TIMEOUT = int(DATABASE_POOL_TIMEOUT)
|
||||
except Exception:
|
||||
DATABASE_POOL_TIMEOUT = 30
|
||||
|
||||
DATABASE_POOL_RECYCLE = os.environ.get("DATABASE_POOL_RECYCLE", 3600)
|
||||
|
||||
if DATABASE_POOL_RECYCLE == "":
|
||||
DATABASE_POOL_RECYCLE = 3600
|
||||
else:
|
||||
try:
|
||||
DATABASE_POOL_RECYCLE = int(DATABASE_POOL_RECYCLE)
|
||||
except Exception:
|
||||
DATABASE_POOL_RECYCLE = 3600
|
||||
|
||||
RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
ENABLE_REALTIME_CHAT_SAVE = (
|
||||
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
|
||||
)
|
||||
|
||||
####################################
|
||||
# REDIS
|
||||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
####################################
|
||||
# WEBUI_AUTH (Required for security)
|
||||
####################################
|
||||
|
||||
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||
)
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
||||
|
||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||
)
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
|
||||
WEBUI_SECRET_KEY = os.environ.get(
|
||||
"WEBUI_SECRET_KEY",
|
||||
os.environ.get(
|
||||
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
|
||||
), # DEPRECATED: remove at next major version
|
||||
)
|
||||
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax")
|
||||
|
||||
WEBUI_SESSION_COOKIE_SECURE = (
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE = os.environ.get(
|
||||
"WEBUI_AUTH_COOKIE_SAME_SITE", WEBUI_SESSION_COOKIE_SAME_SITE
|
||||
)
|
||||
|
||||
WEBUI_AUTH_COOKIE_SECURE = (
|
||||
os.environ.get(
|
||||
"WEBUI_AUTH_COOKIE_SECURE",
|
||||
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false"),
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
||||
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
|
||||
|
||||
ENABLE_WEBSOCKET_SUPPORT = (
|
||||
os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
|
||||
else:
|
||||
try:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
####################################
|
||||
|
||||
OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
|
||||
# Where to store log file
|
||||
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
|
||||
# Maximum size of a file before rotating into a new log file
|
||||
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
|
||||
# METADATA | REQUEST | REQUEST_RESPONSE
|
||||
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
|
||||
try:
|
||||
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
|
||||
except ValueError:
|
||||
MAX_BODY_LOG_SIZE = 2048
|
||||
|
||||
# Comma separated list for urls to exclude from audit
|
||||
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
|
||||
","
|
||||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
|
@ -0,0 +1,319 @@
|
|||
import logging
|
||||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
prepend_to_first_user_message_content,
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in request.app.state.FUNCTIONS:
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = request.app.state.FUNCTIONS[pipe_id]
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
return function_module
|
||||
|
||||
|
||||
async def get_function_models(request):
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
function_module = get_function_module_by_id(request, pipe.id)
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||
sub_pipes = await function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
sub_pipes = []
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
|
||||
)
|
||||
|
||||
for p in sub_pipes:
|
||||
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
sub_pipe_name = p["name"]
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
||||
|
||||
pipe_flag = {"type": pipe.type}
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": sub_pipe_id,
|
||||
"name": sub_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_flag = {"type": "pipe"}
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
||||
)
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
|
||||
return pipe_models
|
||||
|
||||
|
||||
async def generate_function_chat_completion(
|
||||
request, form_data, user, models: dict = {}
|
||||
):
|
||||
async def execute_pipe(pipe, params):
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
return await pipe(**params)
|
||||
else:
|
||||
return pipe(**params)
|
||||
|
||||
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
|
||||
if isinstance(res, str):
|
||||
return res
|
||||
if isinstance(res, Generator):
|
||||
return "".join(map(str, res))
|
||||
if isinstance(res, AsyncGenerator):
|
||||
return "".join([str(stream) async for stream in res])
|
||||
|
||||
def process_line(form_data: dict, line):
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
if isinstance(line, dict):
|
||||
line = f"data: {json.dumps(line)}"
|
||||
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
return f"{line}\n\n"
|
||||
else:
|
||||
line = openai_chat_chunk_message_template(form_data["model"], line)
|
||||
return f"data: {json.dumps(line)}\n\n"
|
||||
|
||||
def get_pipe_id(form_data: dict) -> str:
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, _ = pipe_id.split(".", 1)
|
||||
return pipe_id
|
||||
|
||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
if extra_params is None:
|
||||
extra_params = {}
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function_module.pipe)
|
||||
params = {"body": form_data} | {
|
||||
k: v for k, v in extra_params.items() if k in sig.parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
params["__user__"]["valves"] = function_module.UserValves()
|
||||
|
||||
return params
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
files = metadata.get("files", [])
|
||||
tool_ids = metadata.get("tool_ids", [])
|
||||
# Check if tool_ids is None
|
||||
if tool_ids is None:
|
||||
tool_ids = []
|
||||
|
||||
__event_emitter__ = None
|
||||
__event_call__ = None
|
||||
__task__ = None
|
||||
__task_body__ = None
|
||||
|
||||
if metadata:
|
||||
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
||||
__event_emitter__ = get_event_emitter(metadata)
|
||||
__event_call__ = get_event_call(metadata)
|
||||
__task__ = metadata.get("task", None)
|
||||
__task_body__ = metadata.get("task_body", None)
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
}
|
||||
extra_params["__tools__"] = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models.get(form_data["model"], None),
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
},
|
||||
)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data.get("stream", False):
|
||||
|
||||
async def stream_content():
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
# Directly return if the response is a StreamingResponse
|
||||
if isinstance(res, StreamingResponse):
|
||||
async for data in res.body_iterator:
|
||||
yield data
|
||||
return
|
||||
if isinstance(res, dict):
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = openai_chat_chunk_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, AsyncGenerator):
|
||||
async for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = openai_chat_chunk_message_template(
|
||||
form_data["model"], ""
|
||||
)
|
||||
finish_message["choices"][0]["finish_reason"] = "stop"
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield "data: [DONE]"
|
||||
|
||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||
else:
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
|
||||
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||
return res
|
||||
if isinstance(res, BaseModel):
|
||||
return res.model_dump()
|
||||
|
||||
message = await get_message_content(res)
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
import json
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from open_webui.internal.wrappers import register_connection
|
||||
from open_webui.env import (
|
||||
OPEN_WEBUI_DIR,
|
||||
DATABASE_URL,
|
||||
DATABASE_SCHEMA,
|
||||
SRC_LOG_LEVELS,
|
||||
DATABASE_POOL_MAX_OVERFLOW,
|
||||
DATABASE_POOL_RECYCLE,
|
||||
DATABASE_POOL_SIZE,
|
||||
DATABASE_POOL_TIMEOUT,
|
||||
)
|
||||
from peewee_migrate import Router
|
||||
from sqlalchemy import Dialect, create_engine, MetaData, types
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.pool import QueuePool, NullPool
|
||||
from sqlalchemy.sql.type_api import _T
|
||||
from typing_extensions import Self
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
|
||||
class JSONField(types.TypeDecorator):
|
||||
impl = types.Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
|
||||
return json.dumps(value)
|
||||
|
||||
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
def copy(self, **kw: Any) -> Self:
|
||||
return JSONField(self.impl.length)
|
||||
|
||||
def db_value(self, value):
|
||||
return json.dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
# Workaround to handle the peewee migration
|
||||
# This is required to ensure the peewee migration is handled before the alembic migration
|
||||
def handle_peewee_migration(DATABASE_URL):
|
||||
# db = None
|
||||
try:
|
||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||
router.run()
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Properly closing the database connection
|
||||
if db and not db.is_closed():
|
||||
db.close()
|
||||
|
||||
# Assert if db connection has been closed
|
||||
assert db.is_closed(), "Database connection is still open."
|
||||
|
||||
|
||||
handle_peewee_migration(DATABASE_URL)
|
||||
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||
if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
else:
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
pool_size=DATABASE_POOL_SIZE,
|
||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||
pool_timeout=DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=DATABASE_POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||
Base = declarative_base(metadata=metadata_obj)
|
||||
Session = scoped_session(SessionLocal)
|
||||
|
||||
|
||||
def get_session():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
get_db = contextmanager(get_session)
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
"""Peewee migrations -- 001_initial_schema.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# We perform different migrations for SQLite and other databases
|
||||
# This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite
|
||||
# will require per-database SQL queries.
|
||||
# Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base
|
||||
# schema instead of trying to migrate from an older schema.
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
migrate_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
migrate_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Auth(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
email = pw.CharField(max_length=255)
|
||||
password = pw.CharField(max_length=255)
|
||||
active = pw.BooleanField()
|
||||
|
||||
class Meta:
|
||||
table_name = "auth"
|
||||
|
||||
@migrator.create_model
|
||||
class Chat(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.CharField()
|
||||
chat = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
|
||||
@migrator.create_model
|
||||
class ChatIdTag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
tag_name = pw.CharField(max_length=255)
|
||||
chat_id = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
|
||||
@migrator.create_model
|
||||
class Document(pw.Model):
|
||||
id = pw.AutoField()
|
||||
collection_name = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255, unique=True)
|
||||
title = pw.CharField()
|
||||
filename = pw.CharField()
|
||||
content = pw.TextField(null=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
|
||||
@migrator.create_model
|
||||
class Modelfile(pw.Model):
|
||||
id = pw.AutoField()
|
||||
tag_name = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
modelfile = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
|
||||
@migrator.create_model
|
||||
class Prompt(pw.Model):
|
||||
id = pw.AutoField()
|
||||
command = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.CharField()
|
||||
content = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
|
||||
@migrator.create_model
|
||||
class Tag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
data = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "tag"
|
||||
|
||||
@migrator.create_model
|
||||
class User(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
email = pw.CharField(max_length=255)
|
||||
role = pw.CharField(max_length=255)
|
||||
profile_image_url = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
|
||||
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Auth(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
email = pw.CharField(max_length=255)
|
||||
password = pw.TextField()
|
||||
active = pw.BooleanField()
|
||||
|
||||
class Meta:
|
||||
table_name = "auth"
|
||||
|
||||
@migrator.create_model
|
||||
class Chat(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.TextField()
|
||||
chat = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
|
||||
@migrator.create_model
|
||||
class ChatIdTag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
tag_name = pw.CharField(max_length=255)
|
||||
chat_id = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
|
||||
@migrator.create_model
|
||||
class Document(pw.Model):
|
||||
id = pw.AutoField()
|
||||
collection_name = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255, unique=True)
|
||||
title = pw.TextField()
|
||||
filename = pw.TextField()
|
||||
content = pw.TextField(null=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
|
||||
@migrator.create_model
|
||||
class Modelfile(pw.Model):
|
||||
id = pw.AutoField()
|
||||
tag_name = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
modelfile = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
|
||||
@migrator.create_model
|
||||
class Prompt(pw.Model):
|
||||
id = pw.AutoField()
|
||||
command = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
title = pw.TextField()
|
||||
content = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
|
||||
@migrator.create_model
|
||||
class Tag(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
data = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "tag"
|
||||
|
||||
@migrator.create_model
|
||||
class User(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
name = pw.CharField(max_length=255)
|
||||
email = pw.CharField(max_length=255)
|
||||
role = pw.CharField(max_length=255)
|
||||
profile_image_url = pw.TextField()
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("user")
|
||||
|
||||
migrator.remove_model("tag")
|
||||
|
||||
migrator.remove_model("prompt")
|
||||
|
||||
migrator.remove_model("modelfile")
|
||||
|
||||
migrator.remove_model("document")
|
||||
|
||||
migrator.remove_model("chatidtag")
|
||||
|
||||
migrator.remove_model("chat")
|
||||
|
||||
migrator.remove_model("auth")
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("chat", "share_id")
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("user", "api_key")
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("chat", "archived")
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
migrate_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
migrate_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
created_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
updated_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
created_at=pw.DateTimeField(null=False),
|
||||
updated_at=pw.DateTimeField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
rollback_sqlite(migrator, database, fake=fake)
|
||||
else:
|
||||
rollback_external(migrator, database, fake=fake)
|
||||
|
||||
|
||||
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
|
||||
|
||||
|
||||
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
# Alter the tables with varchar to text where necessary
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
password=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
title=pw.TextField(),
|
||||
filename=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
profile_image_url=pw.TextField(),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
password=pw.CharField(max_length=255),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
title=pw.CharField(),
|
||||
filename=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
profile_image_url=pw.CharField(),
|
||||
)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields created_at and updated_at to the 'user' table
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL'
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("user", "timestamp")
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
last_active_at=pw.BigIntegerField(null=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql('UPDATE "user" SET timestamp = created_at')
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@migrator.create_model
|
||||
class Memory(pw.Model):
|
||||
id = pw.CharField(max_length=255, unique=True)
|
||||
user_id = pw.CharField(max_length=255)
|
||||
content = pw.TextField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "memory"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("memory")
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Model(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
base_model_id = pw.TextField(null=True)
|
||||
|
||||
name = pw.TextField()
|
||||
|
||||
meta = pw.TextField()
|
||||
params = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "model"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("model")
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
import json
|
||||
|
||||
from open_webui.utils.misc import parse_ollama_modelfile
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Fetch data from 'modelfile' table and insert into 'model' table
|
||||
migrate_modelfile_to_model(migrator, database)
|
||||
# Drop the 'modelfile' table
|
||||
migrator.remove_model("modelfile")
|
||||
|
||||
|
||||
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
||||
ModelFile = migrator.orm["modelfile"]
|
||||
Model = migrator.orm["model"]
|
||||
|
||||
modelfiles = ModelFile.select()
|
||||
|
||||
for modelfile in modelfiles:
|
||||
# Extract and transform data in Python
|
||||
|
||||
modelfile.modelfile = json.loads(modelfile.modelfile)
|
||||
meta = json.dumps(
|
||||
{
|
||||
"description": modelfile.modelfile.get("desc"),
|
||||
"profile_image_url": modelfile.modelfile.get("imageUrl"),
|
||||
"ollama": {"modelfile": modelfile.modelfile.get("content")},
|
||||
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
|
||||
"categories": modelfile.modelfile.get("categories"),
|
||||
"user": {**modelfile.modelfile.get("user", {}), "community": True},
|
||||
}
|
||||
)
|
||||
|
||||
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
|
||||
|
||||
# Insert the processed data into the 'model' table
|
||||
Model.create(
|
||||
id=f"ollama-{modelfile.tag_name}",
|
||||
user_id=modelfile.user_id,
|
||||
base_model_id=info.get("base_model_id"),
|
||||
name=modelfile.modelfile.get("title"),
|
||||
meta=meta,
|
||||
params=json.dumps(info.get("params", {})),
|
||||
created_at=modelfile.timestamp,
|
||||
updated_at=modelfile.timestamp,
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
recreate_modelfile_table(migrator, database)
|
||||
move_data_back_to_modelfile(migrator, database)
|
||||
migrator.remove_model("model")
|
||||
|
||||
|
||||
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS modelfile (
|
||||
user_id TEXT,
|
||||
tag_name TEXT,
|
||||
modelfile JSON,
|
||||
timestamp BIGINT
|
||||
)
|
||||
"""
|
||||
migrator.sql(query)
|
||||
|
||||
|
||||
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
||||
Model = migrator.orm["model"]
|
||||
Modelfile = migrator.orm["modelfile"]
|
||||
|
||||
models = Model.select()
|
||||
|
||||
for model in models:
|
||||
# Extract and transform data in Python
|
||||
meta = json.loads(model.meta)
|
||||
|
||||
modelfile_data = {
|
||||
"title": model.name,
|
||||
"desc": meta.get("description"),
|
||||
"imageUrl": meta.get("profile_image_url"),
|
||||
"content": meta.get("ollama", {}).get("modelfile"),
|
||||
"suggestionPrompts": meta.get("suggestion_prompts"),
|
||||
"categories": meta.get("categories"),
|
||||
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
|
||||
}
|
||||
|
||||
# Insert the processed data back into the 'modelfile' table
|
||||
Modelfile.create(
|
||||
user_id=model.user_id,
|
||||
tag_name=model.id,
|
||||
modelfile=modelfile_data,
|
||||
timestamp=model.created_at,
|
||||
)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields settings to the 'user' table
|
||||
migrator.add_fields("user", settings=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "settings")
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Tool(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
|
||||
name = pw.TextField()
|
||||
content = pw.TextField()
|
||||
specs = pw.TextField()
|
||||
|
||||
meta = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "tool"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("tool")
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields info to the 'user' table
|
||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "info")
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class File(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
filename = pw.TextField()
|
||||
meta = pw.TextField()
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "file"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("file")
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Function(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
|
||||
name = pw.TextField()
|
||||
type = pw.TextField()
|
||||
|
||||
content = pw.TextField()
|
||||
meta = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "function"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("function")
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("tool", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("tool", "valves")
|
||||
migrator.remove_fields("function", "valves")
|
||||
migrator.remove_fields("function", "is_active")
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""Peewee migrations -- 017_add_user_oauth_sub.py.
|
||||
Some examples (model - class or model name)::
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
oauth_sub=pw.TextField(null=True, unique=True),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("user", "oauth_sub")
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Peewee migrations -- 017_add_user_oauth_sub.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"function",
|
||||
is_global=pw.BooleanField(default=False),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("function", "is_global")
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import logging
|
||||
from contextvars import ContextVar
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from peewee import *
|
||||
from peewee import InterfaceError as PeeWeeInterfaceError
|
||||
from peewee import PostgresqlDatabase
|
||||
from playhouse.db_url import connect, parse
|
||||
from playhouse.shortcuts import ReconnectMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||
|
||||
|
||||
class PeeweeConnectionState(object):
|
||||
def __init__(self, **kwargs):
|
||||
super().__setattr__("_state", db_state)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self._state.get()[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
value = self._state.get()[name]
|
||||
return value
|
||||
|
||||
|
||||
class CustomReconnectMixin(ReconnectMixin):
|
||||
reconnect_errors = (
|
||||
# psycopg2
|
||||
(OperationalError, "termin"),
|
||||
(InterfaceError, "closed"),
|
||||
# peewee
|
||||
(PeeWeeInterfaceError, "closed"),
|
||||
)
|
||||
|
||||
|
||||
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||
pass
|
||||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url, unquote_password=True)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_password=True)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(**connection)
|
||||
db.connect(reuse_if_open=True)
|
||||
elif isinstance(db, SqliteDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
return db
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,4 @@
|
|||
Generic single-database configuration.
|
||||
|
||||
Create new migrations with
|
||||
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from open_webui.models.auths import Auth
|
||||
from open_webui.env import DATABASE_URL
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Auth.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
DB_URL = DATABASE_URL
|
||||
|
||||
if DB_URL:
|
||||
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from alembic import op
|
||||
from sqlalchemy import Inspector
|
||||
|
||||
|
||||
def get_existing_tables():
|
||||
con = op.get_bind()
|
||||
inspector = Inspector.from_engine(con)
|
||||
tables = set(inspector.get_table_names())
|
||||
return tables
|
||||
|
||||
|
||||
def get_revision_id():
|
||||
import uuid
|
||||
|
||||
return str(uuid.uuid4()).replace("-", "")[:12]
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
"""Migrate tags
|
||||
|
||||
Revision ID: 1af9b942657b
|
||||
Revises: 242a2047eae0
|
||||
Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update, column
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
revision = "1af9b942657b"
|
||||
down_revision = "242a2047eae0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Setup an inspection on the existing table to avoid issues
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
# Clean up potential leftover temp table from previous failures
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
|
||||
|
||||
# Check if the 'tag' table exists
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
# Step 1: Modify Tag table using batch mode for SQLite support
|
||||
if "tag" in tables:
|
||||
# Get the current columns in the 'tag' table
|
||||
columns = [col["name"] for col in inspector.get_columns("tag")]
|
||||
|
||||
# Get any existing unique constraints on the 'tag' table
|
||||
current_constraints = inspector.get_unique_constraints("tag")
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Check if the unique constraint already exists
|
||||
if not any(
|
||||
constraint["name"] == "uq_id_user_id"
|
||||
for constraint in current_constraints
|
||||
):
|
||||
# Create unique constraint if it doesn't exist
|
||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
||||
|
||||
# Check if the 'data' column exists before trying to drop it
|
||||
if "data" in columns:
|
||||
batch_op.drop_column("data")
|
||||
|
||||
# Check if the 'meta' column needs to be created
|
||||
if "meta" not in columns:
|
||||
# Add the 'meta' column if it doesn't already exist
|
||||
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
|
||||
|
||||
tag = table(
|
||||
"tag",
|
||||
column("id", sa.String()),
|
||||
column("name", sa.String()),
|
||||
column("user_id", sa.String()),
|
||||
column("meta", sa.JSON()),
|
||||
)
|
||||
|
||||
# Step 2: Migrate tags
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
|
||||
|
||||
tag_updates = {}
|
||||
for row in result:
|
||||
new_id = row.name.replace(" ", "_").lower()
|
||||
tag_updates[row.id] = new_id
|
||||
|
||||
for tag_id, new_tag_id in tag_updates.items():
|
||||
print(f"Updating tag {tag_id} to {new_tag_id}")
|
||||
if new_tag_id == "pinned":
|
||||
# delete tag
|
||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||
conn.execute(delete_stmt)
|
||||
else:
|
||||
# Check if the new_tag_id already exists in the database
|
||||
existing_tag_query = sa.select(tag.c.id).where(tag.c.id == new_tag_id)
|
||||
existing_tag_result = conn.execute(existing_tag_query).fetchone()
|
||||
|
||||
if existing_tag_result:
|
||||
# Handle duplicate case: the new_tag_id already exists
|
||||
print(
|
||||
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
|
||||
)
|
||||
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
|
||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||
conn.execute(delete_stmt)
|
||||
else:
|
||||
update_stmt = sa.update(tag).where(tag.c.id == tag_id)
|
||||
update_stmt = update_stmt.values(id=new_tag_id)
|
||||
conn.execute(update_stmt)
|
||||
|
||||
# Add columns `pinned` and `meta` to 'chat'
|
||||
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
|
||||
op.add_column(
|
||||
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
|
||||
)
|
||||
|
||||
chatidtag = table(
|
||||
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
|
||||
)
|
||||
chat = table(
|
||||
"chat",
|
||||
column("id", sa.String()),
|
||||
column("pinned", sa.Boolean()),
|
||||
column("meta", sa.JSON()),
|
||||
)
|
||||
|
||||
# Fetch existing tags
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
|
||||
|
||||
chat_updates = {}
|
||||
for row in result:
|
||||
chat_id = row.chat_id
|
||||
tag_name = row.tag_name.replace(" ", "_").lower()
|
||||
|
||||
if tag_name == "pinned":
|
||||
# Specifically handle 'pinned' tag
|
||||
if chat_id not in chat_updates:
|
||||
chat_updates[chat_id] = {"pinned": True, "meta": {}}
|
||||
else:
|
||||
chat_updates[chat_id]["pinned"] = True
|
||||
else:
|
||||
if chat_id not in chat_updates:
|
||||
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
|
||||
else:
|
||||
tags = chat_updates[chat_id]["meta"].get("tags", [])
|
||||
tags.append(tag_name)
|
||||
|
||||
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
|
||||
|
||||
# Update chats based on accumulated changes
|
||||
for chat_id, updates in chat_updates.items():
|
||||
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
|
||||
update_stmt = update_stmt.values(
|
||||
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
|
||||
)
|
||||
conn.execute(update_stmt)
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Update chat table
|
||||
|
||||
Revision ID: 242a2047eae0
|
||||
Revises: 6a39f3d8e55c
|
||||
Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update
|
||||
|
||||
import json
|
||||
|
||||
revision = "242a2047eae0"
|
||||
down_revision = "6a39f3d8e55c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
columns = inspector.get_columns("chat")
|
||||
column_dict = {col["name"]: col for col in columns}
|
||||
|
||||
chat_column = column_dict.get("chat")
|
||||
old_chat_exists = "old_chat" in column_dict
|
||||
|
||||
if chat_column:
|
||||
if isinstance(chat_column["type"], sa.Text):
|
||||
print("Converting 'chat' column to JSON")
|
||||
|
||||
if old_chat_exists:
|
||||
print("Dropping old 'old_chat' column")
|
||||
op.drop_column("chat", "old_chat")
|
||||
|
||||
# Step 1: Rename current 'chat' column to 'old_chat'
|
||||
print("Renaming 'chat' column to 'old_chat'")
|
||||
op.alter_column(
|
||||
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
|
||||
)
|
||||
|
||||
# Step 2: Add new 'chat' column of type JSON
|
||||
print("Adding new 'chat' column of type JSON")
|
||||
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
|
||||
else:
|
||||
# If the column is already JSON, no need to do anything
|
||||
pass
|
||||
|
||||
# Step 3: Migrate data from 'old_chat' to 'chat'
|
||||
chat_table = table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), primary_key=True),
|
||||
sa.Column("old_chat", sa.Text()),
|
||||
sa.Column("chat", sa.JSON()),
|
||||
)
|
||||
|
||||
# - Selecting all data from the table
|
||||
connection = op.get_bind()
|
||||
results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
|
||||
for row in results:
|
||||
try:
|
||||
# Convert text JSON to actual JSON object, assuming the text is in JSON format
|
||||
json_data = json.loads(row.old_chat)
|
||||
except json.JSONDecodeError:
|
||||
json_data = None # Handle cases where the text cannot be converted to JSON
|
||||
|
||||
connection.execute(
|
||||
sa.update(chat_table)
|
||||
.where(chat_table.c.id == row.id)
|
||||
.values(chat=json_data)
|
||||
)
|
||||
|
||||
# Step 4: Drop 'old_chat' column
|
||||
print("Dropping 'old_chat' column")
|
||||
op.drop_column("chat", "old_chat")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Step 1: Add 'old_chat' column back as Text
|
||||
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
|
||||
|
||||
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
|
||||
chat_table = table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), primary_key=True),
|
||||
sa.Column("chat", sa.JSON()),
|
||||
sa.Column("old_chat", sa.Text()),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
|
||||
for row in results:
|
||||
text_data = json.dumps(row.chat) if row.chat is not None else None
|
||||
connection.execute(
|
||||
sa.update(chat_table)
|
||||
.where(chat_table.c.id == row.id)
|
||||
.values(old_chat=text_data)
|
||||
)
|
||||
|
||||
# Step 3: Remove the new 'chat' JSON column
|
||||
op.drop_column("chat", "chat")
|
||||
|
||||
# Step 4: Rename 'old_chat' back to 'chat'
|
||||
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
"""Update message & channel tables
|
||||
|
||||
Revision ID: 3781e22d8b01
|
||||
Revises: 7826ab40b532
|
||||
Create Date: 2024-12-30 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "3781e22d8b01"
|
||||
down_revision = "7826ab40b532"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add 'type' column to the 'channel' table
|
||||
op.add_column(
|
||||
"channel",
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Text(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Add 'parent_id' column to the 'message' table for threads
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message_reaction",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Unique reaction ID
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
|
||||
sa.Column(
|
||||
"message_id", sa.Text(), nullable=False
|
||||
), # Message that was reacted to
|
||||
sa.Column(
|
||||
"name", sa.Text(), nullable=False
|
||||
), # Reaction name (e.g. "thumbs_up")
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the reaction was added
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"channel_member",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Record ID for the membership row
|
||||
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the user joined the channel
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Revert 'type' column addition to the 'channel' table
|
||||
op.drop_column("channel", "type")
|
||||
op.drop_column("message", "parent_id")
|
||||
op.drop_table("message_reaction")
|
||||
op.drop_table("channel_member")
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""Update tags
|
||||
|
||||
Revision ID: 3ab32c4b8f59
|
||||
Revises: 1af9b942657b
|
||||
Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update, column
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
revision = "3ab32c4b8f59"
|
||||
down_revision = "1af9b942657b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
# Inspecting the 'tag' table constraints and structure
|
||||
existing_pk = inspector.get_pk_constraint("tag")
|
||||
unique_constraints = inspector.get_unique_constraints("tag")
|
||||
existing_indexes = inspector.get_indexes("tag")
|
||||
|
||||
print(f"Primary Key: {existing_pk}")
|
||||
print(f"Unique Constraints: {unique_constraints}")
|
||||
print(f"Indexes: {existing_indexes}")
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Drop existing primary key constraint if it exists
|
||||
if existing_pk and existing_pk.get("constrained_columns"):
|
||||
pk_name = existing_pk.get("name")
|
||||
if pk_name:
|
||||
print(f"Dropping primary key constraint: {pk_name}")
|
||||
batch_op.drop_constraint(pk_name, type_="primary")
|
||||
|
||||
# Now create the new primary key with the combination of 'id' and 'user_id'
|
||||
print("Creating new primary key with 'id' and 'user_id'.")
|
||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
||||
|
||||
# Drop unique constraints that could conflict with the new primary key
|
||||
for constraint in unique_constraints:
|
||||
if (
|
||||
constraint["name"] == "uq_id_user_id"
|
||||
): # Adjust this name according to what is actually returned by the inspector
|
||||
print(f"Dropping unique constraint: {constraint['name']}")
|
||||
batch_op.drop_constraint(constraint["name"], type_="unique")
|
||||
|
||||
for index in existing_indexes:
|
||||
if index["unique"]:
|
||||
if not any(
|
||||
constraint["name"] == index["name"]
|
||||
for constraint in unique_constraints
|
||||
):
|
||||
# You are attempting to drop unique indexes
|
||||
print(f"Dropping unique index: {index['name']}")
|
||||
batch_op.drop_index(index["name"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
current_pk = inspector.get_pk_constraint("tag")
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
# Drop the current primary key first, if it matches the one we know we added in upgrade
|
||||
if current_pk and "pk_id_user_id" == current_pk.get("name"):
|
||||
batch_op.drop_constraint("pk_id_user_id", type_="primary")
|
||||
|
||||
# Restore the original primary key
|
||||
batch_op.create_primary_key("pk_id", ["id"])
|
||||
|
||||
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
|
||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Update folder table and change DateTime to BigInteger for timestamp fields
|
||||
|
||||
Revision ID: 4ace53fd72c8
|
||||
Revises: af906e964978
|
||||
Create Date: 2024-10-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "4ace53fd72c8"
|
||||
down_revision = "af906e964978"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Perform safe alterations using batch operation
|
||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
||||
# Step 1: Remove server defaults for created_at and updated_at
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
server_default=None, # Removing server default
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
server_default=None, # Removing server default
|
||||
)
|
||||
|
||||
# Step 2: Change the column types to BigInteger for created_at
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
type_=sa.BigInteger(),
|
||||
existing_type=sa.DateTime(),
|
||||
existing_nullable=False,
|
||||
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
|
||||
)
|
||||
|
||||
# Change the column types to BigInteger for updated_at
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
type_=sa.BigInteger(),
|
||||
existing_type=sa.DateTime(),
|
||||
existing_nullable=False,
|
||||
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Downgrade: Convert columns back to DateTime and restore defaults
|
||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
type_=sa.DateTime(),
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
server_default=sa.func.now(), # Restoring server default on downgrade
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
type_=sa.DateTime(),
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
server_default=sa.func.now(), # Restoring server default on downgrade
|
||||
onupdate=sa.func.now(), # Restoring onupdate behavior if it was there
|
||||
)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
"""Add channel table
|
||||
|
||||
Revision ID: 57c599a3cb57
|
||||
Revises: 922e7a387820
|
||||
Create Date: 2024-12-22 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "57c599a3cb57"
|
||||
down_revision = "922e7a387820"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"channel",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("name", sa.Text()),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text()),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("channel")
|
||||
|
||||
op.drop_table("message")
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""Add knowledge table
|
||||
|
||||
Revision ID: 6a39f3d8e55c
|
||||
Revises: c0fbf31ca0db
|
||||
Create Date: 2024-10-01 14:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
import json
|
||||
|
||||
|
||||
revision = "6a39f3d8e55c"
|
||||
down_revision = "c0fbf31ca0db"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Creating the 'knowledge' table
|
||||
print("Creating knowledge table")
|
||||
knowledge_table = op.create_table(
|
||||
"knowledge",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
print("Migrating data from document table to knowledge table")
|
||||
# Representation of the existing 'document' table
|
||||
document_table = table(
|
||||
"document",
|
||||
column("collection_name", sa.String()),
|
||||
column("user_id", sa.String()),
|
||||
column("name", sa.String()),
|
||||
column("title", sa.Text()),
|
||||
column("content", sa.Text()),
|
||||
column("timestamp", sa.BigInteger()),
|
||||
)
|
||||
|
||||
# Select all from existing document table
|
||||
documents = op.get_bind().execute(
|
||||
select(
|
||||
document_table.c.collection_name,
|
||||
document_table.c.user_id,
|
||||
document_table.c.name,
|
||||
document_table.c.title,
|
||||
document_table.c.content,
|
||||
document_table.c.timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Insert data into knowledge table from document table
|
||||
for doc in documents:
|
||||
op.get_bind().execute(
|
||||
knowledge_table.insert().values(
|
||||
id=doc.collection_name,
|
||||
user_id=doc.user_id,
|
||||
description=doc.name,
|
||||
meta={
|
||||
"legacy": True,
|
||||
"document": True,
|
||||
"tags": json.loads(doc.content or "{}").get("tags", []),
|
||||
},
|
||||
name=doc.title,
|
||||
created_at=doc.timestamp,
|
||||
updated_at=doc.timestamp, # using created_at for both created_at and updated_at in project
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("knowledge")
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""Update file table
|
||||
|
||||
Revision ID: 7826ab40b532
|
||||
Revises: 57c599a3cb57
|
||||
Create Date: 2024-12-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7826ab40b532"
|
||||
down_revision = "57c599a3cb57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"file",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("file", "access_control")
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
"""init
|
||||
|
||||
Revision ID: 7e5b5dc7342b
|
||||
Revises:
|
||||
Create Date: 2024-06-24 13:15:33.808998
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import open_webui.internal.db
|
||||
from open_webui.internal.db import JSONField
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7e5b5dc7342b"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
existing_tables = set(get_existing_tables())
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if "auth" not in existing_tables:
|
||||
op.create_table(
|
||||
"auth",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("password", sa.Text(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "chat" not in existing_tables:
|
||||
op.create_table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("chat", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("share_id", sa.Text(), nullable=True),
|
||||
sa.Column("archived", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("share_id"),
|
||||
)
|
||||
|
||||
if "chatidtag" not in existing_tables:
|
||||
op.create_table(
|
||||
"chatidtag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("tag_name", sa.String(), nullable=True),
|
||||
sa.Column("chat_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "document" not in existing_tables:
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("collection_name", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("collection_name"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
if "file" not in existing_tables:
|
||||
op.create_table(
|
||||
"file",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "function" not in existing_tables:
|
||||
op.create_table(
|
||||
"function",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_global", sa.Boolean(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "memory" not in existing_tables:
|
||||
op.create_table(
|
||||
"memory",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "model" not in existing_tables:
|
||||
op.create_table(
|
||||
"model",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("base_model_id", sa.Text(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("params", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "prompt" not in existing_tables:
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("command", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("command"),
|
||||
)
|
||||
|
||||
if "tag" not in existing_tables:
|
||||
op.create_table(
|
||||
"tag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("data", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "tool" not in existing_tables:
|
||||
op.create_table(
|
||||
"tool",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("specs", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
if "user" not in existing_tables:
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("role", sa.String(), nullable=True),
|
||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
||||
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("settings", JSONField(), nullable=True),
|
||||
sa.Column("info", JSONField(), nullable=True),
|
||||
sa.Column("oauth_sub", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("api_key"),
|
||||
sa.UniqueConstraint("oauth_sub"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user")
|
||||
op.drop_table("tool")
|
||||
op.drop_table("tag")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("model")
|
||||
op.drop_table("memory")
|
||||
op.drop_table("function")
|
||||
op.drop_table("file")
|
||||
op.drop_table("document")
|
||||
op.drop_table("chatidtag")
|
||||
op.drop_table("chat")
|
||||
op.drop_table("auth")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Add group table
|
||||
|
||||
Revision ID: 922e7a387820
|
||||
Revises: 4ace53fd72c8
|
||||
Create Date: 2024-11-14 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "922e7a387820"
|
||||
down_revision = "4ace53fd72c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"group",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("permissions", sa.JSON(), nullable=True),
|
||||
sa.Column("user_ids", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'model' table
|
||||
op.add_column(
|
||||
"model",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'is_active' column to 'model' table
|
||||
op.add_column(
|
||||
"model",
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.sql.expression.true(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'knowledge' table
|
||||
op.add_column(
|
||||
"knowledge",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'prompt' table
|
||||
op.add_column(
|
||||
"prompt",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'tools' table
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("group")
|
||||
|
||||
# Drop 'access_control' column from 'model' table
|
||||
op.drop_column("model", "access_control")
|
||||
|
||||
# Drop 'is_active' column from 'model' table
|
||||
op.drop_column("model", "is_active")
|
||||
|
||||
# Drop 'access_control' column from 'knowledge' table
|
||||
op.drop_column("knowledge", "access_control")
|
||||
|
||||
# Drop 'access_control' column from 'prompt' table
|
||||
op.drop_column("prompt", "access_control")
|
||||
|
||||
# Drop 'access_control' column from 'tools' table
|
||||
op.drop_column("tool", "access_control")
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Add feedback table
|
||||
|
||||
Revision ID: af906e964978
|
||||
Revises: c29facfe716b
|
||||
Create Date: 2024-10-20 17:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# Revision identifiers, used by Alembic.
|
||||
revision = "af906e964978"
|
||||
down_revision = "c29facfe716b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### Create feedback table ###
|
||||
op.create_table(
|
||||
"feedback",
|
||||
sa.Column(
|
||||
"id", sa.Text(), primary_key=True
|
||||
), # Unique identifier for each feedback (TEXT type)
|
||||
sa.Column(
|
||||
"user_id", sa.Text(), nullable=True
|
||||
), # ID of the user providing the feedback (TEXT type)
|
||||
sa.Column(
|
||||
"version", sa.BigInteger(), default=0
|
||||
), # Version of feedback (BIGINT type)
|
||||
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
|
||||
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
|
||||
sa.Column(
|
||||
"meta", sa.JSON(), nullable=True
|
||||
), # Metadata for feedback (JSON type)
|
||||
sa.Column(
|
||||
"snapshot", sa.JSON(), nullable=True
|
||||
), # snapshot data for feedback (JSON type)
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=False
|
||||
), # Feedback creation timestamp (BIGINT representing epoch)
|
||||
sa.Column(
|
||||
"updated_at", sa.BigInteger(), nullable=False
|
||||
), # Feedback update timestamp (BIGINT representing epoch)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### Drop feedback table ###
|
||||
op.drop_table("feedback")
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""Update file table
|
||||
|
||||
Revision ID: c0fbf31ca0db
|
||||
Revises: ca81bd47c050
|
||||
Create Date: 2024-09-20 15:26:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c0fbf31ca0db"
|
||||
down_revision: Union[str, None] = "ca81bd47c050"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
|
||||
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("file", "updated_at")
|
||||
op.drop_column("file", "data")
|
||||
op.drop_column("file", "hash")
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""Update file table path
|
||||
|
||||
Revision ID: c29facfe716b
|
||||
Revises: c69f45358db4
|
||||
Create Date: 2024-10-20 17:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import String, Text, JSON, and_
|
||||
|
||||
|
||||
revision = "c29facfe716b"
|
||||
down_revision = "c69f45358db4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# 1. Add the `path` column to the "file" table.
|
||||
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
|
||||
|
||||
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
|
||||
# Use Alembic's default batch_op for dialect compatibility.
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta",
|
||||
type_=sa.JSON(),
|
||||
existing_type=sa.Text(),
|
||||
existing_nullable=True,
|
||||
nullable=True,
|
||||
postgresql_using="meta::json",
|
||||
)
|
||||
|
||||
# 3. Migrate legacy data from `meta` JSONField
|
||||
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
|
||||
# We will use SQLAlchemy core bindings to ensure safety across different databases.
|
||||
|
||||
file_table = table(
|
||||
"file", column("id", String), column("meta", JSON), column("path", Text)
|
||||
)
|
||||
|
||||
# Create connection to the database
|
||||
connection = op.get_bind()
|
||||
|
||||
# Get the rows where `meta` has a path and `path` column is null (new column)
|
||||
# Loop through each row in the result set to update the path
|
||||
results = connection.execute(
|
||||
sa.select(file_table.c.id, file_table.c.meta).where(
|
||||
and_(file_table.c.path.is_(None), file_table.c.meta.isnot(None))
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Iterate over each row to extract and update the `path` from `meta` column
|
||||
for row in results:
|
||||
if "path" in row.meta:
|
||||
# Extract the `path` field from the `meta` JSON
|
||||
path = row.meta.get("path")
|
||||
|
||||
# Update the `file` table with the new `path` value
|
||||
connection.execute(
|
||||
file_table.update()
|
||||
.where(file_table.c.id == row.id)
|
||||
.values({"path": path})
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# 1. Remove the `path` column
|
||||
op.drop_column("file", "path")
|
||||
|
||||
# 2. Revert the `meta` column back to Text/JSONField
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
|
||||
)
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Add folder table
|
||||
|
||||
Revision ID: c69f45358db4
|
||||
Revises: 3ab32c4b8f59
|
||||
Create Date: 2024-10-16 02:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c69f45358db4"
|
||||
down_revision = "3ab32c4b8f59"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"folder",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("items", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat",
|
||||
sa.Column("folder_id", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("chat", "folder_id")
|
||||
|
||||
op.drop_table("folder")
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Add config table
|
||||
|
||||
Revision ID: ca81bd47c050
|
||||
Revises: 7e5b5dc7342b
|
||||
Create Date: 2024-08-25 15:26:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ca81bd47c050"
|
||||
down_revision: Union[str, None] = "7e5b5dc7342b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"config",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("data", sa.JSON(), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("config")
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import UserModel, Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Boolean, Column, String, Text
|
||||
from open_webui.utils.auth import verify_password
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# DB MODEL
|
||||
####################
|
||||
|
||||
|
||||
class Auth(Base):
|
||||
__tablename__ = "auth"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
email = Column(String)
|
||||
password = Column(Text)
|
||||
active = Column(Boolean)
|
||||
|
||||
|
||||
class AuthModel(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
password: str
|
||||
active: bool = True
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
role: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class SigninResponse(Token, UserResponse):
|
||||
pass
|
||||
|
||||
|
||||
class SigninForm(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class LdapForm(BaseModel):
|
||||
user: str
|
||||
password: str
|
||||
|
||||
|
||||
class ProfileImageUrlForm(BaseModel):
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UpdateProfileForm(BaseModel):
|
||||
profile_image_url: str
|
||||
name: str
|
||||
|
||||
|
||||
class UpdatePasswordForm(BaseModel):
|
||||
password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class SignupForm(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
profile_image_url: Optional[str] = "/user.png"
|
||||
|
||||
|
||||
class AddUserForm(SignupForm):
|
||||
role: Optional[str] = "pending"
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
def insert_new_auth(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
log.info("insert_new_auth")
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
auth = AuthModel(
|
||||
**{"id": id, "email": email, "password": password, "active": True}
|
||||
)
|
||||
result = Auth(**auth.model_dump())
|
||||
db.add(result)
|
||||
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth_sub
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
if result and user:
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
if auth:
|
||||
if verify_password(password, auth.password):
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
return user if user else None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_trusted_header: {email}")
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
if auth:
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = (
|
||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
)
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_email_by_id(self, id: str, email: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(Auth).filter_by(id=id).update({"email": email})
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_auth_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Delete User
|
||||
result = Users.delete_user_by_id(id)
|
||||
|
||||
if result:
|
||||
db.query(Auth).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Auths = AuthsTable()
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Channel DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Channel(Base):
|
||||
__tablename__ = "channel"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
type = Column(Text, nullable=True)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class ChannelModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
type: Optional[str] = None
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ChannelForm(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class ChannelTable:
|
||||
def insert_new_channel(
|
||||
self, type: Optional[str], form_data: ChannelForm, user_id: str
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = ChannelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"type": type,
|
||||
"name": form_data.name.lower(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
}
|
||||
)
|
||||
|
||||
new_channel = Channel(**channel.model_dump())
|
||||
|
||||
db.add(new_channel)
|
||||
db.commit()
|
||||
return channel
|
||||
|
||||
def get_channels(self) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, permission: str = "read"
|
||||
) -> list[ChannelModel]:
|
||||
channels = self.get_channels()
|
||||
return [
|
||||
channel
|
||||
for channel in channels
|
||||
if channel.user_id == user_id
|
||||
or has_access(user_id, permission, channel.access_control)
|
||||
]
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def update_channel_by_id(
|
||||
self, id: str, form_data: ChannelForm
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
channel.name = form_data.name
|
||||
channel.data = form_data.data
|
||||
channel.meta = form_data.meta
|
||||
channel.access_control = form_data.access_control
|
||||
channel.updated_at = int(time.time_ns())
|
||||
|
||||
db.commit()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def delete_channel_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Channel).filter(Channel.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Channels = ChannelTable()
|
||||
|
|
@ -0,0 +1,912 @@
|
|||
import logging
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
####################
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
__tablename__ = "chat"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
chat = Column(JSON)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
share_id = Column(Text, unique=True, nullable=True)
|
||||
archived = Column(Boolean, default=False)
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default="{}")
|
||||
folder_id = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
title: str
|
||||
chat: dict
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
share_id: Optional[str] = None
|
||||
archived: bool = False
|
||||
pinned: Optional[bool] = False
|
||||
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ChatForm(BaseModel):
|
||||
chat: dict
|
||||
|
||||
|
||||
class ChatImportForm(ChatForm):
|
||||
meta: Optional[dict] = {}
|
||||
pinned: Optional[bool] = False
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatTitleMessagesForm(BaseModel):
|
||||
title: str
|
||||
messages: list[dict]
|
||||
|
||||
|
||||
class ChatTitleForm(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
title: str
|
||||
chat: dict
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
share_id: Optional[str] = None # id of the chat to be shared
|
||||
archived: bool
|
||||
pinned: Optional[bool] = False
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
updated_at: int
|
||||
created_at: int
|
||||
|
||||
|
||||
class ChatTable:
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": form_data.chat,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
result = Chat(**chat.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def import_chat(
|
||||
self, user_id: str, form_data: ChatImportForm
|
||||
) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": form_data.chat,
|
||||
"meta": form_data.meta,
|
||||
"pinned": form_data.pinned,
|
||||
"folder_id": form_data.folder_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
result = Chat(**chat.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat_item = db.get(Chat, id)
|
||||
chat_item.chat = chat
|
||||
chat_item.title = chat["title"] if "title" in chat else "New Chat"
|
||||
chat_item.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat_item)
|
||||
|
||||
return ChatModel.model_validate(chat_item)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
chat["title"] = title
|
||||
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def update_chat_tags_by_id(
|
||||
self, id: str, tags: list[str], user
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
self.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
for tag_name in tags:
|
||||
if tag_name.lower() == "none":
|
||||
continue
|
||||
|
||||
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
|
||||
return self.get_chat_by_id(id)
|
||||
|
||||
def get_chat_title_by_id(self, id: str) -> Optional[str]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("title", "New Chat")
|
||||
|
||||
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}) or {}
|
||||
|
||||
def get_message_by_id_and_message_id(
|
||||
self, id: str, message_id: str
|
||||
) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}).get(message_id, {})
|
||||
|
||||
def upsert_message_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, message: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
history["messages"][message_id] = {
|
||||
**history["messages"][message_id],
|
||||
**message,
|
||||
}
|
||||
else:
|
||||
history["messages"][message_id] = message
|
||||
|
||||
history["currentId"] = message_id
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def add_message_status_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, status: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
status_history = history["messages"][message_id].get("statusHistory", [])
|
||||
status_history.append(status)
|
||||
history["messages"][message_id]["statusHistory"] = status_history
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
# Get the existing chat to share
|
||||
chat = db.get(Chat, chat_id)
|
||||
# Check if the chat is already shared
|
||||
if chat.share_id:
|
||||
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
|
||||
# Create a new chat with the same data, but with a new ID
|
||||
shared_chat = ChatModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": f"shared-{chat_id}",
|
||||
"title": chat.title,
|
||||
"chat": chat.chat,
|
||||
"created_at": chat.created_at,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
shared_result = Chat(**shared_chat.model_dump())
|
||||
db.add(shared_result)
|
||||
db.commit()
|
||||
db.refresh(shared_result)
|
||||
|
||||
# Update the original chat with the share_id
|
||||
result = (
|
||||
db.query(Chat)
|
||||
.filter_by(id=chat_id)
|
||||
.update({"share_id": shared_chat.id})
|
||||
)
|
||||
db.commit()
|
||||
return shared_chat if (shared_result and result) else None
|
||||
|
||||
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, chat_id)
|
||||
shared_chat = (
|
||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").first()
|
||||
)
|
||||
|
||||
if shared_chat is None:
|
||||
return self.insert_shared_chat_by_chat_id(chat_id)
|
||||
|
||||
shared_chat.title = chat.title
|
||||
shared_chat.chat = chat.chat
|
||||
|
||||
shared_chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(shared_chat)
|
||||
|
||||
return ChatModel.model_validate(shared_chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_chat_share_id_by_id(
|
||||
self, id: str, share_id: Optional[str]
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.share_id = share_id
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.pinned = not chat.pinned
|
||||
chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.archived = not chat.archived
|
||||
chat.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_archived_chat_list_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, archived=True)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
.all()
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_title_id_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
include_archived: bool = False,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc()).with_entities(
|
||||
Chat.id, Chat.title, Chat.updated_at, Chat.created_at
|
||||
)
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
|
||||
# result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
{
|
||||
"id": chat[0],
|
||||
"title": chat[1],
|
||||
"updated_at": chat[2],
|
||||
"created_at": chat[3],
|
||||
}
|
||||
)
|
||||
for chat in all_chats
|
||||
]
|
||||
|
||||
def get_chat_list_by_chat_ids(
|
||||
self, chat_ids: list[str], skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter(Chat.id.in_(chat_ids))
|
||||
.filter_by(archived=False)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# it is possible that the shared link was deleted. hence,
|
||||
# we check if the chat is still shared by checking if a chat with the share_id exists
|
||||
chat = db.query(Chat).filter_by(share_id=id).first()
|
||||
|
||||
if chat:
|
||||
return self.get_chat_by_id(id)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
# .limit(limit).offset(skip)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, pinned=True, archived=False)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, archived=True)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_user_id_and_search_text(
|
||||
self,
|
||||
user_id: str,
|
||||
search_text: str,
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
) -> list[ChatModel]:
|
||||
"""
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
"""
|
||||
search_text = search_text.lower().strip()
|
||||
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
|
||||
search_text_words = search_text.split(" ")
|
||||
|
||||
# search_text might contain 'tag:tag_name' format so we need to extract the tag_name, split the search_text and remove the tags
|
||||
tag_ids = [
|
||||
word.replace("tag:", "").replace(" ", "_").lower()
|
||||
for word in search_text_words
|
||||
if word.startswith("tag:")
|
||||
]
|
||||
|
||||
search_text_words = [
|
||||
word for word in search_text_words if not word.startswith("tag:")
|
||||
]
|
||||
|
||||
search_text = " ".join(search_text_words)
|
||||
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if not include_archived:
|
||||
query = query.filter(Chat.archived == False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||
dialect_name = db.bind.dialect.name
|
||||
if dialect_name == "sqlite":
|
||||
# SQLite case: using JSON1 extension for JSON searching
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.chat, '$.messages') AS message
|
||||
WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if "none" in tag_ids:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_each(Chat.meta, '$.tags') AS tag
|
||||
WHERE tag.value = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
elif dialect_name == "postgresql":
|
||||
# PostgreSQL relies on proper JSON query for search
|
||||
query = query.filter(
|
||||
(
|
||||
Chat.title.ilike(
|
||||
f"%{search_text}%"
|
||||
) # Case-insensitive search in title
|
||||
| text(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements(Chat.chat->'messages') AS message
|
||||
WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%'
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_text=search_text)
|
||||
)
|
||||
|
||||
# Check if there are any tags to filter, it should have all the tags
|
||||
if "none" in tag_ids:
|
||||
query = query.filter(
|
||||
text(
|
||||
"""
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
elif tag_ids:
|
||||
query = query.filter(
|
||||
and_(
|
||||
*[
|
||||
text(
|
||||
f"""
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM json_array_elements_text(Chat.meta->'tags') AS tag
|
||||
WHERE tag = :tag_id_{tag_idx}
|
||||
)
|
||||
"""
|
||||
).params(**{f"tag_id_{tag_idx}": tag_id})
|
||||
for tag_idx, tag_id in enumerate(tag_ids)
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Perform pagination at the SQL level
|
||||
all_chats = query.offset(skip).limit(limit).all()
|
||||
|
||||
log.info(f"The number of chats: {len(all_chats)}")
|
||||
|
||||
# Validate and return chats
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chats_by_folder_ids_and_user_id(
|
||||
self, folder_ids: list[str], user_id: str
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter(
|
||||
Chat.folder_id.in_(folder_ids), Chat.user_id == user_id
|
||||
)
|
||||
query = query.filter(or_(Chat.pinned == False, Chat.pinned == None))
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def update_chat_folder_id_by_id_and_user_id(
|
||||
self, id: str, user_id: str, folder_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.folder_id = folder_id
|
||||
chat.updated_at = int(time.time())
|
||||
chat.pinned = False
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
|
||||
|
||||
def get_chat_list_by_user_id_and_tag_name(
|
||||
self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
|
||||
) -> list[ChatModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Chat).filter_by(user_id=user_id)
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
log.info(f"DB dialect name: {db.bind.dialect.name}")
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 querying for tags within the meta JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
all_chats = query.all()
|
||||
log.debug(f"all_chats: {all_chats}")
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
) -> Optional[ChatModel]:
|
||||
tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
|
||||
if tag is None:
|
||||
tag = Tags.insert_new_tag(tag_name, user_id)
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
|
||||
tag_id = tag.id
|
||||
if tag_id not in chat.meta.get("tags", []):
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": list(set(chat.meta.get("tags", []) + [tag_id])),
|
||||
}
|
||||
|
||||
db.commit()
|
||||
db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
|
||||
with get_db() as db: # Assuming `get_db()` returns a session object
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=False)
|
||||
|
||||
# Normalize the tag_name for consistency
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
if db.bind.dialect.name == "sqlite":
|
||||
# SQLite JSON1 support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
elif db.bind.dialect.name == "postgresql":
|
||||
# PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
|
||||
query = query.filter(
|
||||
text(
|
||||
"EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
|
||||
)
|
||||
).params(tag_id=tag_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported dialect: {db.bind.dialect.name}"
|
||||
)
|
||||
|
||||
# Get the count of matching records
|
||||
count = query.count()
|
||||
|
||||
# Debugging output for inspection
|
||||
log.info(f"Count of chats for tag '{tag_name}': {count}")
|
||||
|
||||
return count
|
||||
|
||||
def delete_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = tag_name.replace(" ", "_").lower()
|
||||
|
||||
tags = [tag for tag in tags if tag != tag_id]
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": list(set(tags)),
|
||||
}
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chat = db.get(Chat, id)
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
"tags": [],
|
||||
}
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True and self.delete_shared_chat_by_chat_id(id)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True and self.delete_shared_chat_by_chat_id(id)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chats_by_user_id(self, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
self.delete_shared_chats_by_user_id(user_id)
|
||||
|
||||
db.query(Chat).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Chat).filter_by(user_id=user_id, folder_id=folder_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
|
||||
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
|
||||
|
||||
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Chats = ChatTable()
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Feedback DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Feedback(Base):
|
||||
__tablename__ = "feedback"
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
version = Column(BigInteger, default=0)
|
||||
type = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
snapshot = Column(JSON, nullable=True)
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FeedbackModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
version: int
|
||||
type: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
snapshot: Optional[dict] = None
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FeedbackResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
version: int
|
||||
type: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class RatingData(BaseModel):
|
||||
rating: Optional[str | int] = None
|
||||
model_id: Optional[str] = None
|
||||
sibling_model_ids: Optional[list[str]] = None
|
||||
reason: Optional[str] = None
|
||||
comment: Optional[str] = None
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
|
||||
|
||||
class MetaData(BaseModel):
|
||||
arena: Optional[bool] = None
|
||||
chat_id: Optional[str] = None
|
||||
message_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SnapshotData(BaseModel):
|
||||
chat: Optional[dict] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FeedbackForm(BaseModel):
|
||||
type: str
|
||||
data: Optional[RatingData] = None
|
||||
meta: Optional[dict] = None
|
||||
snapshot: Optional[SnapshotData] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FeedbackTable:
|
||||
def insert_new_feedback(
|
||||
self, user_id: str, form_data: FeedbackForm
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
feedback = FeedbackModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"version": 0,
|
||||
**form_data.model_dump(),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
result = Feedback(**feedback.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FeedbackModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new feedback: {e}")
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
return FeedbackModel.model_validate(feedback)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_feedback_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FeedbackModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
return FeedbackModel.model_validate(feedback)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_all_feedbacks(self) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_feedbacks_by_type(self, type: str) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.filter_by(type=type)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_feedbacks_by_user_id(self, user_id: str) -> list[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def update_feedback_by_id(
|
||||
self, id: str, form_data: FeedbackForm
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
||||
if form_data.data:
|
||||
feedback.data = form_data.data.model_dump()
|
||||
if form_data.meta:
|
||||
feedback.meta = form_data.meta
|
||||
if form_data.snapshot:
|
||||
feedback.snapshot = form_data.snapshot.model_dump()
|
||||
|
||||
feedback.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
return FeedbackModel.model_validate(feedback)
|
||||
|
||||
def update_feedback_by_id_and_user_id(
|
||||
self, id: str, user_id: str, form_data: FeedbackForm
|
||||
) -> Optional[FeedbackModel]:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return None
|
||||
|
||||
if form_data.data:
|
||||
feedback.data = form_data.data.model_dump()
|
||||
if form_data.meta:
|
||||
feedback.meta = form_data.meta
|
||||
if form_data.snapshot:
|
||||
feedback.snapshot = form_data.snapshot.model_dump()
|
||||
|
||||
feedback.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
return FeedbackModel.model_validate(feedback)
|
||||
|
||||
def delete_feedback_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
if not feedback:
|
||||
return False
|
||||
db.delete(feedback)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
return False
|
||||
db.delete(feedback)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedbacks_by_user_id(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
feedbacks = db.query(Feedback).filter_by(user_id=user_id).all()
|
||||
if not feedbacks:
|
||||
return False
|
||||
for feedback in feedbacks:
|
||||
db.delete(feedback)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_all_feedbacks(self) -> bool:
|
||||
with get_db() as db:
|
||||
feedbacks = db.query(Feedback).all()
|
||||
if not feedbacks:
|
||||
return False
|
||||
for feedback in feedbacks:
|
||||
db.delete(feedback)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Feedbacks = FeedbackTable()
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Files DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "file"
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
hash = Column(Text, nullable=True)
|
||||
|
||||
filename = Column(Text)
|
||||
path = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FileModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
hash: Optional[str] = None
|
||||
|
||||
filename: str
|
||||
path: Optional[str] = None
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: Optional[int] # timestamp in epoch
|
||||
updated_at: Optional[int] # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FileMeta(BaseModel):
|
||||
name: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FileModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
hash: Optional[str] = None
|
||||
|
||||
filename: str
|
||||
data: Optional[dict] = None
|
||||
meta: FileMeta
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FileMetadataResponse(BaseModel):
|
||||
id: str
|
||||
meta: dict
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FileForm(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
filename: str
|
||||
path: str
|
||||
data: dict = {}
|
||||
meta: dict = {}
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class FilesTable:
|
||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
file = FileModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = File(**file.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FileModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileMetadataResponse(
|
||||
id=file.id,
|
||||
meta=file.meta,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_files(self) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
|
||||
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File)
|
||||
.filter(File.id.in_(ids))
|
||||
.order_by(File.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FileMetadataResponse(
|
||||
id=file.id,
|
||||
meta=file.meta,
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
)
|
||||
for file in db.query(File)
|
||||
.filter(File.id.in_(ids))
|
||||
.order_by(File.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.hash = hash
|
||||
db.commit()
|
||||
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.data = {**(file.data if file.data else {}), **data}
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception as e:
|
||||
|
||||
return None
|
||||
|
||||
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.meta = {**(file.meta if file.meta else {}), **meta}
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_file_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(File).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_files(self) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(File).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Files = FilesTable()
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Folder DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Folder(Base):
|
||||
__tablename__ = "folder"
|
||||
id = Column(Text, primary_key=True)
|
||||
parent_id = Column(Text, nullable=True)
|
||||
user_id = Column(Text)
|
||||
name = Column(Text)
|
||||
items = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
is_expanded = Column(Boolean, default=False)
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FolderModel(BaseModel):
|
||||
id: str
|
||||
parent_id: Optional[str] = None
|
||||
user_id: str
|
||||
name: str
|
||||
items: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
is_expanded: bool = False
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FolderForm(BaseModel):
|
||||
name: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class FolderTable:
|
||||
def insert_new_folder(
|
||||
self, user_id: str, name: str, parent_id: Optional[str] = None
|
||||
) -> Optional[FolderModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
folder = FolderModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"name": name,
|
||||
"parent_id": parent_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
result = Folder(**folder.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FolderModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new folder: {e}")
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_children_folders_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folders = []
|
||||
|
||||
def get_children(folder):
|
||||
children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
)
|
||||
for child in children:
|
||||
get_children(child)
|
||||
folders.append(child)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
get_children(folder)
|
||||
return folders
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_folders_by_user_id(self, user_id: str) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
||||
]
|
||||
|
||||
def get_folder_by_parent_id_and_user_id_and_name(
|
||||
self, parent_id: Optional[str], user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Check if folder exists
|
||||
folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.filter(Folder.name.ilike(name))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
|
||||
return None
|
||||
|
||||
def get_folders_by_parent_id_and_user_id(
|
||||
self, parent_id: Optional[str], user_id: str
|
||||
) -> list[FolderModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.all()
|
||||
]
|
||||
|
||||
def update_folder_parent_id_by_id_and_user_id(
|
||||
self,
|
||||
id: str,
|
||||
user_id: str,
|
||||
parent_id: str,
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.parent_id = parent_id
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_name_by_id_and_user_id(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
existing_folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(name=name, parent_id=folder.parent_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_folder:
|
||||
return None
|
||||
|
||||
folder.name = name
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def update_folder_is_expanded_by_id_and_user_id(
|
||||
self, id: str, user_id: str, is_expanded: bool
|
||||
) -> Optional[FolderModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
|
||||
if not folder:
|
||||
return None
|
||||
|
||||
folder.is_expanded = is_expanded
|
||||
folder.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
|
||||
if not folder:
|
||||
return False
|
||||
|
||||
# Delete all chats in the folder
|
||||
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
|
||||
|
||||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
folder_children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id
|
||||
)
|
||||
for folder_child in folder_children:
|
||||
Chats.delete_chats_by_user_id_and_folder_id(
|
||||
user_id, folder_child.id
|
||||
)
|
||||
delete_children(folder_child)
|
||||
|
||||
folder = db.query(Folder).filter_by(id=folder_child.id).first()
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
|
||||
delete_children(folder)
|
||||
db.delete(folder)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_folder: {e}")
|
||||
return False
|
||||
|
||||
|
||||
Folders = FolderTable()
|
||||
|
|
@ -0,0 +1,274 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Functions DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Function(Base):
|
||||
__tablename__ = "function"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
name = Column(Text)
|
||||
type = Column(Text)
|
||||
content = Column(Text)
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
is_active = Column(Boolean)
|
||||
is_global = Column(Boolean)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
|
||||
|
||||
class FunctionModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
type: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool = False
|
||||
is_global: bool = False
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
type: str
|
||||
name: str
|
||||
meta: FunctionMeta
|
||||
is_active: bool
|
||||
is_global: bool
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FunctionForm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
|
||||
|
||||
class FunctionValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
|
||||
|
||||
class FunctionsTable:
|
||||
def insert_new_function(
|
||||
self, user_id: str, type: str, form_data: FunctionForm
|
||||
) -> Optional[FunctionModel]:
|
||||
function = FunctionModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"type": type,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = Function(**function.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return FunctionModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
return None
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
function = db.get(Function, id)
|
||||
return FunctionModel.model_validate(function)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_functions(self, active_only=False) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).filter_by(is_active=True).all()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).all()
|
||||
]
|
||||
|
||||
def get_functions_by_type(
|
||||
self, type: str, active_only=False
|
||||
) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type=type, is_active=True)
|
||||
.all()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).filter_by(type=type).all()
|
||||
]
|
||||
|
||||
def get_global_filter_functions(self) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type="filter", is_active=True, is_global=True)
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_global_action_functions(self) -> list[FunctionModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type="action", is_active=True, is_global=True)
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_function_valves_by_id(
|
||||
self, id: str, valves: dict
|
||||
) -> Optional[FunctionValves]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
function.valves = valves
|
||||
function.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(function)
|
||||
return self.get_function_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
|
||||
user_settings["functions"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Function).filter_by(id=id).update(
|
||||
{
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_function_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def deactivate_all_functions(self) -> Optional[bool]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Function).update(
|
||||
{
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_function_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Function).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Functions = FunctionsTable()
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# UserGroup DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "group"
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
permissions = Column(JSON, nullable=True)
|
||||
user_ids = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class GroupModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
permissions: Optional[dict] = None
|
||||
user_ids: list[str] = []
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
description: str
|
||||
permissions: Optional[dict] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
user_ids: list[str] = []
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class GroupForm(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
permissions: Optional[dict] = None
|
||||
|
||||
|
||||
class GroupUpdateForm(GroupForm):
|
||||
user_ids: Optional[list[str]] = None
|
||||
|
||||
|
||||
class GroupTable:
|
||||
def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm
|
||||
) -> Optional[GroupModel]:
|
||||
with get_db() as db:
|
||||
group = GroupModel(
|
||||
**{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Group(**group.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return GroupModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_groups(self) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group)
|
||||
.filter(
|
||||
func.json_array_length(Group.user_ids) > 0
|
||||
) # Ensure array exists
|
||||
.filter(
|
||||
Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
||||
) # String-based check
|
||||
.order_by(Group.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
return GroupModel.model_validate(group) if group else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = self.get_group_by_id(id)
|
||||
if group:
|
||||
return group.user_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Group).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_group_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_group_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Group).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_groups(self) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Group).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
groups = self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in groups:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Groups = GroupTable()
|
||||
|
|
@ -0,0 +1,221 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Knowledge DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Knowledge(Base):
|
||||
__tablename__ = "knowledge"
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
# - `None`: Public access, available to all users with the "user" role.
|
||||
# - `{}`: Private access, restricted exclusively to the owner.
|
||||
# - Custom permissions: Specific access control for reading and writing;
|
||||
# Can specify group or user-level restrictions:
|
||||
# {
|
||||
# "read": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# },
|
||||
# "write": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# }
|
||||
# }
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class KnowledgeModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class KnowledgeUserModel(KnowledgeModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class KnowledgeResponse(KnowledgeModel):
|
||||
files: Optional[list[FileMetadataResponse | dict]] = None
|
||||
|
||||
|
||||
class KnowledgeUserResponse(KnowledgeUserModel):
|
||||
files: Optional[list[FileMetadataResponse | dict]] = None
|
||||
|
||||
|
||||
class KnowledgeForm(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
data: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class KnowledgeTable:
|
||||
def insert_new_knowledge(
|
||||
self, user_id: str, form_data: KnowledgeForm
|
||||
) -> Optional[KnowledgeModel]:
|
||||
with get_db() as db:
|
||||
knowledge = KnowledgeModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Knowledge(**knowledge.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return KnowledgeModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
|
||||
with get_db() as db:
|
||||
knowledge_bases = []
|
||||
for knowledge in (
|
||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||
):
|
||||
user = Users.get_user_by_id(knowledge.user_id)
|
||||
knowledge_bases.append(
|
||||
KnowledgeUserModel.model_validate(
|
||||
{
|
||||
**KnowledgeModel.model_validate(knowledge).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
return knowledge_bases
|
||||
|
||||
def get_knowledge_bases_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[KnowledgeUserModel]:
|
||||
knowledge_bases = self.get_knowledge_bases()
|
||||
return [
|
||||
knowledge_base
|
||||
for knowledge_base in knowledge_bases
|
||||
if knowledge_base.user_id == user_id
|
||||
or has_access(user_id, permission, knowledge_base.access_control)
|
||||
]
|
||||
|
||||
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||
return KnowledgeModel.model_validate(knowledge) if knowledge else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_knowledge_by_id(
|
||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def update_knowledge_data_by_id(
|
||||
self, id: str, data: dict
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
"data": data,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_knowledge_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Knowledge).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_knowledge(self) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Knowledge).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Knowledges = KnowledgeTable()
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
####################
|
||||
# Memory DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
__tablename__ = "memory"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
content = Column(Text)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class MemoryModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
content: str
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MemoriesTable:
|
||||
def insert_new_memory(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
memory = MemoryModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
result = Memory(**memory.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return MemoryModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_memory_by_id(
|
||||
self,
|
||||
id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id).update(
|
||||
{"content": content, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_memory_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memories(self) -> list[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
memories = db.query(Memory).all()
|
||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
||||
return [MemoryModel.model_validate(memory) for memory in memories]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
return MemoryModel.model_validate(memory)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_memory_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memories_by_user_id(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Memories = MemoriesTable()
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Message DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class MessageReaction(Base):
|
||||
__tablename__ = "message_reaction"
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
message_id = Column(Text)
|
||||
name = Column(Text)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class MessageReactionModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
message_id: str
|
||||
name: str
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "message"
|
||||
id = Column(Text, primary_key=True)
|
||||
|
||||
user_id = Column(Text)
|
||||
channel_id = Column(Text, nullable=True)
|
||||
|
||||
parent_id = Column(Text, nullable=True)
|
||||
|
||||
content = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger) # time_ns
|
||||
updated_at = Column(BigInteger) # time_ns
|
||||
|
||||
|
||||
class MessageModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
channel_id: Optional[str] = None
|
||||
|
||||
parent_id: Optional[str] = None
|
||||
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
parent_id: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
|
||||
class Reactions(BaseModel):
|
||||
name: str
|
||||
user_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
class MessageResponse(MessageModel):
|
||||
latest_reply_at: Optional[int]
|
||||
reply_count: int
|
||||
reactions: list[Reactions]
|
||||
|
||||
|
||||
class MessageTable:
|
||||
def insert_new_message(
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
ts = int(time.time_ns())
|
||||
message = MessageModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"channel_id": channel_id,
|
||||
"parent_id": form_data.parent_id,
|
||||
"content": form_data.content,
|
||||
"data": form_data.data,
|
||||
"meta": form_data.meta,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
}
|
||||
)
|
||||
|
||||
result = Message(**message.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return MessageModel.model_validate(result) if result else None
|
||||
|
||||
def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
reactions = self.get_reactions_by_message_id(id)
|
||||
replies = self.get_replies_by_message_id(id)
|
||||
|
||||
return MessageResponse(
|
||||
**{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"latest_reply_at": replies[0].created_at if replies else None,
|
||||
"reply_count": len(replies),
|
||||
"reactions": reactions,
|
||||
}
|
||||
)
|
||||
|
||||
def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(parent_id=id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
message.user_id
|
||||
for message in db.query(Message).filter_by(parent_id=id).all()
|
||||
]
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, parent_id=None)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def get_messages_by_parent_id(
|
||||
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, parent_id)
|
||||
|
||||
if not message:
|
||||
return []
|
||||
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id, parent_id=parent_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# If length of all_messages is less than limit, then add the parent message
|
||||
if len(all_messages) < limit:
|
||||
all_messages.append(message)
|
||||
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def update_message_by_id(
|
||||
self, id: str, form_data: MessageForm
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
message.content = form_data.content
|
||||
message.data = form_data.data
|
||||
message.meta = form_data.meta
|
||||
message.updated_at = int(time.time_ns())
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def add_reaction_to_message(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> Optional[MessageReactionModel]:
|
||||
with get_db() as db:
|
||||
reaction_id = str(uuid.uuid4())
|
||||
reaction = MessageReactionModel(
|
||||
id=reaction_id,
|
||||
user_id=user_id,
|
||||
message_id=id,
|
||||
name=name,
|
||||
created_at=int(time.time_ns()),
|
||||
)
|
||||
result = MessageReaction(**reaction.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return MessageReactionModel.model_validate(result) if result else None
|
||||
|
||||
def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
|
||||
with get_db() as db:
|
||||
all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
|
||||
|
||||
reactions = {}
|
||||
for reaction in all_reactions:
|
||||
if reaction.name not in reactions:
|
||||
reactions[reaction.name] = {
|
||||
"name": reaction.name,
|
||||
"user_ids": [],
|
||||
"count": 0,
|
||||
}
|
||||
reactions[reaction.name]["user_ids"].append(reaction.user_id)
|
||||
reactions[reaction.name]["count"] += 1
|
||||
|
||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||
|
||||
def remove_reaction_by_id_and_user_id_and_name(
|
||||
self, id: str, user_id: str, name: str
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(MessageReaction).filter_by(
|
||||
message_id=id, user_id=user_id, name=name
|
||||
).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_reactions_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(MessageReaction).filter_by(message_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_replies_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(Message).filter_by(parent_id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_message_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(Message).filter_by(id=id).delete()
|
||||
|
||||
# Delete all reactions to this message
|
||||
db.query(MessageReaction).filter_by(message_id=id).delete()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Messages = MessageTable()
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from sqlalchemy import or_, and_, func
|
||||
from sqlalchemy.dialects import postgresql, sqlite
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Models DB Schema
|
||||
####################
|
||||
|
||||
|
||||
# ModelParams is a model for the data stored in the params field of the Model table
|
||||
class ModelParams(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
# ModelMeta is a model for the data stored in the meta field of the Model table
|
||||
class ModelMeta(BaseModel):
|
||||
profile_image_url: Optional[str] = "/static/favicon.png"
|
||||
|
||||
description: Optional[str] = None
|
||||
"""
|
||||
User-facing description of the model.
|
||||
"""
|
||||
|
||||
capabilities: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "model"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
"""
|
||||
The model's id as used in the API. If set to an existing model, it will override the model.
|
||||
"""
|
||||
user_id = Column(Text)
|
||||
|
||||
base_model_id = Column(Text, nullable=True)
|
||||
"""
|
||||
An optional pointer to the actual model that should be used when proxying requests.
|
||||
"""
|
||||
|
||||
name = Column(Text)
|
||||
"""
|
||||
The human-readable display name of the model.
|
||||
"""
|
||||
|
||||
params = Column(JSONField)
|
||||
"""
|
||||
Holds a JSON encoded blob of parameters, see `ModelParams`.
|
||||
"""
|
||||
|
||||
meta = Column(JSONField)
|
||||
"""
|
||||
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
||||
"""
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
# - `None`: Public access, available to all users with the "user" role.
|
||||
# - `{}`: Private access, restricted exclusively to the owner.
|
||||
# - Custom permissions: Specific access control for reading and writing;
|
||||
# Can specify group or user-level restrictions:
|
||||
# {
|
||||
# "read": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# },
|
||||
# "write": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# }
|
||||
# }
|
||||
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class ModelModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
base_model_id: Optional[str] = None
|
||||
|
||||
name: str
|
||||
params: ModelParams
|
||||
meta: ModelMeta
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
is_active: bool
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ModelUserResponse(ModelModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class ModelResponse(ModelModel):
|
||||
pass
|
||||
|
||||
|
||||
class ModelForm(BaseModel):
|
||||
id: str
|
||||
base_model_id: Optional[str] = None
|
||||
name: str
|
||||
meta: ModelMeta
|
||||
params: ModelParams
|
||||
access_control: Optional[dict] = None
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class ModelsTable:
|
||||
def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str
|
||||
) -> Optional[ModelModel]:
|
||||
model = ModelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = Model(**model.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
|
||||
if result:
|
||||
return ModelModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to insert a new model: {e}")
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> list[ModelModel]:
|
||||
with get_db() as db:
|
||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
||||
|
||||
def get_models(self) -> list[ModelUserResponse]:
|
||||
with get_db() as db:
|
||||
models = []
|
||||
for model in db.query(Model).filter(Model.base_model_id != None).all():
|
||||
user = Users.get_user_by_id(model.user_id)
|
||||
models.append(
|
||||
ModelUserResponse.model_validate(
|
||||
{
|
||||
**ModelModel.model_validate(model).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
def get_base_models(self) -> list[ModelModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
ModelModel.model_validate(model)
|
||||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ModelUserResponse]:
|
||||
models = self.get_models()
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
if model.user_id == user_id
|
||||
or has_access(user_id, permission, model.access_control)
|
||||
]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
model = db.get(Model, id)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
with get_db() as db:
|
||||
try:
|
||||
is_active = db.query(Model).filter_by(id=id).first().is_active
|
||||
|
||||
db.query(Model).filter_by(id=id).update(
|
||||
{
|
||||
"is_active": not is_active,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return self.get_model_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
# update only the fields that are present in the model
|
||||
result = (
|
||||
db.query(Model)
|
||||
.filter_by(id=id)
|
||||
.update(model.model_dump(exclude={"id"}))
|
||||
)
|
||||
db.commit()
|
||||
|
||||
model = db.get(Model, id)
|
||||
db.refresh(model)
|
||||
return ModelModel.model_validate(model)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Model).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_models(self) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Model).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Models = ModelsTable()
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
|
||||
command = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
content = Column(Text)
|
||||
timestamp = Column(BigInteger)
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
# - `None`: Public access, available to all users with the "user" role.
|
||||
# - `{}`: Private access, restricted exclusively to the owner.
|
||||
# - Custom permissions: Specific access control for reading and writing;
|
||||
# Can specify group or user-level restrictions:
|
||||
# {
|
||||
# "read": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# },
|
||||
# "write": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
class PromptModel(BaseModel):
|
||||
command: str
|
||||
user_id: str
|
||||
title: str
|
||||
content: str
|
||||
timestamp: int # timestamp in epoch
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class PromptUserResponse(PromptModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class PromptForm(BaseModel):
|
||||
command: str
|
||||
title: str
|
||||
content: str
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class PromptsTable:
|
||||
def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
prompt = PromptModel(
|
||||
**{
|
||||
"user_id": user_id,
|
||||
**form_data.model_dump(),
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = Prompt(**prompt.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return PromptModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompts(self) -> list[PromptUserResponse]:
|
||||
with get_db() as db:
|
||||
prompts = []
|
||||
|
||||
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
|
||||
user = Users.get_user_by_id(prompt.user_id)
|
||||
prompts.append(
|
||||
PromptUserResponse.model_validate(
|
||||
{
|
||||
**PromptModel.model_validate(prompt).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return prompts
|
||||
|
||||
def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[PromptUserResponse]:
|
||||
prompts = self.get_prompts()
|
||||
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or has_access(user_id, permission, prompt.access_control)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
prompt.title = form_data.title
|
||||
prompt.content = form_data.content
|
||||
prompt.access_control = form_data.access_control
|
||||
prompt.timestamp = int(time.time())
|
||||
db.commit()
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_prompt_by_command(self, command: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Prompt).filter_by(command=command).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Prompts = PromptsTable()
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Tag DB Schema
|
||||
####################
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
id = Column(String)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
||||
|
||||
|
||||
class TagModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
user_id: str
|
||||
meta: Optional[dict] = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class TagChatIdForm(BaseModel):
|
||||
name: str
|
||||
chat_id: str
|
||||
|
||||
|
||||
class TagTable:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
with get_db() as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
result = Tag(**tag.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return TagModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new tag: {e}")
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str
|
||||
) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(" ", "_").lower()
|
||||
with get_db() as db:
|
||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
return TagModel.model_validate(tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str
|
||||
) -> list[TagModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
||||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
log.debug(f"res: {res}")
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
|
||||
Tags = TagTable()
|
||||
|
|
@ -0,0 +1,262 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Tools DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
__tablename__ = "tool"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
name = Column(Text)
|
||||
content = Column(Text)
|
||||
specs = Column(JSONField)
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
|
||||
access_control = Column(JSON, nullable=True) # Controls data access levels.
|
||||
# Defines access control rules for this entry.
|
||||
# - `None`: Public access, available to all users with the "user" role.
|
||||
# - `{}`: Private access, restricted exclusively to the owner.
|
||||
# - Custom permissions: Specific access control for reading and writing;
|
||||
# Can specify group or user-level restrictions:
|
||||
# {
|
||||
# "read": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# },
|
||||
# "write": {
|
||||
# "group_ids": ["group_id1", "group_id2"],
|
||||
# "user_ids": ["user_id1", "user_id2"]
|
||||
# }
|
||||
# }
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
|
||||
|
||||
class ToolModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
content: str
|
||||
specs: list[dict]
|
||||
meta: ToolMeta
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ToolUserModel(ToolModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
meta: ToolMeta
|
||||
access_control: Optional[dict] = None
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class ToolUserResponse(ToolResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class ToolForm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
meta: ToolMeta
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class ToolValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
|
||||
|
||||
class ToolsTable:
|
||||
def insert_new_tool(
|
||||
self, user_id: str, form_data: ToolForm, specs: list[dict]
|
||||
) -> Optional[ToolModel]:
|
||||
with get_db() as db:
|
||||
tool = ToolModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"specs": specs,
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Tool(**tool.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return ToolModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
tool = db.get(Tool, id)
|
||||
return ToolModel.model_validate(tool)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tools(self) -> list[ToolUserModel]:
|
||||
with get_db() as db:
|
||||
tools = []
|
||||
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
|
||||
user = Users.get_user_by_id(tool.user_id)
|
||||
tools.append(
|
||||
ToolUserModel.model_validate(
|
||||
{
|
||||
**ToolModel.model_validate(tool).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools()
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user_id
|
||||
or has_access(user_id, permission, tool.access_control)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{"valves": valves, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_tool_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
|
||||
user_settings["tools"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{**updated, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
tool = db.query(Tool).get(id)
|
||||
db.refresh(tool)
|
||||
return ToolModel.model_validate(tool)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_tool_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Tools = ToolsTable()
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
|
||||
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String)
|
||||
email = Column(String)
|
||||
role = Column(String)
|
||||
profile_image_url = Column(Text)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
api_key = Column(String, nullable=True, unique=True)
|
||||
settings = Column(JSONField, nullable=True)
|
||||
info = Column(JSONField, nullable=True)
|
||||
|
||||
oauth_sub = Column(Text, unique=True)
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
profile_image_url: str
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: Optional[dict] = None
|
||||
|
||||
oauth_sub: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserNameResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserRoleUpdateForm(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
|
||||
|
||||
class UserUpdateForm(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
profile_image_url: str
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class UsersTable:
|
||||
def insert_new_user(
|
||||
self,
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db() as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
"role": role,
|
||||
"profile_image_url": profile_image_url,
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"oauth_sub": oauth_sub,
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
if result:
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_user_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(api_key=api_key).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(email=email).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_users(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
|
||||
query = db.query(User).order_by(User.created_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
users = query.all()
|
||||
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self) -> Optional[int]:
|
||||
with get_db() as db:
|
||||
return db.query(User).count()
|
||||
|
||||
def get_first_user(self) -> UserModel:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).order_by(User.created_at).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
user.settings.get("ui", {})
|
||||
.get("notifications", {})
|
||||
.get("webhook_url", None)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update({"role": role})
|
||||
db.commit()
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_profile_image_url_by_id(
|
||||
self, id: str, profile_image_url: str
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"profile_image_url": profile_image_url}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_oauth_sub_by_id(
|
||||
self, id: str, oauth_sub: str
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(User).filter_by(id=id).update(updated)
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
# return UserModel(**user.dict())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user_settings = db.query(User).filter_by(id=id).first().settings
|
||||
|
||||
if user_settings is None:
|
||||
user_settings = {}
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_user_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
# Remove User from Groups
|
||||
Groups.remove_user_from_all_groups(id)
|
||||
|
||||
# Delete User Chats
|
||||
result = Chats.delete_chats_by_user_id(id)
|
||||
if result:
|
||||
with get_db() as db:
|
||||
# Delete User
|
||||
db.query(User).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return user.api_key
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_valid_user_ids(self, user_ids: list[str]) -> list[str]:
|
||||
with get_db() as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.id for user in users]
|
||||
|
||||
|
||||
Users = UsersTable()
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
OutlookMessageLoader,
|
||||
PyPDFLoader,
|
||||
TextLoader,
|
||||
UnstructuredEPubLoader,
|
||||
UnstructuredExcelLoader,
|
||||
UnstructuredMarkdownLoader,
|
||||
UnstructuredPowerPointLoader,
|
||||
UnstructuredRSTLoader,
|
||||
UnstructuredXMLLoader,
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
known_source_ext = [
|
||||
"go",
|
||||
"py",
|
||||
"java",
|
||||
"sh",
|
||||
"bat",
|
||||
"ps1",
|
||||
"cmd",
|
||||
"js",
|
||||
"ts",
|
||||
"css",
|
||||
"cpp",
|
||||
"hpp",
|
||||
"h",
|
||||
"c",
|
||||
"cs",
|
||||
"sql",
|
||||
"log",
|
||||
"ini",
|
||||
"pl",
|
||||
"pm",
|
||||
"r",
|
||||
"dart",
|
||||
"dockerfile",
|
||||
"env",
|
||||
"php",
|
||||
"hs",
|
||||
"hsc",
|
||||
"lua",
|
||||
"nginxconf",
|
||||
"conf",
|
||||
"m",
|
||||
"mm",
|
||||
"plsql",
|
||||
"perl",
|
||||
"rb",
|
||||
"rs",
|
||||
"db2",
|
||||
"scala",
|
||||
"bash",
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"msg",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
"tsx",
|
||||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
]
|
||||
|
||||
|
||||
class TikaLoader:
|
||||
def __init__(self, url, file_path, mime_type=None):
|
||||
self.url = url
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
if self.mime_type is not None:
|
||||
headers = {"Content-Type": self.mime_type}
|
||||
else:
|
||||
headers = {}
|
||||
|
||||
endpoint = self.url
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
endpoint += "tika/text"
|
||||
|
||||
r = requests.put(endpoint, data=data, headers=headers)
|
||||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
|
||||
log.debug("Tika extracted text: %s", text)
|
||||
|
||||
return [Document(page_content=text, metadata=headers)]
|
||||
else:
|
||||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
self.engine = engine
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load(
|
||||
self, filename: str, file_content_type: str, file_path: str
|
||||
) -> list[Document]:
|
||||
loader = self._get_loader(filename, file_content_type, file_path)
|
||||
docs = loader.load()
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
|
||||
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TikaLoader(
|
||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
|
||||
)
|
||||
elif file_ext == "csv":
|
||||
loader = CSVLoader(file_path)
|
||||
elif file_ext == "rst":
|
||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
||||
elif file_ext == "xml":
|
||||
loader = UnstructuredXMLLoader(file_path)
|
||||
elif file_ext in ["htm", "html"]:
|
||||
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
||||
elif file_ext == "md":
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_content_type == "application/epub+zip":
|
||||
loader = UnstructuredEPubLoader(file_path)
|
||||
elif (
|
||||
file_content_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
or file_ext == "docx"
|
||||
):
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
] or file_ext in ["ppt", "pptx"]:
|
||||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return loader
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import logging
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
ALLOWED_SCHEMES = {"http", "https"}
|
||||
ALLOWED_NETLOCS = {
|
||||
"youtu.be",
|
||||
"m.youtube.com",
|
||||
"youtube.com",
|
||||
"www.youtube.com",
|
||||
"www.youtube-nocookie.com",
|
||||
"vid.plus",
|
||||
}
|
||||
|
||||
|
||||
def _parse_video_id(url: str) -> Optional[str]:
|
||||
"""Parse a YouTube URL and return the video ID if valid, otherwise None."""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme not in ALLOWED_SCHEMES:
|
||||
return None
|
||||
|
||||
if parsed_url.netloc not in ALLOWED_NETLOCS:
|
||||
return None
|
||||
|
||||
path = parsed_url.path
|
||||
|
||||
if path.endswith("/watch"):
|
||||
query = parsed_url.query
|
||||
parsed_query = parse_qs(query)
|
||||
if "v" in parsed_query:
|
||||
ids = parsed_query["v"]
|
||||
video_id = ids if isinstance(ids, str) else ids[0]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
path = parsed_url.path.lstrip("/")
|
||||
video_id = path.split("/")[-1]
|
||||
|
||||
if len(video_id) != 11: # Video IDs are 11 characters long
|
||||
return None
|
||||
|
||||
return video_id
|
||||
|
||||
|
||||
class YoutubeLoader:
|
||||
"""Load `YouTube` video transcripts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_id: str,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
proxy_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self.language = language
|
||||
self.proxy_url = proxy_url
|
||||
if isinstance(language, str):
|
||||
self.language = [language]
|
||||
else:
|
||||
self.language = language
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load YouTube transcripts into `Document` objects."""
|
||||
try:
|
||||
from youtube_transcript_api import (
|
||||
NoTranscriptFound,
|
||||
TranscriptsDisabled,
|
||||
YouTubeTranscriptApi,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Could not import "youtube_transcript_api" Python package. '
|
||||
"Please install it with `pip install youtube-transcript-api`."
|
||||
)
|
||||
|
||||
if self.proxy_url:
|
||||
youtube_proxies = {
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
# Don't log complete URL because it might contain secrets
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
else:
|
||||
youtube_proxies = None
|
||||
|
||||
try:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(
|
||||
self.video_id, proxies=youtube_proxies
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Loading YouTube transcript failed")
|
||||
return []
|
||||
|
||||
try:
|
||||
transcript = transcript_list.find_transcript(self.language)
|
||||
except NoTranscriptFound:
|
||||
transcript = transcript_list.find_transcript(["en"])
|
||||
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
|
||||
transcript = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: transcript_piece["text"].strip(" "),
|
||||
transcript_pieces,
|
||||
)
|
||||
)
|
||||
return [Document(page_content=transcript, metadata=self._metadata)]
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT:
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
if DOCKER:
|
||||
# This is a workaround for the issue with the docker container
|
||||
# where the torch extension is not loaded properly
|
||||
# and the following error is thrown:
|
||||
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
||||
|
||||
lock_file = (
|
||||
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
|
||||
)
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
|
||||
self.ckpt = Checkpoint(
|
||||
name,
|
||||
colbert_config=ColBERTConfig(model_name=name),
|
||||
).to(self.device)
|
||||
pass
|
||||
|
||||
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
||||
|
||||
query_embeddings = query_embeddings.to(self.device)
|
||||
document_embeddings = document_embeddings.to(self.device)
|
||||
|
||||
# Validate dimensions to ensure compatibility
|
||||
if query_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
||||
)
|
||||
if document_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
||||
)
|
||||
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||
raise ValueError(
|
||||
"There should be either one query or queries equal to the number of documents."
|
||||
)
|
||||
|
||||
# Transpose the query embeddings to align for matrix multiplication
|
||||
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||
# Compute similarity scores using batch matrix multiplication
|
||||
computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
|
||||
# Apply max pooling to extract the highest semantic similarity across each document's sequence
|
||||
maximum_scores = torch.max(computed_scores, dim=1).values
|
||||
|
||||
# Sum up the maximum scores across features to get the overall document relevance scores
|
||||
final_scores = maximum_scores.sum(dim=1)
|
||||
|
||||
normalized_scores = torch.softmax(final_scores, dim=0)
|
||||
|
||||
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
def predict(self, sentences):
|
||||
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
# Embedding the documents
|
||||
embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
|
||||
# Embedding the queries
|
||||
embedded_queries = self.ckpt.queryFromText([query], bsize=32)
|
||||
embedded_query = embedded_queries[0]
|
||||
|
||||
# Calculate retrieval scores for the query against all documents
|
||||
scores = self.calculate_similarity_scores(
|
||||
embedded_query.unsqueeze(0), embedded_docs
|
||||
)
|
||||
|
||||
return scores
|
||||
|
|
@ -0,0 +1,699 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import asyncio
|
||||
import requests
|
||||
import hashlib
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
OFFLINE_MODE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorSearchRetriever(BaseRetriever):
|
||||
collection_name: Any
|
||||
embedding_function: Any
|
||||
top_k: int
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> list[Document]:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=self.collection_name,
|
||||
vectors=[self.embedding_function(query)],
|
||||
limit=self.top_k,
|
||||
)
|
||||
|
||||
ids = result.ids[0]
|
||||
metadatas = result.metadatas[0]
|
||||
documents = result.documents[0]
|
||||
|
||||
results = []
|
||||
for idx in range(len(ids)):
|
||||
results.append(
|
||||
Document(
|
||||
metadata=metadatas[idx],
|
||||
page_content=documents[idx],
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def query_doc(
|
||||
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
|
||||
):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.search(
|
||||
collection_name=collection_name,
|
||||
vectors=[query_embedding],
|
||||
limit=k,
|
||||
)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_doc(collection_name: str, user: UserModel = None):
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting doc {collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def query_doc_with_hybrid_search(
|
||||
collection_name: str,
|
||||
query: str,
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
) -> dict:
|
||||
try:
|
||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||
|
||||
bm25_retriever = BM25Retriever.from_texts(
|
||||
texts=result.documents[0],
|
||||
metadatas=result.metadatas[0],
|
||||
)
|
||||
bm25_retriever.k = k
|
||||
|
||||
vector_search_retriever = VectorSearchRetriever(
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
top_k=k,
|
||||
)
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
compressor = RerankCompressor(
|
||||
embedding_function=embedding_function,
|
||||
top_n=k,
|
||||
reranking_function=reranking_function,
|
||||
r_score=r,
|
||||
)
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor, base_retriever=ensemble_retriever
|
||||
)
|
||||
|
||||
result = compression_retriever.invoke(query)
|
||||
result = {
|
||||
"distances": [[d.metadata.get("score") for d in result]],
|
||||
"documents": [[d.page_content for d in result]],
|
||||
"metadatas": [[d.metadata for d in result]],
|
||||
}
|
||||
|
||||
log.info(
|
||||
"query_doc_with_hybrid_search:result "
|
||||
+ f'{result["metadatas"]} {result["distances"]}'
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def merge_get_results(get_results: list[dict]) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined_documents = []
|
||||
combined_metadatas = []
|
||||
combined_ids = []
|
||||
|
||||
for data in get_results:
|
||||
combined_documents.extend(data["documents"][0])
|
||||
combined_metadatas.extend(data["metadatas"][0])
|
||||
combined_ids.extend(data["ids"][0])
|
||||
|
||||
# Create the output dictionary
|
||||
result = {
|
||||
"documents": [combined_documents],
|
||||
"metadatas": [combined_metadatas],
|
||||
"ids": [combined_ids],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def merge_and_sort_query_results(
|
||||
query_results: list[dict], k: int, reverse: bool = False
|
||||
) -> dict:
|
||||
# Initialize lists to store combined data
|
||||
combined = []
|
||||
seen_hashes = set() # To store unique document hashes
|
||||
|
||||
for data in query_results:
|
||||
distances = data["distances"][0]
|
||||
documents = data["documents"][0]
|
||||
metadatas = data["metadatas"][0]
|
||||
|
||||
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||
if isinstance(document, str):
|
||||
doc_hash = hashlib.md5(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
if doc_hash not in seen_hashes:
|
||||
seen_hashes.add(doc_hash)
|
||||
combined.append((distance, document, metadata))
|
||||
|
||||
# Sort the list based on distances
|
||||
combined.sort(key=lambda x: x[0], reverse=reverse)
|
||||
|
||||
# Slice to keep only the top k elements
|
||||
sorted_distances, sorted_documents, sorted_metadatas = (
|
||||
zip(*combined[:k]) if combined else ([], [], [])
|
||||
)
|
||||
|
||||
# Create and return the output dictionary
|
||||
return {
|
||||
"distances": [list(sorted_distances)],
|
||||
"documents": [list(sorted_documents)],
|
||||
"metadatas": [list(sorted_metadatas)],
|
||||
}
|
||||
|
||||
|
||||
def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = get_doc(collection_name=collection_name)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
return merge_get_results(results)
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
for query in queries:
|
||||
query_embedding = embedding_function(query)
|
||||
for collection_name in collection_names:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def query_collection_with_hybrid_search(
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
k: int,
|
||||
reranking_function,
|
||||
r: float,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
for query in queries:
|
||||
result = query_doc_with_hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Error when querying the collection with " f"hybrid_search: {e}"
|
||||
)
|
||||
error = True
|
||||
|
||||
if error:
|
||||
raise Exception(
|
||||
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if VECTOR_DB == "chroma":
|
||||
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
|
||||
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
|
||||
return merge_and_sort_query_results(results, k=k, reverse=False)
|
||||
else:
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
embedding_function,
|
||||
url,
|
||||
key,
|
||||
embedding_batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query, user=None: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
func = lambda query, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def generate_multiple(query, user, func):
|
||||
if isinstance(query, list):
|
||||
embeddings = []
|
||||
for i in range(0, len(query), embedding_batch_size):
|
||||
embeddings.extend(
|
||||
func(query[i : i + embedding_batch_size], user=user)
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
return func(query, user)
|
||||
|
||||
return lambda query, user=None: generate_multiple(query, user, func)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
|
||||
|
||||
|
||||
def get_sources_from_files(
|
||||
request,
|
||||
files,
|
||||
queries,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
r,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
log.debug(
|
||||
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
|
||||
)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for file in files:
|
||||
|
||||
context = None
|
||||
if file.get("docs"):
|
||||
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
context = {
|
||||
"documents": [[doc.get("content") for doc in file.get("docs")]],
|
||||
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
|
||||
}
|
||||
elif file.get("context") == "full":
|
||||
# Manual Full Mode Toggle
|
||||
context = {
|
||||
"documents": [[file.get("file").get("data", {}).get("content")]],
|
||||
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
|
||||
}
|
||||
elif (
|
||||
file.get("type") != "web_search"
|
||||
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
# BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
if file.get("type") == "collection":
|
||||
file_ids = file.get("data", {}).get("file_ids", [])
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
for file_id in file_ids:
|
||||
file_object = Files.get_file_by_id(file_id)
|
||||
|
||||
if file_object:
|
||||
documents.append(file_object.data.get("content", ""))
|
||||
metadatas.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
)
|
||||
|
||||
context = {
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
|
||||
elif file.get("id"):
|
||||
file_object = Files.get_file_by_id(file.get("id"))
|
||||
if file_object:
|
||||
context = {
|
||||
"documents": [[file_object.data.get("content", "")]],
|
||||
"metadatas": [
|
||||
[
|
||||
{
|
||||
"file_id": file.get("id"),
|
||||
"name": file_object.filename,
|
||||
"source": file_object.filename,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
else:
|
||||
collection_names = []
|
||||
if file.get("type") == "collection":
|
||||
if file.get("legacy"):
|
||||
collection_names = file.get("collection_names", [])
|
||||
else:
|
||||
collection_names.append(file["id"])
|
||||
elif file.get("collection_name"):
|
||||
collection_names.append(file["collection_name"])
|
||||
elif file.get("id"):
|
||||
if file.get("legacy"):
|
||||
collection_names.append(f"{file['id']}")
|
||||
else:
|
||||
collection_names.append(f"file-{file['id']}")
|
||||
|
||||
collection_names = set(collection_names).difference(extracted_collections)
|
||||
if not collection_names:
|
||||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
if full_context:
|
||||
try:
|
||||
context = get_all_items_from_collections(collection_names)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
else:
|
||||
try:
|
||||
context = None
|
||||
if file.get("type") == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
try:
|
||||
context = query_collection_with_hybrid_search(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
reranking_function=reranking_function,
|
||||
r=r,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Error when using hybrid search, using"
|
||||
" non hybrid search as fallback."
|
||||
)
|
||||
|
||||
if (not hybrid_search) or (context is None):
|
||||
context = query_collection(
|
||||
collection_names=collection_names,
|
||||
queries=queries,
|
||||
embedding_function=embedding_function,
|
||||
k=k,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
if context:
|
||||
if "data" in file:
|
||||
del file["data"]
|
||||
|
||||
relevant_contexts.append({**context, "file": file})
|
||||
|
||||
sources = []
|
||||
for context in relevant_contexts:
|
||||
try:
|
||||
if "documents" in context:
|
||||
if "metadatas" in context:
|
||||
source = {
|
||||
"source": context["file"],
|
||||
"document": context["documents"][0],
|
||||
"metadata": context["metadatas"][0],
|
||||
}
|
||||
if "distances" in context and context["distances"]:
|
||||
source["distances"] = context["distances"][0]
|
||||
|
||||
sources.append(source)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_model
|
||||
|
||||
if OFFLINE_MODE:
|
||||
local_files_only = True
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"model: {model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(model)
|
||||
or ("\\" in model or model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return model
|
||||
elif "/" not in model:
|
||||
# Set valid repo_id for model short-name
|
||||
model = "sentence-transformers" + "/" + model
|
||||
|
||||
snapshot_kwargs["repo_id"] = model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"model_repo_path: {model_repo_path}")
|
||||
return model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine model snapshot path: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str = "https://api.openai.com/v1",
|
||||
key: str = "",
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/api/embed",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating ollama batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
url = kwargs.get("url", "")
|
||||
key = kwargs.get("key", "")
|
||||
user = kwargs.get("user")
|
||||
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
|
||||
|
||||
class RerankCompressor(BaseDocumentCompressor):
|
||||
embedding_function: Any
|
||||
top_n: int
|
||||
reranking_function: Any
|
||||
r_score: float
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
reranking = self.reranking_function is not None
|
||||
|
||||
if reranking:
|
||||
scores = self.reranking_function.predict(
|
||||
[(query, doc.page_content) for doc in documents]
|
||||
)
|
||||
else:
|
||||
from sentence_transformers import util
|
||||
|
||||
query_embedding = self.embedding_function(query)
|
||||
document_embedding = self.embedding_function(
|
||||
[doc.page_content for doc in documents]
|
||||
)
|
||||
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
||||
|
||||
docs_with_scores = list(zip(documents, scores.tolist()))
|
||||
if self.r_score:
|
||||
docs_with_scores = [
|
||||
(d, s) for d, s in docs_with_scores if s >= self.r_score
|
||||
]
|
||||
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
final_results = []
|
||||
for doc, doc_score in result[: self.top_n]:
|
||||
metadata = doc.metadata
|
||||
metadata["score"] = doc_score
|
||||
doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata=metadata,
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import chromadb
|
||||
import logging
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
CHROMA_HTTP_PORT,
|
||||
CHROMA_HTTP_HEADERS,
|
||||
CHROMA_HTTP_SSL,
|
||||
CHROMA_TENANT,
|
||||
CHROMA_DATABASE,
|
||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
"anonymized_telemetry": False,
|
||||
}
|
||||
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
||||
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
||||
settings_dict["chroma_client_auth_credentials"] = (
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
)
|
||||
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
self.client = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
headers=CHROMA_HTTP_HEADERS,
|
||||
ssl=CHROMA_HTTP_SSL,
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
settings=Settings(**settings_dict),
|
||||
)
|
||||
else:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
settings=Settings(**settings_dict),
|
||||
tenant=CHROMA_TENANT,
|
||||
database=CHROMA_DATABASE,
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collection_names = self.client.list_collections()
|
||||
return collection_name in collection_names
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
return self.client.delete_collection(name=collection_name)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.query(
|
||||
query_embeddings=vectors,
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": result["distances"],
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.get(
|
||||
where=filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
result = collection.get()
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
|
||||
for batch in create_batches(
|
||||
api=self.client,
|
||||
documents=documents,
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [item["metadata"] for item in items]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
if collection:
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
elif filter:
|
||||
collection.delete(where=filter)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
return self.client.reset()
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
MILVUS_TOKEN,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
if MILVUS_TOKEN is None:
|
||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
|
||||
else:
|
||||
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in result:
|
||||
_ids = []
|
||||
_distances = []
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_distances.append(item.get("distance"))
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=True,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
description="vector",
|
||||
)
|
||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||
schema.add_field(
|
||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||
)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="HNSW",
|
||||
metric_type="COSINE",
|
||||
params={"M": 16, "efConstruction": 100},
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.drop_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
max_limit = 16383 # The maximum number of records per request
|
||||
all_results = []
|
||||
|
||||
if limit is None:
|
||||
limit = float("inf") # Use infinity as a placeholder for no limit
|
||||
|
||||
# Initialize offset and remaining to handle pagination
|
||||
offset = 0
|
||||
remaining = limit
|
||||
|
||||
try:
|
||||
# Loop until there are no more items to fetch or the desired limit is reached
|
||||
while remaining > 0:
|
||||
log.info(f"remaining: {remaining}")
|
||||
current_fetch = min(
|
||||
max_limit, remaining
|
||||
) # Determine how many items to fetch in this iteration
|
||||
|
||||
results = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
output_fields=["*"],
|
||||
limit=current_fetch,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
if not results:
|
||||
break
|
||||
|
||||
all_results.extend(results)
|
||||
results_count = len(results)
|
||||
remaining -= (
|
||||
results_count # Decrease remaining by the number of items fetched
|
||||
)
|
||||
offset += results_count
|
||||
|
||||
# Break the loop if the results returned are less than the requested fetch count
|
||||
if results_count < current_fetch:
|
||||
break
|
||||
|
||||
log.debug(all_results)
|
||||
return self._result_to_get_result([all_results])
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error querying collection {collection_name} with limit {limit}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
result = self.client.query(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter='id != ""',
|
||||
)
|
||||
return self._result_to_get_result([result])
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": item["metadata"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if ids:
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
# Convert the filter dictionary to a string using JSON_CONTAINS.
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
filter=filter_string,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.list_collections()
|
||||
for collection_name in collection_names:
|
||||
if collection_name.startswith(self.collection_prefix):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
from opensearchpy import OpenSearch
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
OPENSEARCH_CERT_VERIFY,
|
||||
OPENSEARCH_USERNAME,
|
||||
OPENSEARCH_PASSWORD,
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
hosts=[OPENSEARCH_URI],
|
||||
use_ssl=OPENSEARCH_SSL,
|
||||
verify_certs=OPENSEARCH_CERT_VERIFY,
|
||||
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
|
||||
)
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
ids = []
|
||||
distances = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
|
||||
def _create_index(self, index_name: str, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": true,
|
||||
"similarity": "faiss",
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "ip", # Use inner product to approximate cosine similarity
|
||||
"engine": "faiss",
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
},
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
}
|
||||
}
|
||||
}
|
||||
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
def has_collection(self, index_name: str) -> bool:
|
||||
# has_collection here means has index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
|
||||
|
||||
def delete_colleciton(self, index_name: str):
|
||||
# delete_collection here means delete index.
|
||||
# We are simply adapting to the norms of the other DBs.
|
||||
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
|
||||
|
||||
def search(
|
||||
self, index_name: str, vectors: list[list[float]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{collection_name}",
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_or_create_index(self, index_name: str, dimension: int):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension)
|
||||
|
||||
def get(self, index_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}_{index_name}", body=query
|
||||
)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
def insert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def upsert(self, index_name: str, items: list[VectorItem]):
|
||||
if not self.has_index(index_name):
|
||||
self._create_index(index_name, dimension=len(items[0]["vector"]))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"index": {
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
},
|
||||
}
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
self.client.bulk(actions)
|
||||
|
||||
def delete(self, index_name: str, ids: list[str]):
|
||||
actions = [
|
||||
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
|
||||
for id in ids
|
||||
]
|
||||
self.client.bulk(body=actions)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
from sqlalchemy import (
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
Table,
|
||||
values,
|
||||
)
|
||||
from sqlalchemy.sql import true
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.internal.db import Session
|
||||
|
||||
self.session = Session
|
||||
else:
|
||||
engine = create_engine(
|
||||
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
# Check vector length consistency
|
||||
self.check_vector_length()
|
||||
|
||||
# Create the tables if they do not exist
|
||||
# Base.metadata.create_all requires a bind (engine or connection)
|
||||
# Get the connection from the session
|
||||
connection = self.session.connection()
|
||||
Base.metadata.create_all(bind=connection)
|
||||
|
||||
# Create an index on the vector column if it doesn't exist
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
)
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
log.info("Initialization complete.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
"""
|
||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||
Raises an exception if there is a mismatch.
|
||||
"""
|
||||
metadata = MetaData()
|
||||
try:
|
||||
# Attempt to reflect the 'document_chunk' table
|
||||
document_chunk_table = Table(
|
||||
"document_chunk", metadata, autoload_with=self.session.bind
|
||||
)
|
||||
except NoSuchTableError:
|
||||
# Table does not exist; no action needed
|
||||
return
|
||||
|
||||
# Proceed to check the vector column
|
||||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
vector_type = vector_column.type
|
||||
if isinstance(vector_type, Vector):
|
||||
db_vector_length = vector_type.dim
|
||||
if db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column exists but is not of type 'Vector'."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
current_length = len(vector)
|
||||
if current_length < VECTOR_LENGTH:
|
||||
# Pad the vector with zeros
|
||||
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
||||
elif current_length > VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
||||
)
|
||||
return vector
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[float]],
|
||||
limit: Optional[int] = None,
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
# Adjust query vectors to VECTOR_LENGTH
|
||||
vectors = [self.adjust_vector_length(vector) for vector in vectors]
|
||||
num_queries = len(vectors)
|
||||
|
||||
def vector_expr(vector):
|
||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
|
||||
# Build the main query by joining query_vectors and the lateral subquery
|
||||
stmt = (
|
||||
select(
|
||||
query_vectors.c.qid,
|
||||
subq.c.id,
|
||||
subq.c.text,
|
||||
subq.c.vmetadata,
|
||||
subq.c.distance,
|
||||
)
|
||||
.select_from(query_vectors)
|
||||
.join(subq, true())
|
||||
.order_by(query_vectors.c.qid, subq.c.distance)
|
||||
)
|
||||
|
||||
result_proxy = self.session.execute(stmt)
|
||||
results = result_proxy.all()
|
||||
|
||||
ids = [[] for _ in range(num_queries)]
|
||||
distances = [[] for _ in range(num_queries)]
|
||||
documents = [[] for _ in range(num_queries)]
|
||||
metadatas = [[] for _ in range(num_queries)]
|
||||
|
||||
if not results:
|
||||
return SearchResult(
|
||||
ids=ids,
|
||||
distances=distances,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for row in results:
|
||||
qid = int(row.qid)
|
||||
ids[qid].append(row.id)
|
||||
distances[qid].append(row.distance)
|
||||
documents[qid].append(row.text)
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error during query: {e}")
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
log.exception(f"Error during get: {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during delete: {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during reset: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
return exists
|
||||
except Exception as e:
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
self.delete(collection_name)
|
||||
log.info(f"Collection '{collection_name}' deleted.")
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.client = (
|
||||
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
if self.QDRANT_URI
|
||||
else None
|
||||
)
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension, distance=models.Distance.COSINE
|
||||
),
|
||||
)
|
||||
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
|
||||
def _create_points(self, items: list[VectorItem]):
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={"text": item["text"], "metadata": item["metadata"]},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(
|
||||
f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
# Construct the filter string for querying
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
try:
|
||||
if limit is None:
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
query_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
except Exception as e:
|
||||
log.exception(f"Error querying a collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
points = self._create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids.
|
||||
field_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries.
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
id: str
|
||||
text: str
|
||||
vector: List[float | int]
|
||||
metadata: Any
|
||||
|
||||
|
||||
class GetResult(BaseModel):
|
||||
ids: Optional[List[List[str]]]
|
||||
documents: Optional[List[List[str]]]
|
||||
metadatas: Optional[List[List[Any]]]
|
||||
|
||||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
import logging
|
||||
import os
|
||||
from pprint import pprint
|
||||
from typing import Optional
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
import argparse
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
"""
|
||||
Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
|
||||
"""
|
||||
|
||||
|
||||
def search_bing(
|
||||
subscription_key: str,
|
||||
endpoint: str,
|
||||
locale: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
mkt = locale
|
||||
params = {"q": query, "mkt": mkt, "count": count}
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
|
||||
|
||||
try:
|
||||
response = requests.get(endpoint, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("webPages", {}).get("value", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("name"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as ex:
|
||||
log.error(f"Error: {ex}")
|
||||
raise ex
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
|
||||
parser.add_argument(
|
||||
"query",
|
||||
type=str,
|
||||
default="Top 10 international news today",
|
||||
help="The search query.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=10, help="Number of search results to return."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter", nargs="*", help="List of filters to apply to the search results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--locale",
|
||||
type=str,
|
||||
default="en-US",
|
||||
help="The locale to use for the search, maps to market in api",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
results = search_bing(args.locale, args.query, args.count, args.filter)
|
||||
pprint(results)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import json
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def _parse_response(response):
|
||||
result = {}
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
if "webPages" in data:
|
||||
webPages = data["webPages"]
|
||||
if "value" in webPages:
|
||||
result["webpage"] = [
|
||||
{
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("snippet", ""),
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "")
|
||||
or item.get("dateLastCrawled", ""),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def search_bocha(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Bocha Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = json.dumps(
|
||||
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||
)
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
results = _parse_response(response.json())
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||
)
|
||||
for result in results.get("webpage", [])[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_brave(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Brave Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": api_key,
|
||||
}
|
||||
params = {"q": query, "count": count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("web", {}).get("results", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from duckduckgo_search import DDGS
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_duckduckgo(
|
||||
query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
# Use the DDGS context manager to create a DDGS object
|
||||
with DDGS() as ddgs:
|
||||
# Use the ddgs.text() method to perform the search
|
||||
ddgs_gen = ddgs.text(
|
||||
query, safesearch="moderate", max_results=count, backend="api"
|
||||
)
|
||||
# Check if there are search results
|
||||
if ddgs_gen:
|
||||
# Convert the search results into a list
|
||||
search_results = [r for r in ddgs_gen]
|
||||
|
||||
if filter_list:
|
||||
search_results = get_filtered_results(search_results, filter_list)
|
||||
|
||||
# Return the list of search results
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
EXA_API_BASE = "https://api.exa.ai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExaResult:
|
||||
url: str
|
||||
title: str
|
||||
text: str
|
||||
|
||||
|
||||
def search_exa(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Exa Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Exa Search API key
|
||||
query (str): The query to search for
|
||||
count (int): Number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||
"""
|
||||
log.info(f"Searching with Exa for query: {query}")
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"numResults": count or 5,
|
||||
"includeDomains": filter_list,
|
||||
"contents": {"text": True, "highlights": True},
|
||||
"type": "auto", # Use the auto search type (keyword or neural)
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{EXA_API_BASE}/search", headers=headers, json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["results"]:
|
||||
results.append(
|
||||
ExaResult(
|
||||
url=result["url"],
|
||||
title=result["title"],
|
||||
text=result["text"],
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Found {len(results)} results")
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.url,
|
||||
title=result.title,
|
||||
snippet=result.text,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
log.error(f"Error searching Exa: {e}")
|
||||
return []
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
Handles pagination for counts greater than 10.
|
||||
|
||||
Args:
|
||||
api_key (str): A Programmable Search Engine API key
|
||||
search_engine_id (str): A Programmable Search Engine ID
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
|
||||
filter_list (Optional[list[str]], optional): A list of keywords to filter out from results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResult objects.
|
||||
"""
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
all_results = []
|
||||
start_index = 1 # Google PSE start parameter is 1-based
|
||||
|
||||
while count > 0:
|
||||
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": num_results_this_page,
|
||||
"start": start_index,
|
||||
}
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(
|
||||
results
|
||||
) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
if filter_list:
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in all_results
|
||||
]
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from yarl import URL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_jina(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""
|
||||
Search using Jina's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
jina_search_endpoint = "https://s.jina.ai/"
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": api_key,
|
||||
"X-Retain-Images": "none",
|
||||
}
|
||||
|
||||
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||
|
||||
url = str(URL(jina_search_endpoint))
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["data"]:
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("content"),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_kagi(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
The Search API will inherit the settings in your account, including results personalization and snippet length.
|
||||
|
||||
Args:
|
||||
api_key (str): A Kagi Search API key
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
"""
|
||||
url = "https://kagi.com/api/v0/search"
|
||||
headers = {
|
||||
"Authorization": f"Bot {api_key}",
|
||||
}
|
||||
params = {"q": query, "limit": count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
search_results = json_response.get("data", [])
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result["url"], title=result["title"], snippet=result.get("snippet")
|
||||
)
|
||||
for result in search_results
|
||||
if result["t"] == 0
|
||||
]
|
||||
|
||||
print(results)
|
||||
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return results
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import validators
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_filtered_results(results, filter_list):
|
||||
if not filter_list:
|
||||
return results
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
url = result.get("url") or result.get("link", "")
|
||||
if not validators.url(url):
|
||||
continue
|
||||
domain = urlparse(url).netloc
|
||||
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
||||
filtered_results.append(result)
|
||||
return filtered_results
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
link: str
|
||||
title: Optional[str]
|
||||
snippet: Optional[str]
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_mojeek(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Mojeek's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Mojeek Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.mojeek.com/search"
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("response", {}).get("results", [])
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("desc")
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_searchapi(
|
||||
api_key: str,
|
||||
engine: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using searchapi.io's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A searchapi.io API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://www.searchapi.io/api/v1/search"
|
||||
|
||||
engine = engine or "google"
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from searchapi search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_searxng(
|
||||
query_url: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function allows passing additional parameters such as language or time_range to tailor the search result.
|
||||
|
||||
Args:
|
||||
query_url (str): The base URL of the SearXNG server.
|
||||
query (str): The search term or question to find in the SearXNG database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
Keyword Args:
|
||||
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
|
||||
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
|
||||
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
||||
categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
|
||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||
language = kwargs.get("language", "en-US")
|
||||
safesearch = kwargs.get("safesearch", "1")
|
||||
time_range = kwargs.get("time_range", "")
|
||||
categories = "".join(kwargs.get("categories", []))
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"pageno": 1,
|
||||
"safesearch": safesearch,
|
||||
"language": language,
|
||||
"time_range": time_range,
|
||||
"categories": categories,
|
||||
"theme": "simple",
|
||||
"image_proxy": 0,
|
||||
}
|
||||
|
||||
# Legacy query format
|
||||
if "<query>" in query_url:
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.split("?")[0]
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpapi(
|
||||
api_key: str,
|
||||
engine: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpapi.com's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serpapi.com API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://serpapi.com/search"
|
||||
|
||||
engine = engine or "google"
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serpapi search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serper(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serper.dev API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://google.serper.dev/search"
|
||||
|
||||
payload = json.dumps({"q": query})
|
||||
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = sorted(
|
||||
json_response.get("organic", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serply(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
hl: str = "us",
|
||||
limit: int = 10,
|
||||
device_type: str = "desktop",
|
||||
proxy_location: str = "US",
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serply.io API key
|
||||
query (str): The query to search for
|
||||
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
|
||||
limit (int): The maximum number of results to return [10-100, defaults to 10]
|
||||
"""
|
||||
log.info("Searching with Serply")
|
||||
|
||||
url = "https://api.serply.io/v1/search/"
|
||||
|
||||
query_payload = {
|
||||
"q": query,
|
||||
"language": "en",
|
||||
"num": limit,
|
||||
"gl": proxy_location.upper(),
|
||||
"hl": hl.lower(),
|
||||
}
|
||||
|
||||
url = f"{url}{urlencode(query_payload)}"
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"X-User-Agent": device_type,
|
||||
"User-Agent": "open-webui",
|
||||
"X-Proxy-Location": proxy_location,
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serply search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue