Skip to content

Commit

Permalink
feat: train and clean commands, 127k dataset, 397k dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelkallis committed Nov 6, 2020
1 parent 70a404d commit c7a7508
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 24 deletions.
106 changes: 86 additions & 20 deletions bin/tickettagger.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ fs.mkdirSync(DATASET_DIR, { recursive: true });
const datasetManager = new DatasetManager({ DATASET_DIR });
const labels = ["__label__bug", "__label__enhancement", "__label__question"];
const datasetTable = {
balanced:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt",
unbalanced:
['30k']:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-30493-real.csv",
['127k']:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-127k.txt",
['397k']:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-397k.txt",
['30k-balanced']:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt",
english:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues_english.txt",
["english:baseline"]:
Expand Down Expand Up @@ -112,10 +116,13 @@ const filterHyperparameters = (opts) =>
)
);

console.log(chalk.magenta(`tickettagger, Copyright (C) ${new Date().getFullYear()} Rafael Kallis, GPL-v3 license\n`))

yargs(process.argv.slice(2))
.scriptName("tickettagger")
.usage("$0 <command>")
.command({
command: "benchmark <mode>",
command: `${chalk.magenta(benchmark)} <mode>`,
description: "Run benchmarks on Ticket-Tagger.",
builder: (yargs) =>
yargs
Expand Down Expand Up @@ -191,6 +198,37 @@ yargs(process.argv.slice(2))
handler: crossHandler,
}),
})
.command({
command: `${chalk.magenta(train)} <dataset> <model>`,
description: "Train a model.",
builder: (yargs) =>
withHyperparameterOptions(yargs)
.positional(
"dataset",
datasetOption({
description:
"The dataset (key or URL) to train the model with.",
})
)
.option("force", {
type: "boolean",
default: false,
description:
"Force a new download even if the data is present locally.",
})
.example([
[
"$0 train 127k result",
"Train a model using the 127k dataset and output to 'result.bin'.",
],
]),
handler: trainHandler,
})
.command({
command: "clean",
description: "Clean the dataset + model cache.",
handler: cleanHandler,
})
.demandCommand()
.help()
.parse();
Expand All @@ -214,14 +252,16 @@ async function trivialHandler({
)
);
}
const modelPath = path.join(MODEL_DIR, `${trainingset.id}.bin`);
const modelPath = path.join(MODEL_DIR, `${trainingset.id}`);
await classifier.train("supervised", {
input: trainingset.path,
output: modelPath,
...filterHyperparameters(opts),
});
await classifier.loadModel(modelPath);
const { actual, predicted } = await evaluate(testset.path, classifier);
const actual = [];
const predicted = [];
await evaluateInline(testset.path, classifier, actual, predicted);
printStats({ actual, predicted });
}

Expand All @@ -240,40 +280,66 @@ async function crossHandler({ dataset: datasetUri, folds, force, ...opts }) {
run,
force,
});
const modelPath = path.join(MODEL_DIR, `${id}.bin`);
const modelPath = path.join(MODEL_DIR, `${id}`);
await classifier.train("supervised", {
input: trainPath,
output: modelPath,
...filterHyperparameters(opts),
});
await classifier.loadModel(modelPath);
const { actual: runActual, predicted: runPredicted } = await evaluate(
testPath,
classifier
);
actual.push(...runActual);
predicted.push(...runPredicted);
await evaluateInline(testPath, classifier, actual, predicted);
console.log(chalk.magenta(`run ${run + 1}/${folds} finished`));
}
printStats({ actual, predicted });
}

async function evaluate(datasetPath, classifier) {
/**
* Train a model.
*/
async function trainHandler({ dataset: datasetUri, model: modelPath, force, ...opts }) {
const dataset = await datasetManager.fetch(datasetUri, force);
const classifier = new Classifier();
await classifier.train("supervised", {
input: dataset.path,
output: modelPath,
...filterHyperparameters(opts),
});
}

function cleanHandler({}) {
for (const datasetPath of fs.readdirSync(DATASET_DIR)) {
fs.unlinkSync(path.join(DATASET_DIR, datasetPath));
}
for (const modelPath of fs.readdirSync(MODEL_DIR)) {
fs.unlinkSync(path.join(MODEL_DIR, modelPath));
}
}

async function* evaluateIter(datasetPath, classifier) {
const lines = readline.createInterface({
input: fs.createReadStream(datasetPath),
});
const actualList = [];
const predictedList = [];
for await (const line of lines) {
if (!/__label__[a-zA-Z0-9]+/.test(line)) {
console.warn(chalk.yellow("found line with no label, skipping line"));
continue;
}
const [actual] = line.match(/__label__[a-zA-Z0-9]+/);
const text = line.substring(actual.length);
const [prediction = { label: null }] = await classifier.predict(text, 1);
actualList.push(actual);
predictedList.push(prediction.label);
const [predictionResult = { label: null }] = await classifier.predict(text, 1);
const predicted = predictionResult.label;
yield { actual, predicted };
}
return { actual: actualList, predicted: predictedList };
}

async function evaluateInline(datasetPath, classifier, actual, predicted) {
for await (const recordResult of evaluateIter(datasetPath, classifier)) {
const { actual: recordActual, predicted: recordPredicted } = recordResult;
actual.push(recordActual);
predicted.push(recordPredicted);
}
}

function printStats({ actual, predicted }) {
const cm = ConfusionMatrix.fromLabels(actual, predicted);
console.log(chalk.bgMagenta(" stats "));
Expand Down
6 changes: 3 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "tickettagger",
"version": "2.1.1",
"version": "2.1.2",
"description": "Machine learning driven issue classification bot.",
"license": "GPL-3.0",
"repository": {
Expand Down

0 comments on commit c7a7508

Please sign in to comment.