diff --git a/bin/tickettagger.js b/bin/tickettagger.js index b5a3e08f..d63bfe9a 100644 --- a/bin/tickettagger.js +++ b/bin/tickettagger.js @@ -42,10 +42,14 @@ fs.mkdirSync(DATASET_DIR, { recursive: true }); const datasetManager = new DatasetManager({ DATASET_DIR }); const labels = ["__label__bug", "__label__enhancement", "__label__question"]; const datasetTable = { - balanced: - "https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt", - unbalanced: + ['30k']: "https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-30493-real.csv", + ['127k']: + "https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-127k.txt", + ['397k']: + "https://tickettagger.blob.core.windows.net/datasets/github-labels-top3-real-397k.txt", + ['30k-balanced']: + "https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues.txt", english: "https://gist.githubusercontent.com/rafaelkallis/6aa281b00d73d77fc843bd34f8184854/raw/8c10ebf2fd6f937f8667c660ea33d122bac739eb/issues_english.txt", ["english:baseline"]: @@ -112,10 +116,13 @@ const filterHyperparameters = (opts) => ) ); +console.log(chalk.magenta(`tickettagger, Copyright (C) ${new Date().getFullYear()} Rafael Kallis, GPL-v3 license\n`)) + yargs(process.argv.slice(2)) .scriptName("tickettagger") + .usage("$0 ") .command({ - command: "benchmark ", + command: `${chalk.magenta(benchmark)} `, description: "Run benchmarks on Ticket-Tagger.", builder: (yargs) => yargs @@ -191,6 +198,37 @@ yargs(process.argv.slice(2)) handler: crossHandler, }), }) + .command({ + command: `${chalk.magenta(train)} `, + description: "Train a model.", + builder: (yargs) => + withHyperparameterOptions(yargs) + .positional( + "dataset", + datasetOption({ + description: + "The dataset (key or URL) to train the model with.", + }) + ) + .option("force", { + type: "boolean", + default: false, + description: + "Force a new download even if the data is present locally.", + }) + .example([ + [ + "$0 train 127k result", + "Train a model using the 127k dataset and output to 'result.bin'.", + ], + ]), + handler: trainHandler, + }) + .command({ + command: "clean", + description: "Clean the dataset + model cache.", + handler: cleanHandler, + }) .demandCommand() .help() .parse(); @@ -214,14 +252,16 @@ async function trivialHandler({ ) ); } - const modelPath = path.join(MODEL_DIR, `${trainingset.id}.bin`); + const modelPath = path.join(MODEL_DIR, `${trainingset.id}`); await classifier.train("supervised", { input: trainingset.path, output: modelPath, ...filterHyperparameters(opts), }); await classifier.loadModel(modelPath); - const { actual, predicted } = await evaluate(testset.path, classifier); + const actual = []; + const predicted = []; + await evaluateInline(testset.path, classifier, actual, predicted); printStats({ actual, predicted }); } @@ -240,40 +280,66 @@ async function crossHandler({ dataset: datasetUri, folds, force, ...opts }) { run, force, }); - const modelPath = path.join(MODEL_DIR, `${id}.bin`); + const modelPath = path.join(MODEL_DIR, `${id}`); await classifier.train("supervised", { input: trainPath, output: modelPath, ...filterHyperparameters(opts), }); await classifier.loadModel(modelPath); - const { actual: runActual, predicted: runPredicted } = await evaluate( - testPath, - classifier - ); - actual.push(...runActual); - predicted.push(...runPredicted); + await evaluateInline(testPath, classifier, actual, predicted); console.log(chalk.magenta(`run ${run + 1}/${folds} finished`)); } printStats({ actual, predicted }); } -async function evaluate(datasetPath, classifier) { +/** + * Train a model. + */ +async function trainHandler({ dataset: datasetUri, model: modelPath, force, ...opts }) { + const dataset = await datasetManager.fetch(datasetUri, force); + const classifier = new Classifier(); + await classifier.train("supervised", { + input: dataset.path, + output: modelPath, + ...filterHyperparameters(opts), + }); +} + +function cleanHandler({}) { + for (const datasetPath of fs.readdirSync(DATASET_DIR)) { + fs.unlinkSync(path.join(DATASET_DIR, datasetPath)); + } + for (const modelPath of fs.readdirSync(MODEL_DIR)) { + fs.unlinkSync(path.join(MODEL_DIR, modelPath)); + } +} + +async function* evaluateIter(datasetPath, classifier) { const lines = readline.createInterface({ input: fs.createReadStream(datasetPath), }); - const actualList = []; - const predictedList = []; for await (const line of lines) { + if (!/__label__[a-zA-Z0-9]+/.test(line)) { + console.warn(chalk.yellow("found line with no label, skipping line")); + continue; + } const [actual] = line.match(/__label__[a-zA-Z0-9]+/); const text = line.substring(actual.length); - const [prediction = { label: null }] = await classifier.predict(text, 1); - actualList.push(actual); - predictedList.push(prediction.label); + const [predictionResult = { label: null }] = await classifier.predict(text, 1); + const predicted = predictionResult.label; + yield { actual, predicted }; } - return { actual: actualList, predicted: predictedList }; } +async function evaluateInline(datasetPath, classifier, actual, predicted) { + for await (const recordResult of evaluateIter(datasetPath, classifier)) { + const { actual: recordActual, predicted: recordPredicted } = recordResult; + actual.push(recordActual); + predicted.push(recordPredicted); + } +} + function printStats({ actual, predicted }) { const cm = ConfusionMatrix.fromLabels(actual, predicted); console.log(chalk.bgMagenta(" stats ")); diff --git a/package-lock.json b/package-lock.json index 51dd4f70..9cc5b3b9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,6 +1,6 @@ { "name": "tickettagger", - "version": "2.1.0", + "version": "2.1.1", "lockfileVersion": 1, "requires": true, "dependencies": { @@ -4361,7 +4361,7 @@ }, "strip-ansi": { "version": "3.0.1", - "resolved": "http://registry.npmjs.org/strip-ansi/-/strip-ansi-3.0.1.tgz", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-3.0.1.tgz", "integrity": "sha1-ajhfuIU9lS1f8F0Oiq+UJ43GPc8=", "requires": { "ansi-regex": "^2.0.0" @@ -7152,7 +7152,7 @@ }, "readable-stream": { "version": "2.3.6", - "resolved": "http://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", "integrity": "sha512-tQtKA9WIAhBF3+VLAseyMqZeBjW0AHJoxOtYqSUZNJxauErmLbVm2FW1y+J/YA9dUrAC39ITejlZWhVIwawkKw==", "requires": { "core-util-is": "~1.0.0", diff --git a/package.json b/package.json index 1338d567..4ab72931 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "tickettagger", - "version": "2.1.1", + "version": "2.1.2", "description": "Machine learning driven issue classification bot.", "license": "GPL-3.0", "repository": {