This commit is contained in:
simon987
2023-07-24 19:36:20 -04:00
parent f56cfb0f2f
commit 27188b6fa0
29 changed files with 1008 additions and 75 deletions

View File

@@ -0,0 +1,118 @@
const inf = Number.POSITIVE_INFINITY;
const START_TOK = 49406;
const END_TOK = 49407;
function min(array, key) {
return array
.reduce((a, b) => (key(a, b) ? b : a))
}
class TupleSet extends Set {
add(elem) {
return super.add(elem.join("`"));
}
has(elem) {
return super.has(elem.join("`"));
}
toList() {
return [...this].map(x => x.split("`"))
}
}
export class BPETokenizer {
_encoder = null;
_bpeRanks = null;
constructor(encoder, bpeRanks) {
this._encoder = encoder;
this._bpeRanks = bpeRanks;
}
getPairs(word) {
const pairs = new TupleSet();
let prevChar = word[0];
for (let i = 1; i < word.length; i++) {
pairs.add([prevChar, word[i]])
prevChar = word[i];
}
return pairs.toList();
}
bpe(token) {
let word = [...token];
word[word.length - 1] += "</w>";
let pairs = this.getPairs(word)
if (pairs.length === 0) {
return token + "</w>"
}
while (true) {
const bigram = min(pairs, (a, b) => {
return (this._bpeRanks[a.join("`")] ?? inf) > (this._bpeRanks[b.join("`") ?? inf])
});
if (this._bpeRanks[bigram.join("`")] === undefined) {
break;
}
const [first, second] = bigram;
let newWord = [];
let i = 0;
while (i < word.length) {
const j = word.indexOf(first, i);
if (j === -1) {
newWord.push(...word.slice(i));
break;
} else {
newWord.push(...word.slice(i, j));
i = j;
}
if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
newWord.push(first + second);
i += 2;
} else {
newWord.push(word[i]);
i += 1;
}
}
word = [...newWord]
if (word.length === 1) {
break;
} else {
pairs = this.getPairs(word);
}
}
return word.join(" ");
}
encode(text) {
let bpeTokens = [];
text = text.trim();
text = text.replaceAll(/\s+/g, " ");
text
.match(/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z0-9]+/ig)
.forEach(token => {
bpeTokens.push(...this.bpe(token).split(" ").map(t => this._encoder[t]));
});
bpeTokens.unshift(START_TOK);
bpeTokens = bpeTokens.slice(0, 76);
bpeTokens.push(END_TOK);
while (bpeTokens.length < 77) {
bpeTokens.push(0);
}
return bpeTokens;
}
}

View File

@@ -1,6 +1,8 @@
import BertTokenizer from "@/ml/BertTokenizer";
import * as tf from "@tensorflow/tfjs";
import axios from "axios";
import {chunk as _chunk} from "underscore";
import * as ort from "onnxruntime-web";
import {argMax, downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
export default class BertNerModel {
vocabUrl;
@@ -29,7 +31,10 @@ export default class BertNerModel {
}
async loadModel(onProgress) {
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
const buf = await downloadToBuffer(this.modelUrl, onProgress);
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
}
alignLabels(labels, wordIds, words) {
@@ -57,21 +62,28 @@ export default class BertNerModel {
async predict(text, callback) {
this._previousWordId = null;
const encoded = this._tokenizer.encodeText(text, this.inputSize)
const encoded = this._tokenizer.encodeText(text, this.inputSize);
let i = 0;
for (let chunk of encoded.inputChunks) {
const rawResult = tf.tidy(() => this._model.execute({
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
}));
const labelIds = await tf.argMax(rawResult, -1);
const labelIdsArray = await labelIds.array();
const labels = labelIdsArray[0].map(id => this.id2label[id]);
rawResult.dispose()
const results = await this._model.run({
input_ids: new ort.Tensor("int32", chunk.inputIds, [1, this.inputSize]),
token_type_ids: new ort.Tensor("int32", chunk.segmentIds, [1, this.inputSize]),
attention_mask: new ort.Tensor("int32", chunk.inputMask, [1, this.inputSize]),
});
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
const labelIds = _chunk(results["output"].data, this.id2label.length).map(argMax);
const labels = labelIds.map(id => this.id2label[id]);
callback(this.alignLabels(labels, chunk.wordIds, encoded.words));
i += 1;
// give browser some time to repaint
if (i % 2 === 0) {
await new Promise(resolve => setTimeout(resolve, 0));
}
}
}
}

View File

@@ -1,4 +1,5 @@
import {zip, chunk} from "underscore";
import {toInt64} from "@/ml/mlUtils";
const UNK_INDEX = 100;
const CLS_INDEX = 101;

View File

@@ -0,0 +1,48 @@
import * as ort from "onnxruntime-web";
import {BPETokenizer} from "@/ml/BPETokenizer";
import axios from "axios";
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
export class CLIPTransformerModel {
_modelUrl = null;
_tokenizerUrl = null;
_model = null;
_tokenizer = null;
constructor(modelUrl, tokenizerUrl) {
this._modelUrl = modelUrl;
this._tokenizerUrl = tokenizerUrl;
}
async init(onProgress) {
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
}
async loadModel(onProgress) {
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
const buf = await downloadToBuffer(this._modelUrl, onProgress);
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
}
async loadTokenizer() {
const resp = await axios.get(this._tokenizerUrl);
this._tokenizer = new BPETokenizer(resp.data.encoder, resp.data.bpe_ranks)
}
async predict(text) {
const tokenized = this._tokenizer.encode(text);
const feeds = {
input_ids: new ort.Tensor("int32", tokenized, [1, 77])
};
const results = await this._model.run(feeds);
return Array.from(
Object.values(results)
.find(result => result.size === 512).data
);
}
}

View File

@@ -0,0 +1,47 @@
export async function downloadToBuffer(url, onProgress) {
const resp = await fetch(url);
const contentLength = +resp.headers.get("Content-Length");
const buf = new Uint8ClampedArray(contentLength);
const reader = resp.body.getReader();
let cursor = 0;
if (onProgress) {
onProgress(0);
}
while (true) {
const {done, value} = await reader.read();
if (done) {
break;
}
console.log(`Sending ${value.length} bytes into ${buf.length} at offset ${cursor} (${buf.length - cursor} free)`)
buf.set(value, cursor);
cursor += value.length;
if (onProgress) {
onProgress(cursor / contentLength);
}
}
return buf;
}
export function argMax(array) {
return array
.map((x, i) => [x, i])
.reduce((r, a) => (a[0] > r[0] ? a : r))[1];
}
export function toInt64(array) {
return new BigInt64Array(array.map(BigInt));
}
export const ORT_WASM_PATHS = {
"ort-wasm-simd.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-simd.wasm",
"ort-wasm.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm.wasm",
"ort-wasm-simd-threaded.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-simd-threaded.wasm",
"ort-wasm-threaded.wasm": "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.15.1/dist/ort-wasm-threaded.wasm",
}