Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: inference engine provider and accessibility #2470

Merged
merged 5 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: load, unload model and inference synchronously
  • Loading branch information
louis-jan committed Mar 25, 2024
commit 9551996e349b1f75c8293cc117981f9a5e33b238
46 changes: 0 additions & 46 deletions core/src/extensions/ai-engines/RemoteOAIEngine.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../types'
import { EngineManager } from './EngineManager'

/**
* Base AIEngine
Expand All @@ -11,30 12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'

/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()

events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))

this.prePopulateModels()
}

/**
* Defines models
*/
models(): Promise<Model[]> {
return Promise.resolve([])
}

/**
* On extension load, subscribe to events.
* Registers AI Engines
*/
onLoad() {
this.prePopulateModels()
registerEngine() {
EngineManager.instance()?.register(this)
}

/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}

/*
* Inference request
*/
inference(data: MessageRequest) {}

/**
* Stop inference
*/
stopInference() {}

/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
Expand Down
34 changes: 34 additions & 0 deletions core/src/extensions/engines/EngineManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 1,34 @@
import { log } from '../../core'
import { AIEngine } from './AIEngine'

/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()

/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}

/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}

static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
}
}

/**
* The singleton instance of the ExtensionManager.
*/
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
Expand All @@ -26,10 26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return

const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
Expand All @@ -42,24 42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)

if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()

executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
Expand All @@ -43,12 43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}

/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return

const timestamp = Date.now()
Expand Down Expand Up @@ -114,7 114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
Expand Down
26 changes: 26 additions & 0 deletions core/src/extensions/engines/RemoteOAIEngine.ts
Original file line number Diff line number Diff line change
@@ -0,0 1,26 @@
import { OAIEngine } from './OAIEngine'

/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}

/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
export * from './EngineManager'
2 changes: 1 addition & 1 deletion core/src/extensions/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 32,4 @@ export { HuggingFaceExtension } from './huggingface'
/**
* Base AI Engines.
*/
export * from './ai-engines'
export * from './engines'
7 changes: 3 additions & 4 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}

override unloadModel(model: Model): void {
super.unloadModel(model)

if (model.engine && model.engine !== this.provider) return
override async unloadModel(model?: Model) {
if (model?.engine && model.engine !== this.provider) return

// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return super.unloadModel(model)
}
}
Loading
Loading