Skip to content

Commit

Permalink
feat: more f1 averages
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelkallis committed Apr 5, 2021
1 parent c779bb4 commit 46f5bc8
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions bin/tickettagger.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ const chalk = require("chalk");
const ConfusionMatrix = require("ml-confusion-matrix");
const { DatasetManager } = require("../src/dataset-manager");

const MODEL_DIR = path.join(os.tmpdir(), "ticket-tagger", "models");
const DATASET_DIR = path.join(os.tmpdir(), "ticket-tagger", "datasets");
const MODEL_DIR = path.join(os.homedir(), ".tickettagger/models");
const DATASET_DIR = path.join(os.homedir(), ".tickettagger/datasets");

fs.mkdirSync(MODEL_DIR, { recursive: true });
fs.mkdirSync(DATASET_DIR, { recursive: true });
Expand All @@ -44,12 +44,16 @@ const labels = ["__label__bug", "__label__enhancement", "__label__question"];
const datasetTable = {
["30k"]:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-30493-real.csv",
["unbalanced"]:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-30493-real.csv",
["127k"]:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-127k.txt",
["397k"]:
"https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-397k.txt",
["30k-balanced"]:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt",
["balanced"]:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt",
english:
"https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues_english.txt",
["english:baseline"]:
Expand Down Expand Up @@ -353,7 +357,22 @@ async function evaluateInline(datasetPath, classifier, actual, predicted) {
function printStats({ actual, predicted }) {
const cm = ConfusionMatrix.fromLabels(actual, predicted);
console.log(chalk.bgMagenta(" stats "));
console.log(chalk.magenta("accuracy: "), cm.getAccuracy().toFixed(3));
const weights = Object.fromEntries(
labels.map((l) => [l, actual.filter((a) => a === l).length / actual.length])
);
console.log(
chalk.magenta("f1 weighted: "),
sum(labels.map((l) => cm.getF1Score(l) * weights[l])).toFixed(3)
);
const TP = sum(labels.map((l) => cm.getTruePositiveCount(l)));
const FP = sum(labels.map((l) => cm.getFalsePositiveCount(l)));
const FN = sum(labels.map((l) => cm.getFalseNegativeCount(l)));
const microF1 = (2 * TP) / (2 * TP + FP + FN);
console.log(chalk.magenta("f1 micro: "), microF1.toFixed(3));
console.log(
chalk.magenta("f1 macro: "),
sum(labels.map((l) => cm.getF1Score(l) / labels.length)).toFixed(3)
);

for (const label of labels) {
console.log(chalk.bgMagenta(` ${label.substring(9)} `));
Expand All @@ -368,3 +387,7 @@ function printStats({ actual, predicted }) {
console.log(chalk.magenta("f1 score: "), cm.getF1Score(label).toFixed(3));
}
}

function sum(values) {
return values.reduce((acc, next) => acc + next, 0);
}

0 comments on commit 46f5bc8

Please sign in to comment.