Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for LM Studio LLMs and Embedding models #1646

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support for LM Studio LLMs and Embedding models
  • Loading branch information
Dev-Khant committed Aug 5, 2024
commit 309d93b3c7e3e8d1e155f55bfd848d3fd99a281a
20 changes: 16 additions & 4 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 10,13 @@ def __init__(
self,
model: Optional[str] = None,
embedding_dims: Optional[int] = None,
api_key: Optional[str] = None,

# Ollama specific
base_url: Optional[str] = None
ollama_base_url: Optional[str] = None,

# LM Studio specific
lmstudio_base_url: Optional[str] = None
):
"""
Initializes a configuration class instance for the Embeddings.
Expand All @@ -21,12 25,20 @@ def __init__(
:type model: Optional[str], optional
:param embedding_dims: The number of dimensions in the embedding, defaults to None
:type embedding_dims: Optional[int], optional
:param base_url: Base URL for the Ollama API, defaults to None
:type base_url: Optional[str], optional
:param api_key: API key to use, defaults to None
:type api_key: Optional[str], optional
:param ollama_base_url: Base URL for the Ollama API, defaults to None
:type ollama_base_url: Optional[str], optional
:param lmstudio_base_url: Base URL for the LM Studio, defaults to None
:type lmstudio_base_url: Optional[str], optional
"""

self.model = model
self.embedding_dims = embedding_dims
self.api_key = api_key

# Ollama specific
self.base_url = base_url
self.ollama_base_url = ollama_base_url

# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
14 changes: 13 additions & 1 deletion mem0/configs/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 10,7 @@ def __init__(
self,
model: Optional[str] = None,
temperature: float = 0,
api_key: Optional[str] = None,
max_tokens: int = 3000,
top_p: float = 0,
top_k: int = 1,
Expand All @@ -22,7 23,10 @@ def __init__(
app_name: Optional[str] = None,

# Ollama specific
ollama_base_url: Optional[str] = None
ollama_base_url: Optional[str] = None,

# LM Studio specific
lmstudio_base_url: Optional[str] = None
):
"""
Initializes a configuration class instance for the LLM.
Expand All @@ -32,6 36,8 @@ def __init__(
:param temperature: Controls the randomness of the model's output.
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
:type temperature: float, optional
:param api_key: API key to use, defaults to None
:type api_key: Optional[str], optional
:param max_tokens: Controls how many tokens are generated, defaults to 3000
:type max_tokens: int, optional
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
Expand All @@ -51,10 57,13 @@ def __init__(
:type app_name: Optional[str], optional
:param ollama_base_url: The base URL of the LLM, defaults to None
:type ollama_base_url: Optional[str], optional
:param lmstudio_base_url: The base URL of the LLM Studio, defaults to None
:type lmstudio_base_url: Optional[str], optional
"""

self.model = model
self.temperature = temperature
self.api_key = api_key
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
Expand All @@ -68,3 77,6 @@ def __init__(

# Ollama specific
self.ollama_base_url = ollama_base_url

# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
110 changes: 110 additions & 0 deletions mem0/configs/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 40,113 @@

Here are the details of the task:
"""

FUNCTION_CALLING_PROMPT = """
You are an expert in function calling. Your task is to analyze user conversations, identify the appropriate functions to call from the provided list, and return the function calls in JSON format.

Function List:
[
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory, natural language text"},
},
"required": ["data"],
},
},
},
{
"type": "function",
"function": {
"name": "update_memory",
"description": "Update memory provided ID and data",
"parameters": {
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "memory_id of the memory to update",
},
"data": {
"type": "string",
"description": "Updated data for the memory, natural language text",
},
},
"required": ["memory_id", "data"],
},
},
},
{
"type": "function",
"function": {
"name": "delete_memory",
"description": "Delete memory by memory_id",
"parameters": {
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "memory_id of the memory to delete",
}
},
"required": ["memory_id"],
},
},
}
]


Each function in the list above includes:
- "name": The name of the function
- "description": A brief description of what the function does
- "parameters": The required parameters for the function
- "type": The data type of the parameters
- "properties": Specific properties of the parameters
- "required": List of required parameters

