Skip to content

Commit

Permalink
chore: load, unload model and inference synchronously
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Mar 24, 2024
1 parent 1ad794c commit 767b06d
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 145 deletions.
51 changes: 44 additions & 7 deletions core/src/extensions/ai-engines/AIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../types'
import { engineManager } from './EngineManager'

/**
* Base AIEngine
Expand All @@ -11,30 12,66 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'

/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()

events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))

this.prePopulateModels()
}

models(): Promise<Model[]> {
return Promise.resolve([])
}

registerEngine() {
// Register AI Engines
engineManager.register(this)
}

/**
* On extension load, subscribe to events.
* Load the model.
*/
onLoad() {
this.prePopulateModels()
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}

/*
* Inference request
*/
inference(data: MessageRequest) {}

/**
* Stop inference
*/
stopInference() {}

/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
Expand Down
30 changes: 30 additions & 0 deletions core/src/extensions/ai-engines/EngineManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 1,30 @@
import { AIEngine } from './AIEngine'

/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()

/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}

/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}
}

/**
* The singleton instance of the ExtensionManager.
*/
export const engineManager = new EngineManager()
24 changes: 11 additions & 13 deletions core/src/extensions/ai-engines/LocalOAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
Expand All @@ -26,10 26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return

const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
Expand All @@ -42,24 42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)

if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()

executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/extensions/ai-engines/OAIEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
Expand All @@ -43,12 43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}

/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return

const timestamp = Date.now()
Expand Down Expand Up @@ -114,7 114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
Expand Down
22 changes: 1 addition & 21 deletions core/src/extensions/ai-engines/RemoteOAIEngine.ts
Original file line number Diff line number Diff line change
@@ -1,5 1,3 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'

/**
Expand All @@ -12,26 10,8 @@ export abstract class RemoteOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}

/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}

/**
Expand Down
1 change: 1 addition & 0 deletions core/src/extensions/ai-engines/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
export * from './EngineManager'
7 changes: 3 additions & 4 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}

override unloadModel(model: Model): void {
super.unloadModel(model)

if (model.engine && model.engine !== this.provider) return
override async unloadModel(model?: Model) {
if (model?.engine && model.engine !== this.provider) return

// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return super.unloadModel(model)
}
}
67 changes: 9 additions & 58 deletions web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 8,17 @@ import {
ExtensionTypeEnum,
MessageStatus,
MessageRequest,
Model,
ConversationalExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
Thread,
ModelInitFailed,
engineManager,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'

import {
activeModelAtom,
loadModelErrorAtom,
stateModelAtom,
} from '@/hooks/useActiveModel'

import { queuedMessageAtom } from '@/hooks/useSendChatMessage'

import { toaster } from '../Toast'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'

import { extensionManager } from '@/extension'
import {
Expand All @@ -51,8 42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setLoadModelError = useSetAtom(loadModelErrorAtom)

const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
Expand Down Expand Up @@ -88,44 77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage]
)

const onModelReady = useCallback(
(model: Model) => {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
},
[setActiveModel, setStateModel]
)

const onModelStopped = useCallback(() => {
setTimeout(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, 500)
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, [setActiveModel, setStateModel])

const onModelInitFailed = useCallback(
(res: ModelInitFailed) => {
console.error('Failed to load model: ', res.error.message)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.id,
}))
setLoadModelError(res.error.message)
setQueuedMessage(false)
},
[setStateModel, setQueuedMessage, setLoadModelError]
)

const updateThreadTitle = useCallback(
(message: ThreadMessage) => {
// Update only when it's finished
Expand Down Expand Up @@ -274,7 230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {

// 2. Update the title with the result of the inference
setTimeout(() => {
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = engineManager.get(
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
}, 1000)
}
}
Expand All @@ -283,17 242,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.on(ModelEvent.OnModelReady, onModelReady)
events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped)
}
}, [
onNewMessageResponse,
onMessageResponseUpdate,
onModelReady,
onModelInitFailed,
onModelStopped,
])
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])

useEffect(() => {
return () => {
Expand Down
Loading

0 comments on commit 767b06d

Please sign in to comment.