Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: db migration #3595

Merged
merged 34 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift click to select a range
df09d08
feat(sqlalchemy): Replace peewee with sqlalchemy
jonathan-rohde Jun 18, 2024
bee835c
feat(sqlalchemy): remove session reference from router
jonathan-rohde Jun 21, 2024
070d908
feat(sqlalchemy): use subprocess to do migrations
jonathan-rohde Jun 24, 2024
320e658
feat(sqlalchemy): cleanup fixes
jonathan-rohde Jun 24, 2024
c134eab
feat(sqlalchemy): format backend
jonathan-rohde Jun 24, 2024
eb01e8d
feat(sqlalchemy): use scoped session
jonathan-rohde Jun 24, 2024
da403f3
feat(sqlalchemy): use session factory instead of context manager
jonathan-rohde Jun 24, 2024
a9b1487
feat(sqlalchemy): fix wrong column types
jonathan-rohde Jun 24, 2024
8f939cf
feat(sqlalchemy): some fixes
jonathan-rohde Jun 24, 2024
2fb27ad
feat(sqlalchemy): add missing file
jonathan-rohde Jun 24, 2024
d88bd51
feat(sqlalchemy): format backend
jonathan-rohde Jun 24, 2024
642c352
feat(sqlalchemy): rebase
jonathan-rohde Jun 25, 2024
d4b6b7c
feat(sqlalchemy): reverted not needed api change
jonathan-rohde Jun 25, 2024
23e4d9d
feat(sqlalchemy): formatting
jonathan-rohde Jun 25, 2024
827b1e5
feat(sqlalchemy): execute tests in github actions
jonathan-rohde Jun 25, 2024
df47c49
Merge branch 'refs/heads/dev' into feat/sqlalchemy-instead-of-peewee
jonathan-rohde Jun 28, 2024
5391f4c
feat(sqlalchemy): add new column
jonathan-rohde Jun 28, 2024
2aecd7d
Merge branch 'refs/heads/dev' into feat/sqlalchemy-instead-of-peewee
jonathan-rohde Jul 1, 2024
d0e89a0
Merge pull request #3327 from jonathan-rohde/feat/sqlalchemy-instead-…
tjbck Jul 2, 2024
647aa19
chore: format
tjbck Jul 2, 2024
44a9b86
fix: functions
tjbck Jul 3, 2024
aa88022
fix: functions
tjbck Jul 3, 2024
4d23957
revert: model_validate
tjbck Jul 3, 2024
0d78b63
Merge pull request #3621 from open-webui/dev
tjbck Jul 4, 2024
15f6f7b
revert: peewee migrations
tjbck Jul 4, 2024
bfc53b4
revert
tjbck Jul 4, 2024
1b65df3
revert
tjbck Jul 4, 2024
8646460
refac
tjbck Jul 4, 2024
37a5d2c
Update db.py
tjbck Jul 4, 2024
8fe2a7b
fix
tjbck Jul 4, 2024
8b13755
Update auths.py
tjbck Jul 4, 2024
d60f066
Merge pull request #3668 from open-webui/dev
tjbck Jul 6, 2024
1436bb7
enh: handle peewee migration
tjbck Jul 6, 2024
4e75150
Merge pull request #3669 from open-webui/dev-migration-session
tjbck Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(sqlalchemy): format backend
  • Loading branch information
jonathan-rohde committed Jun 27, 2024
commit c134eab27a929cbf678a60356a4c8f6c2e718201
5 changes: 3 additions & 2 deletions backend/apps/webui/internal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 53,9 @@ def python_value(self, value):
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
Base = declarative_base()


Expand All @@ -66,4 68,3 @@ def get_session():
except Exception as e:
db.rollback()
raise e

