Skip to content

Commit

Permalink
chore: load, unload model and inference synchronously
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 23, 2024
1 parent 1ad794c commit b914301
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 144 deletions.
43 changes: 36 additions & 7 deletions core/src/extensions/ai-engines/AIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,7 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../types'

/**
* Base AIEngine
Expand All @@ -11,30 11,59 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'

/**
* On extension load, subscribe to events.
*/
override onLoad() {
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))

this.prePopulateModels()
}

models(): Promise<Model[]> {
return Promise.resolve([])
}

/**
* On extension load, subscribe to events.
* Load the model.
*/
onLoad() {
this.prePopulateModels()
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}

/*
* Inference request
*/
inference(data: MessageRequest) {}

/**
* Stop inference
*/
stopInference() {}

/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
Expand Down
24 changes: 11 additions & 13 deletions core/src/extensions/ai-engines/LocalOAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
Expand All @@ -26,10 26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return

const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
Expand All @@ -42,24 42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)

if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()

executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/extensions/ai-engines/OAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
Expand All @@ -43,12 43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}

/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return

const timestamp = Date.now()
Expand Down Expand Up @@ -114,7 114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
Expand Down
20 changes: 1 addition & 19 deletions core/src/extensions/ai-engines/RemoteOAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 12,8 @@ export abstract class RemoteOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}

/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}

/**
Expand Down
7 changes: 3 additions & 4 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}

override unloadModel(model: Model): void {
super.unloadModel(model)

if (model.engine && model.engine !== this.provider) return
override async unloadModel(model?: Model) {
if (model?.engine && model.engine !== this.provider) return

// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return super.unloadModel(model)
}
}
66 changes: 8 additions & 58 deletions web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 8,16 @@ import {
ExtensionTypeEnum,
MessageStatus,
MessageRequest,
Model,
ConversationalExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
Thread,
ModelInitFailed,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'

import {
activeModelAtom,
loadModelErrorAtom,
stateModelAtom,
} from '@/hooks/useActiveModel'

import { queuedMessageAtom } from '@/hooks/useSendChatMessage'

import { toaster } from '../Toast'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'

import { extensionManager } from '@/extension'
import {
Expand All @@ -51,8 41,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setLoadModelError = useSetAtom(loadModelErrorAtom)

const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
Expand Down Expand Up @@ -88,44 76,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage]
)

const onModelReady = useCallback(
(model: Model) => {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
},
[setActiveModel, setStateModel]
)

const onModelStopped = useCallback(() => {
setTimeout(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, 500)
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, [setActiveModel, setStateModel])

const onModelInitFailed = useCallback(
(res: ModelInitFailed) => {
console.error('Failed to load model: ', res.error.message)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.id,
}))
setLoadModelError(res.error.message)
setQueuedMessage(false)
},
[setStateModel, setQueuedMessage, setLoadModelError]
)

const updateThreadTitle = useCallback(
(message: ThreadMessage) => {
// Update only when it's finished
Expand Down Expand Up @@ -274,7 229,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {

// 2. Update the title with the result of the inference
setTimeout(() => {
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = extensionManager.getEngine(
messageRequest.model?.engine ?? ''
)
engine?.inference(messageRequest)
}, 1000)
}
}
Expand All @@ -283,17 241,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.on(ModelEvent.OnModelReady, onModelReady)
events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped)
}
}, [
onNewMessageResponse,
onMessageResponseUpdate,
onModelReady,
onModelInitFailed,
onModelStopped,
])
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])

useEffect(() => {
return () => {
Expand Down
25 changes: 24 additions & 1 deletion web/extension/ExtensionManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import {
AIEngine,
BaseExtension,
ExtensionTypeEnum,
InferenceEngine,

Check warning on line 7 in web/extension/ExtensionManager.ts

View workflow job for this annotation

GitHub Actions / test-on-ubuntu

'InferenceEngine' is defined but never used

Check warning on line 7 in web/extension/ExtensionManager.ts

View workflow job for this annotation

GitHub Actions / test-on-macos

'InferenceEngine' is defined but never used

Check warning on line 7 in web/extension/ExtensionManager.ts

View workflow job for this annotation

GitHub Actions / test-on-windows-pr

'InferenceEngine' is defined but never used
} from '@janhq/core'

import Extension from './Extension'

Expand All @@ -9,13 14,22 @@ import Extension from './Extension'
*/
export class ExtensionManager {
private extensions = new Map<string, BaseExtension>()
private engines = new Map<string, AIEngine>()

/**
* Registers an extension.
* @param extension - The extension to register.
*/
register<T extends BaseExtension>(name: string, extension: T) {
this.extensions.set(extension.type() ?? name, extension)

// Register AI Engines
if ('provider' in extension && typeof extension.provider === 'string') {
this.engines.set(
extension.provider as unknown as string,
extension as unknown as AIEngine
)
}
}

/**
Expand All @@ -29,6 43,15 @@ export class ExtensionManager {
return this.extensions.get(type) as T | undefined
}

/**
* Retrieves a extension by its type.
* @param engine - The engine name to retrieve.
* @returns The extension, if found.
*/
getEngine<T extends AIEngine>(engine: string): T | undefined {
return this.engines.get(engine) as T | undefined
}

/**
* Loads all registered extension.
*/
Expand Down
Loading

0 comments on commit b914301

Please sign in to comment.