mirror of
https://github.com/sist2app/sist2.git
synced 2026-03-29 11:31:35 +00:00
Rework user scripts, update DB schema to support embeddings
This commit is contained in:
@@ -2,6 +2,7 @@ import * as ort from "onnxruntime-web";
|
||||
import {BPETokenizer} from "@/ml/BPETokenizer";
|
||||
import axios from "axios";
|
||||
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
||||
import ModelStore from "@/ml/ModelStore";
|
||||
|
||||
export class CLIPTransformerModel {
|
||||
|
||||
@@ -21,9 +22,17 @@ export class CLIPTransformerModel {
|
||||
|
||||
async loadModel(onProgress) {
|
||||
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
||||
const buf = await downloadToBuffer(this._modelUrl, onProgress);
|
||||
ort.env.wasm.numThreads = 2;
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
|
||||
let buf = await ModelStore.get(this._modelUrl);
|
||||
if (!buf) {
|
||||
buf = await downloadToBuffer(this._modelUrl, onProgress);
|
||||
await ModelStore.set(this._modelUrl, buf);
|
||||
}
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {
|
||||
executionProviders: ["wasm"],
|
||||
});
|
||||
}
|
||||
|
||||
async loadTokenizer() {
|
||||
@@ -34,11 +43,11 @@ export class CLIPTransformerModel {
|
||||
async predict(text) {
|
||||
const tokenized = this._tokenizer.encode(text);
|
||||
|
||||
const feeds = {
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int32", tokenized, [1, 77])
|
||||
};
|
||||
|
||||
const results = await this._model.run(feeds);
|
||||
const results = await this._model.run(inputs);
|
||||
|
||||
return Array.from(
|
||||
Object.values(results)
|
||||
|
||||
67
sist2-vue/src/ml/ModelStore.js
Normal file
67
sist2-vue/src/ml/ModelStore.js
Normal file
@@ -0,0 +1,67 @@
|
||||
class ModelStore {
|
||||
|
||||
_ok;
|
||||
_db;
|
||||
_resolve;
|
||||
_loadingPromise;
|
||||
|
||||
constructor() {
|
||||
const request = window.indexedDB.open("ModelStore", 1);
|
||||
|
||||
request.onerror = () => {
|
||||
this._ok = false;
|
||||
}
|
||||
|
||||
request.onupgradeneeded = event => {
|
||||
const db = event.target.result;
|
||||
db.createObjectStore("models");
|
||||
}
|
||||
|
||||
request.onsuccess = () => {
|
||||
this._ok = true;
|
||||
this._db = request.result;
|
||||
|
||||
this._resolve();
|
||||
}
|
||||
|
||||
this._loadingPromise = new Promise(resolve => this._resolve = resolve);
|
||||
}
|
||||
|
||||
async get(key) {
|
||||
await this._loadingPromise;
|
||||
|
||||
const req = this._db.transaction(["models"], "readwrite")
|
||||
.objectStore("models")
|
||||
.get(key);
|
||||
|
||||
return new Promise(resolve => {
|
||||
req.onsuccess = event => {
|
||||
resolve(event.target.result);
|
||||
};
|
||||
req.onerror = event => {
|
||||
console.log("ERROR:");
|
||||
console.log(event);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
async set(key, val) {
|
||||
await this._loadingPromise;
|
||||
|
||||
const req = this._db.transaction(["models"], "readwrite")
|
||||
.objectStore("models")
|
||||
.put(val, key);
|
||||
|
||||
return new Promise(resolve => {
|
||||
req.onsuccess = () => {
|
||||
resolve(true);
|
||||
};
|
||||
req.onerror = () => {
|
||||
resolve(false);
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export default new ModelStore();
|
||||
@@ -17,7 +17,6 @@ export async function downloadToBuffer(url, onProgress) {
|
||||
break;
|
||||
}
|
||||
|
||||
console.log(`Sending ${value.length} bytes into ${buf.length} at offset ${cursor} (${buf.length - cursor} free)`)
|
||||
buf.set(value, cursor);
|
||||
cursor += value.length;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user