20 changes: 7 additions & 13 deletions backend/apps/webui/models/auths.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 126,7 @@ def insert_new_auth(
else:
return None

def authenticate_user(
self, email: str, password: str
) -> Optional[UserModel]:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
with get_session() as db:
try:
Expand All @@ -144,9 142,7 @@ def authenticate_user(
except:
return None

def authenticate_user_by_api_key(
self, api_key: str
) -> Optional[UserModel]:
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
with get_session() as db:
# if no api_key, return None
Expand All @@ -159,9 155,7 @@ def authenticate_user_by_api_key(
except:
return False

def authenticate_user_by_trusted_header(
self, email: str
) -> Optional[UserModel]:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
with get_session() as db:
try:
Expand All @@ -172,12 166,12 @@ def authenticate_user_by_trusted_header(
except:
return None

def update_user_password_by_id(
self, id: str, new_password: str
) -> bool:
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
with get_session() as db:
try:
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
return True if result == 1 else False
except:
return False
Expand Down
36 changes: 15 additions & 21 deletions backend/apps/webui/models/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 79,17 @@ class ChatTitleIdResponse(BaseModel):

class ChatTable:

def insert_new_chat(
self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_session() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"created_at": int(time.time()),
Expand All @@ -103,9 103,7 @@ def insert_new_chat(
db.refresh(result)
return ChatModel.model_validate(result) if result else None

def update_chat_by_id(
self, id: str, chat: dict
) -> Optional[ChatModel]:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
with get_session() as db:
try:
chat_obj = db.get(Chat, id)
Expand All @@ -119,9 117,7 @@ def update_chat_by_id(
except Exception as e:
return None

def insert_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
Expand All @@ -145,14 141,14 @@ def insert_shared_chat_by_chat_id(
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)

return shared_chat if (shared_result and result) else None

def update_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
try:
print("update_shared_chat_by_id")
Expand Down Expand Up @@ -271,9 267,7 @@ def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
except Exception as e:
return None

def get_chat_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChatModel]:
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_session() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
Expand All @@ -293,13 287,13 @@ def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]

def get_archived_chats_by_user_id(
self, user_id: str
) -> List[ChatModel]:
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat)
Expand Down
4 changes: 3 additions & 1 deletion backend/apps/webui/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 106,9 @@ def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:

def get_docs(self) -> List[DocumentModel]:
with get_session() as db:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]

def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
Expand Down
1 change: 1 addition & 0 deletions backend/apps/webui/models/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 39,7 @@ class FileModel(BaseModel):

model_config = ConfigDict(from_attributes=True)


####################
# Forms
####################
Expand Down
26 changes: 15 additions & 11 deletions backend/apps/webui/models/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 142,9 @@ def get_functions_by_type(
with get_session() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(
type=type, is_active=True
).all()
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
with get_session() as db:
Expand Down Expand Up @@ -220,10 220,12 @@ def update_user_valves_by_id_and_user_id(
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
with get_session() as db:
db.query(Function).filter_by(id=id).update({
**updated,
"updated_at": int(time.time()),
})
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except:
Expand All @@ -232,10 234,12 @@ def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionMode
def deactivate_all_functions(self) -> Optional[bool]:
try:
with get_session() as db:
db.query(Function).update({
"is_active": False,
"updated_at": int(time.time()),
})
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except:
Expand Down
4 changes: 1 addition & 3 deletions backend/apps/webui/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 153,7 @@ def get_model_by_id(self, id: str) -> Optional[ModelModel]:
except:
return None

def update_model_by_id(
self, id: str, model: ModelForm
) -> Optional[ModelModel]:
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
with get_session() as db:
Expand Down
4 changes: 3 additions & 1 deletion backend/apps/webui/models/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 83,9 @@ def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:

def get_prompts(self) -> List[PromptModel]:
with get_session() as db:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]

def update_prompt_by_command(
self, command: str, form_data: PromptForm
Expand Down
18 changes: 8 additions & 10 deletions backend/apps/webui/models/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 79,7 @@ class ChatTagsResponse(BaseModel):

class TagTable:

def insert_new_tag(
self, name: str, user_id: str
) -> Optional[TagModel]:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
Expand Down Expand Up @@ -201,11 199,13 @@ def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
with get_session() as db:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)

def delete_tag_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> bool:
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
with get_session() as db:
res = (
Expand Down Expand Up @@ -252,9 252,7 @@ def delete_tag_by_tag_name_and_chat_id_and_user_id(
log.error(f"delete_tag: {e}")
return False

def delete_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> bool:
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)

for tag in tags:
Expand Down
16 changes: 6 additions & 10 deletions backend/apps/webui/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 165,7 @@ def get_first_user(self) -> UserModel:
except:
return None

def update_user_role_by_id(
self, id: str, role: str
) -> Optional[UserModel]:
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"role": role})
Expand All @@ -193,12 191,12 @@ def update_user_profile_image_url_by_id(
except:
return None

def update_user_last_active_by_id(
self, id: str
) -> Optional[UserModel]:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)

user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
Expand All @@ -217,9 215,7 @@ def update_user_oauth_sub_by_id(
except:
return None

def update_user_by_id(
self, id: str, updated: dict
) -> Optional[UserModel]:
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update(updated)
Expand Down
14 changes: 4 additions & 10 deletions backend/apps/webui/routers/auths.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 78,7 @@ async def get_session_user(

@router.post("/update/profile", response_model=UserResponse)
async def update_profile(
form_data: UpdateProfileForm,
session_user=Depends(get_current_user)
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
):
if session_user:
user = Users.update_user_by_id(
Expand All @@ -101,8 100,7 @@ async def update_profile(

@router.post("/update/password", response_model=bool)
async def update_password(
form_data: UpdatePasswordForm,
session_user=Depends(get_current_user)
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
Expand Down Expand Up @@ -269,9 267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):


@router.post("/add", response_model=SigninResponse)
async def add_user(
form_data: AddUserForm, user=Depends(get_admin_user)
):
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):

if not validate_email_format(form_data.email.lower()):
raise HTTPException(
Expand Down Expand Up @@ -316,9 312,7 @@ async def add_user(


@router.get("/admin/details")
async def get_admin_details(
request: Request, user=Depends(get_current_user)
):
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
Expand Down
Loading