mirror of
https://github.com/sist2app/sist2.git
synced 2026-01-24 03:11:13 +00:00
wip
This commit is contained in:
48
sist2-vue/src/ml/CLIPTransformerModel.js
Normal file
48
sist2-vue/src/ml/CLIPTransformerModel.js
Normal file
@@ -0,0 +1,48 @@
|
||||
import * as ort from "onnxruntime-web";
|
||||
import {BPETokenizer} from "@/ml/BPETokenizer";
|
||||
import axios from "axios";
|
||||
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
||||
|
||||
export class CLIPTransformerModel {
|
||||
|
||||
_modelUrl = null;
|
||||
_tokenizerUrl = null;
|
||||
_model = null;
|
||||
_tokenizer = null;
|
||||
|
||||
constructor(modelUrl, tokenizerUrl) {
|
||||
this._modelUrl = modelUrl;
|
||||
this._tokenizerUrl = tokenizerUrl;
|
||||
}
|
||||
|
||||
async init(onProgress) {
|
||||
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
||||
const buf = await downloadToBuffer(this._modelUrl, onProgress);
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
|
||||
}
|
||||
|
||||
async loadTokenizer() {
|
||||
const resp = await axios.get(this._tokenizerUrl);
|
||||
this._tokenizer = new BPETokenizer(resp.data.encoder, resp.data.bpe_ranks)
|
||||
}
|
||||
|
||||
async predict(text) {
|
||||
const tokenized = this._tokenizer.encode(text);
|
||||
|
||||
const feeds = {
|
||||
input_ids: new ort.Tensor("int32", tokenized, [1, 77])
|
||||
};
|
||||
|
||||
const results = await this._model.run(feeds);
|
||||
|
||||
return Array.from(
|
||||
Object.values(results)
|
||||
.find(result => result.size === 512).data
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user