Your responsibilities:
1. Carefully read and understand the user's conversation.
2. Identify which function(s) from the provided list are relevant to the user's request.
3. For each relevant function:
a. Ensure all required parameters are included and properly formatted.
b. Strictly follow data type of the parameters.
c. Extract or infer parameter values from the user's conversation.
4. Construct a JSON object for each function call with the following structure:
{
"name": "function_name",
"parameters": {
"param1": "value1",
"param2": "value2",
...
}
}
5. If multiple functions are needed, return an array of these JSON objects.

Guidelines for response:
- Do not make contradictory function calls. Ensure all function calls are logically consistent with each other and the user's request.
- Ensure all required parameters are included in your function calls.
- Only call update_memory or delete_memory if a memory_id is present in the user's request.
- Do not call both update_memory and delete_memory on the same memory_id.
- Strictly follow the JSON format provided in the example response below.

Example response format:
{
"function_calls": [
{
"name": "function_1",
"parameters": {
"data": "Name is John"
}
},
]
}

CRITICAL: Your entire response must be a single JSON object. Do not write anything before or after the JSON. Do not explain your reasoning or provide any commentary. Only output the function calls JSON.

Now, please analyze the following conversation and provide the appropriate function call(s):
"""
2 changes: 1 addition & 1 deletion mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 9,7 @@ class EmbedderConfig(BaseModel):
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default=None
description="Configuration for the specific embedding model", default={}
)

@field_validator("config")
Expand Down
2 changes: 1 addition & 1 deletion mem0/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 18,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if not self.config.embedding_dims:
self.config.embedding_dims=512

self.client = Client(host=self.config.base_url)
self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists()

def _ensure_model_exists(self):
Expand Down
22 changes: 16 additions & 6 deletions mem0/embeddings/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 1,4 @@
import os
from typing import Optional
from openai import OpenAI

Expand All @@ -8,13 9,22 @@
class OpenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model="text-embedding-3-small"
if not self.config.embedding_dims:
self.config.embedding_dims=1536

self.client = OpenAI()
if self.config.lmstudio_base_url: # Use LM Studio
if not self.config.model:
self.config.model="nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.f16.gguf"
if not self.config.embedding_dims:
self.config.embedding_dims=768

self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
else:
if not self.config.model:
self.config.model="text-embedding-3-small"
if not self.config.embedding_dims:
self.config.embedding_dims=1536

api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
self.client = OpenAI(api_key=api_key)

def embed(self, text):
"""
Expand Down
2 changes: 1 addition & 1 deletion mem0/llms/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 14,7 @@ class LlmConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"):
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai", "lmstudio"):
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
Expand Down
91 changes: 91 additions & 0 deletions mem0/llms/lm_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 1,91 @@
import json
from typing import Dict, List, Optional

from openai import OpenAI

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.prompts import FUNCTION_CALLING_PROMPT


class LMStudioLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"

if self.config.lmstudio_base_url:
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
else:
raise ValueError("LM Studio base URL and API key is required.")


def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.

Args:
response: The raw response from API.
tools: The list of tools provided in the request.

Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"tool_calls": []
}

tool_calls = json.loads(response.choices[0].message.content)

for tool_call in tool_calls["function_calls"]:
if tool_call["name"] == "update_memory" or tool_call["name"] == "delete_memory":
if tool_call["parameters"]["memory_id"] == "":
continue
processed_response["tool_calls"].append({
"name": tool_call["name"],
"arguments": tool_call["parameters"]
})

return processed_response
else:
return response.choices[0].message.content

def generate_response(
self,
messages: List[Dict[str, str]],
response_format: dict = {"type": "json_object"},
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto"
):
"""
Generate a response based on the given messages using LM Studio.

Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".

Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}

if tools:
params["response_format"] = response_format
system_prompt = {
"role": "system",
"content": FUNCTION_CALLING_PROMPT
}
params["messages"].insert(0, system_prompt)

response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)
5 changes: 3 additions & 2 deletions mem0/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 14,11 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
if not self.config.model:
self.config.model="gpt-4o"

if os.environ.get("OPENROUTER_API_KEY"):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
else:
self.client = OpenAI()
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
self.client = OpenAI(api_key=api_key)

def _parse_response(self, response, tools):
"""
Expand Down
3 changes: 2 additions & 1 deletion mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 29,7 @@
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider)
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
Expand Down Expand Up @@ -375,6 375,7 @@ def _update_memory_tool(self, memory_id, data, metadata=None):

new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()

Expand Down
Loading
Loading