Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure support #1604

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mem0/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 1,6 @@
import importlib.metadata

__version__ = importlib.metadata.version("mem0ai")
__version__ = 0.9

from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa
42 changes: 42 additions & 0 deletions mem0/embeddings/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 1,42 @@
from openai import AzureOpenAI
import os
from mem0.embeddings.base import EmbeddingBase


class AzureOpenAIEmbedding(EmbeddingBase):
def __init__(self, model="text-embedding-ada-002"):
self.api_key=None
self.azure_endpoint=None
self.api_version = None
if os.getenv("EMBED_AZURE_OPENAI_API_KEY") and os.getenv("EMBED_AZURE_OPENAI_ENDPOINT") and os.getenv("EMBED_OPENAI_API_VERSION"):

self.api_key = os.getenv("EMBED_AZURE_OPENAI_API_KEY")
self.azure_endpoint = os.getenv("EMBED_AZURE_OPENAI_ENDPOINT")
self.api_version = os.getenv("EMBED_OPENAI_API_VERSION")
else:
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
self.api_version = os.getenv("AZURE_OPENAI_ENDPOINT")
print(self.api_key, self.azure_endpoint, self.api_version)
self.client = AzureOpenAI(api_version=self.api_version, api_key=self.api_key, azure_endpoint=self.azure_endpoint)
self.model = model
self.dims = 1536

def embed(self, text):
"""
Get the embedding for the given text using OpenAI.

Args:
text (str): The text to embed.

Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")

return (
self.client.embeddings.create(input=[text], model=self.model)
.data[0]
.embedding
)

6 changes: 3 additions & 3 deletions mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 5,17 @@

class EmbedderConfig(BaseModel):
provider: str = Field(
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
description="Provider of the embedding model (e.g., 'ollama', 'openai','litellm')",
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default=None
description="Configuration for the specific embedding model", default={}
)

@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama"]:
if provider in ["openai", "ollama","litellm"]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
Expand Down
76 changes: 76 additions & 0 deletions mem0/llms/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 1,76 @@
import json
from typing import Dict, List, Optional

from openai import AzureOpenAI

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig

class AzureOpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model="gpt-4o"
self.client = AzureOpenAI()

def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
}

if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})

return processed_response
else:
return response.choices[0].message.content


def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Azure OpenAI.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice

response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)
4 changes: 2 additions & 2 deletions mem0/llms/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 5,7 @@

class LlmConfig(BaseModel):
provider: str = Field(
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
description="Provider of the LLM (e.g., 'ollama', 'openai','azure_openai')", default="openai"
)
config: Optional[dict] = Field(
description="Configuration for the specific LLM", default={}
Expand All @@ -14,7 14,7 @@ class LlmConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm"):
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm","azure_openai"):
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
Expand Down
4 changes: 3 additions & 1 deletion mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 17,7 @@ class LlmFactory:
"together": "mem0.llms.together.TogetherLLM",
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM",
"litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM"
}

@classmethod
Expand All @@ -33,7 34,8 @@ class EmbedderFactory:
provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding"
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"litellm": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding"
}

@classmethod
Expand Down
37 changes: 37 additions & 0 deletions tests/test_azure.py
Original file line number Diff line number Diff line change
@@ -0,0 1,37 @@

import openai
import os, time,sys
sys.path.append("../")
from mem0 import Memory
# We tested the scenario where embedding and LLM are provided by different interfaces from azure_openai. If an interface can provide both embedding and LLM simultaneously, then it is sufficient to set the three environment variables: “AZURE_OPENAI_API_KEY”, “AZURE_OPENAI_ENDPOINT”, and “OPENAI_API_VERSION”.
os.environ["AZURE_OPENAI_API_KEY"] = ""
os.environ["AZURE_OPENAI_ENDPOINT"] = ""
os.environ["OPENAI_API_VERSION"] = ""

os.environ["EMBED_AZURE_OPENAI_API_KEY"] = ""
os.environ["EMBED_AZURE_OPENAI_ENDPOINT"] = ""
os.environ["EMBED_OPENAI_API_VERSION"] = ""

config = {
"llm": {
"provider": "azure_openai",
"config": {
"model": "gpt-4o",
"temperature": 0.1,
"max_tokens": 2000,
}
},
"embedder":{
"provider":"litellm"
}

}

m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})

related_memories = m.search(query="What are Alice's hobbies?", user_id="alice")
print(related_memories)
all_memories = m.get_all()
memory_id = all_memories[0]["id"]
history = m.history(memory_id=memory_id)