Skip to content

Commit

Permalink
chore: refactor groq extension
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 21, 2024
1 parent 724fd49 commit 75d49d0
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 310 deletions.
4 changes: 3 additions & 1 deletion core/src/extensions/ai-engines/AIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 14,9 @@ export abstract class AIEngine extends BaseExtension {
// The model folder
modelFolder: string = 'models'

abstract models(): Promise<Model[]>
models(): Promise<Model[]> {
return Promise.resolve([])
}

/**
* On extension load, subscribe to events.
Expand Down
2 changes: 1 addition & 1 deletion extensions/inference-groq-extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 25,7 @@
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1",
"ulid": "^2.3.0"
"ulidx": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
Expand Down
83 changes: 0 additions & 83 deletions extensions/inference-groq-extension/src/helpers/sse.ts

This file was deleted.

187 changes: 20 additions & 167 deletions extensions/inference-groq-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,218 7,71 @@
*/

import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events,
fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName,
joinPath,
RemoteOAIEngine,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulid'
import { join } from 'path'

/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceGroqExtension extends BaseExtension {
private static readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'groq.json'
export default class JanInferenceGroqExtension extends RemoteOAIEngine {
private readonly _engineDir = 'file://engines'
private readonly _engineMetadataFileName = 'groq.json'

private static _currentModel: GroqModel
inferenceUrl: string = 'https://api.groq.com/openai/v1/chat/completions'
provider = 'groq'

private static _engineSettings: EngineSettings = {
private _engineSettings: EngineSettings = {
full_url: 'https://api.groq.com/openai/v1/chat/completions',
api_key: 'gsk-<your key here>',
}

controller = new AbortController()
isCancelled = false

/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) {
await fs
.mkdirSync(JanInferenceGroqExtension._engineDir)
.catch((err) => console.debug(err))
if (!(await fs.existsSync(this._engineDir))) {
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
}

JanInferenceGroqExtension.writeDefaultEngineSettings()

// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceGroqExtension.handleMessageRequest(data, this)
)

events.on(ModelEvent.OnModelInit, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelInit(model)
})

events.on(ModelEvent.OnModelStop, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceGroqExtension.handleInferenceStopped(this)
})
this.writeDefaultEngineSettings()

const settingsFilePath = await joinPath([
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName,
this._engineDir,
this._engineMetadataFileName,
])


// Events subscription
events.on(
AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath)
JanInferenceGroqExtension.writeDefaultEngineSettings()
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
}
)
}

/**
* Stops the model inference.
*/
onUnload(): void {}

static async writeDefaultEngineSettings() {
async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName
)
const engineFile = join(this._engineDir, this._engineMetadataFileName)
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceGroqExtension._engineSettings =
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.full_url
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2)
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private static async handleModelInit(model: GroqModel) {
if (model.engine !== InferenceEngine.groq) {
return
} else {
JanInferenceGroqExtension._currentModel = model
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}

private static async handleModelStop(model: GroqModel) {
if (model.engine !== 'groq') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}

private static async handleInferenceStopped(
instance: JanInferenceGroqExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}

/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceGroqExtension
) {
if (data.model.engine !== 'groq') {
return
}

const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}

if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}

instance.isCancelled = false
instance.controller = new AbortController()

requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceGroqExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
}
51 changes: 25 additions & 26 deletions models/groq-llama2-70b/model.json
Original file line number Diff line number Diff line change
@@ -1,27 1,26 @@
{
"sources": [
{
"url": "https://groq.com"
}
],
"id": "llama2-70b-4096",
"object": "model",
"name": "Groq Llama 2 70b",
"version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Meta",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}

"sources": [
{
"url": "https://groq.com"
}
],
"id": "llama2-70b-4096",
"object": "model",
"name": "Groq Llama 2 70b",
"version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Meta",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}
Loading

0 comments on commit 75d49d0

Please sign in to comment.