mirror of
https://github.com/sist2app/sist2.git
synced 2026-03-29 03:21:37 +00:00
wip
This commit is contained in:
118
sist2-vue/src/ml/BPETokenizer.js
Normal file
118
sist2-vue/src/ml/BPETokenizer.js
Normal file
@@ -0,0 +1,118 @@
|
||||
const inf = Number.POSITIVE_INFINITY;
|
||||
const START_TOK = 49406;
|
||||
const END_TOK = 49407;
|
||||
|
||||
function min(array, key) {
|
||||
return array
|
||||
.reduce((a, b) => (key(a, b) ? b : a))
|
||||
}
|
||||
|
||||
class TupleSet extends Set {
|
||||
add(elem) {
|
||||
return super.add(elem.join("`"));
|
||||
}
|
||||
|
||||
has(elem) {
|
||||
return super.has(elem.join("`"));
|
||||
}
|
||||
|
||||
toList() {
|
||||
return [...this].map(x => x.split("`"))
|
||||
}
|
||||
}
|
||||
|
||||
export class BPETokenizer {
|
||||
|
||||
_encoder = null;
|
||||
_bpeRanks = null;
|
||||
|
||||
constructor(encoder, bpeRanks) {
|
||||
this._encoder = encoder;
|
||||
this._bpeRanks = bpeRanks;
|
||||
}
|
||||
|
||||
getPairs(word) {
|
||||
const pairs = new TupleSet();
|
||||
|
||||
let prevChar = word[0];
|
||||
for (let i = 1; i < word.length; i++) {
|
||||
pairs.add([prevChar, word[i]])
|
||||
prevChar = word[i];
|
||||
}
|
||||
|
||||
return pairs.toList();
|
||||
}
|
||||
|
||||
bpe(token) {
|
||||
let word = [...token];
|
||||
word[word.length - 1] += "</w>";
|
||||
let pairs = this.getPairs(word)
|
||||
|
||||
if (pairs.length === 0) {
|
||||
return token + "</w>"
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const bigram = min(pairs, (a, b) => {
|
||||
return (this._bpeRanks[a.join("`")] ?? inf) > (this._bpeRanks[b.join("`") ?? inf])
|
||||
});
|
||||
|
||||
if (this._bpeRanks[bigram.join("`")] === undefined) {
|
||||
break;
|
||||
}
|
||||
|
||||
const [first, second] = bigram;
|
||||
let newWord = [];
|
||||
let i = 0;
|
||||
|
||||
while (i < word.length) {
|
||||
const j = word.indexOf(first, i);
|
||||
if (j === -1) {
|
||||
newWord.push(...word.slice(i));
|
||||
break;
|
||||
} else {
|
||||
newWord.push(...word.slice(i, j));
|
||||
i = j;
|
||||
}
|
||||
|
||||
if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
|
||||
newWord.push(first + second);
|
||||
i += 2;
|
||||
} else {
|
||||
newWord.push(word[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
word = [...newWord]
|
||||
if (word.length === 1) {
|
||||
break;
|
||||
} else {
|
||||
pairs = this.getPairs(word);
|
||||
}
|
||||
}
|
||||
|
||||
return word.join(" ");
|
||||
}
|
||||
|
||||
encode(text) {
|
||||
let bpeTokens = [];
|
||||
text = text.trim();
|
||||
text = text.replaceAll(/\s+/g, " ");
|
||||
|
||||
text
|
||||
.match(/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z0-9]+/ig)
|
||||
.forEach(token => {
|
||||
bpeTokens.push(...this.bpe(token).split(" ").map(t => this._encoder[t]));
|
||||
});
|
||||
|
||||
bpeTokens.unshift(START_TOK);
|
||||
bpeTokens = bpeTokens.slice(0, 76);
|
||||
bpeTokens.push(END_TOK);
|
||||
while (bpeTokens.length < 77) {
|
||||
bpeTokens.push(0);
|
||||
}
|
||||
|
||||
return bpeTokens;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
import BertTokenizer from "@/ml/BertTokenizer";
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
import axios from "axios";
|
||||
import {chunk as _chunk} from "underscore";
|
||||
import * as ort from "onnxruntime-web";
|
||||
import {argMax, downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
||||
|
||||
export default class BertNerModel {
|
||||
vocabUrl;
|
||||
@@ -29,7 +31,10 @@ export default class BertNerModel {
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
|
||||
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
||||
const buf = await downloadToBuffer(this.modelUrl, onProgress);
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
|
||||
}
|
||||
|
||||
alignLabels(labels, wordIds, words) {
|
||||
@@ -57,21 +62,28 @@ export default class BertNerModel {
|
||||
|
||||
async predict(text, callback) {
|
||||
this._previousWordId = null;
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize)
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize);
|
||||
|
||||
let i = 0;
|
||||
for (let chunk of encoded.inputChunks) {
|
||||
const rawResult = tf.tidy(() => this._model.execute({
|
||||
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
|
||||
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
|
||||
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
|
||||
}));
|
||||
|
||||
const labelIds = await tf.argMax(rawResult, -1);
|
||||
const labelIdsArray = await labelIds.array();
|
||||
const labels = labelIdsArray[0].map(id => this.id2label[id]);
|
||||
rawResult.dispose()
|
||||
const results = await this._model.run({
|
||||
input_ids: new ort.Tensor("int32", chunk.inputIds, [1, this.inputSize]),
|
||||
token_type_ids: new ort.Tensor("int32", chunk.segmentIds, [1, this.inputSize]),
|
||||
attention_mask: new ort.Tensor("int32", chunk.inputMask, [1, this.inputSize]),
|
||||
});
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
|
||||
const labelIds = _chunk(results["output"].data, this.id2label.length).map(argMax);
|
||||
const labels = labelIds.map(id => this.id2label[id]);
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words));
|
||||
|
||||
i += 1;
|
||||
|
||||
// give browser some time to repaint
|
||||
if (i % 2 === 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import {zip, chunk} from "underscore";
|
||||
import {toInt64} from "@/ml/mlUtils";
|
||||
|
||||
const UNK_INDEX = 100;
|
||||
const CLS_INDEX = 101;
|
||||
|
||||
48
sist2-vue/src/ml/CLIPTransformerModel.js
Normal file
48
sist2-vue/src/ml/CLIPTransformerModel.js
Normal file
@@ -0,0 +1,48 @@
|
||||
import * as ort from "onnxruntime-web";
|
||||
import {BPETokenizer} from "@/ml/BPETokenizer";
|
||||
import axios from "axios";
|
||||
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
||||
|
||||
export class CLIPTransformerModel {
|
||||
|
||||
_modelUrl = null;
|
||||
_tokenizerUrl = null;
|
||||
_model = null;
|
||||
_tokenizer = null;
|
||||
|
||||
constructor(modelUrl, tokenizerUrl) {
|
||||
this._modelUrl = modelUrl;
|
||||
this._tokenizerUrl = tokenizerUrl;
|
||||
}
|
||||
|
||||
async init(onProgress) {
|
||||
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
||||
const buf = await downloadToBuffer(this._modelUrl, onProgress);
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
|
||||
}
|
||||
|
||||
async loadTokenizer() {
|
||||
const resp = await axios.get(this._tokenizerUrl);
|
||||
this._tokenizer = new BPETokenizer(resp.data.encoder, resp.data.bpe_ranks)
|
||||
}
|
||||
|
||||
async predict(text) {
|
||||
const tokenized = this._tokenizer.encode(text);
|
||||
|
||||
const feeds = {
|
||||
input_ids: new ort.Tensor("int32", tokenized, [1, 77])
|
||||
};
|
||||
|
||||
const results = await this._model.run(feeds);
|
||||
|
||||
return Array.from(
|
||||
Object.values(results)
|
||||
.find(result => result.size === 512).data
|
||||
);
|
||||
}
|
||||
}
|
||||
47
sist2-vue/src/ml/mlUtils.js
Normal file
47
sist2-vue/src/ml/mlUtils.js
Normal file
@@ -0,0 +1,47 @@
|
||||
export async function downloadToBuffer(url, onProgress) {
|
||||
const resp = await fetch(url);
|
||||
|
||||
const contentLength = +resp.headers.get("Content-Length");
|
||||
const buf = new Uint8ClampedArray(contentLength);
|
||||
const reader = resp.body.getReader();
|
||||
let cursor = 0;
|
||||
|
||||
if (onProgress) {
|
||||
onProgress(0);
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const {done, value} = await reader.read();
|
||||
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
console.log(`Sending ${value.length} bytes into ${buf.length} at offset ${cursor} (${buf.length - cursor} free)`)
|
||||
buf.set(value, cursor);
|
||||
cursor += value.length;
|
||||
|
||||
if (onProgress) {
|
||||
onProgress(cursor / contentLength);
|
||||
}
|
||||
}
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
export function argMax(array) {
|
||||
return array
|
||||
.map((x, i) => [x, i])
|
||||
.reduce((r, a) => (a[0] > r[0] ? a : r))[1];
|
||||
}
|
||||
|
||||
export function toInt64(array) {
|
||||
return new BigInt64Array(array.map(BigInt));
|
||||
}
|
||||
|
||||
export const ORT_WASM_PATHS = {
|
||||
"ort-wasm-simd.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-simd.wasm",
|
||||
"ort-wasm.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm.wasm",
|
||||
"ort-wasm-simd-threaded.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-simd-threaded.wasm",
|
||||
"ort-wasm-threaded.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-threaded.wasm",
|
||||
}
|
||||
Reference in New Issue
Block a user