Skip to content

Commit

Permalink
Rename Model to ModelNgram and Models to Model
Browse files Browse the repository at this point in the history
The old Models struct is the main interface to access the model, so it
should be called like that. The inner models are now called ModelNgram
because each one only stores one ngram order.
  • Loading branch information
ZJaume committed Sep 13, 2024
1 parent d270ee0 commit 2c6d366
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use env_logger::Env;
use strum::IntoEnumIterator;
use target;

use crate::languagemodel::{Model, ModelType};
use crate::languagemodel::{ModelNgram, OrderNgram};
use crate::identifier::Identifier;
use crate::utils::Abort;
use crate::python::module_path;
Expand Down Expand Up @@ -46,10 +46,10 @@ impl BinarizeCmd {
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels"));
let save_path = self.output_dir.unwrap_or(module_path().unwrap());

for model_type in ModelType::iter() {
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = Model::from_text(&model_path, model_type)
let model = ModelNgram::from_text(&model_path, model_type)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
Expand Down
14 changes: 7 additions & 7 deletions src/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use log::{debug,warn};
use lazy_static::lazy_static;
use rayon::prelude::*;

use crate::languagemodel::{Models};
use crate::languagemodel::Model;
use crate::lang::{Lang, LangScores, LangBitmap};

lazy_static! {
Expand All @@ -20,7 +20,7 @@ lazy_static! {
}

pub struct Identifier {
models: Arc<Models>,
model: Arc<Model>,
lang_scored: LangBitmap,
lang_points: LangScores,
word_scores: LangScores,
Expand All @@ -33,12 +33,12 @@ impl Identifier {
const MAX_NGRAM : usize = 6;

pub fn load(modelpath: &str) -> Result<Self> {
Ok(Self::new(Arc::new(Models::load(modelpath)?)))
Ok(Self::new(Arc::new(Model::load(modelpath)?)))
}

pub fn new(models: Arc<Models>) -> Self {
pub fn new(model: Arc<Model>) -> Self {
Self {
models: models,
model: model,
lang_scored: LangBitmap::new(),
lang_points: LangScores::new(),
word_scores: LangScores::new(),
Expand Down Expand Up @@ -96,7 +96,7 @@ impl Identifier {

/// Update scores according to current ngram probability if found
fn score_gram(&mut self, gram: &str, dic_id: usize) -> bool {
if let Some(kiepro) = self.models[dic_id].dic.get(gram) {
if let Some(kiepro) = self.model[dic_id].dic.get(gram) {
// found the word in language model
// update scores according to each lang that has the word
// use penalty value for langs that don't have the word
Expand Down Expand Up @@ -301,7 +301,7 @@ impl Identifier {
// Only initialize the identifier once
let mut identifier = identifier.lock().unwrap();
if identifier.is_none() {
*identifier = Some(Identifier::new(self.models.clone()));
*identifier = Some(Identifier::new(self.model.clone()));
}
identifier.as_mut().unwrap().identify(&text)
})
Expand Down
36 changes: 18 additions & 18 deletions src/languagemodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::lang::Lang;
#[derive(bitcode::Encode, bitcode::Decode, EnumIter, Display, EnumCount,
Debug, PartialEq, Clone, Copy)]
#[strum(serialize_all = "lowercase")]
pub enum ModelType {
pub enum OrderNgram {
Word,
Unigram,
Bigram,
Expand All @@ -32,21 +32,21 @@ pub enum ModelType {


#[derive(bitcode::Encode, bitcode::Decode, Debug, PartialEq)]
pub struct Model {
pub struct ModelNgram {
pub dic: HashMap<String, Vec<(Lang, f32)>, MyHasher>,
pub model_type: ModelType,
pub model_type: OrderNgram,
}

impl Model {
impl ModelNgram {
// The following values are the ones used in Jauhiainen et al. 2017.
pub const MAX_USED : f64 = 0.0000005;

pub fn contains(&self, key: &str) -> bool {
self.dic.contains_key(key)
}

pub fn from_text(model_dir: &Path, model_type: ModelType) -> Result<Self> {
let mut model = Model {
pub fn from_text(model_dir: &Path, model_type: OrderNgram) -> Result<Self> {
let mut model = ModelNgram {
dic: HashMap::default(),
model_type: model_type.clone()
};
Expand Down Expand Up @@ -155,15 +155,15 @@ impl Model {
}
}

pub struct Models {
inner: [Model; ModelType::COUNT],
pub struct Model {
inner: [ModelNgram; OrderNgram::COUNT],
}

impl Models {
impl Model {
pub fn load(modelpath: &str) -> Result<Self> {
// Run a separated thread to load each model
let mut handles: Vec<thread::JoinHandle<_>> = Vec::new();
for model_type in ModelType::iter() {
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
let filename = format!("{modelpath}/{type_repr}.bin");

Expand All @@ -178,10 +178,10 @@ impl Models {
return Err(io::Error::new(io::ErrorKind::NotFound, message).into());
}
handles.push(thread::spawn(move || {
let model = Model::from_bin(&filename)?;
let model = ModelNgram::from_bin(&filename)?;
// check model type is correct
assert!(model.model_type == model_type);
Ok::<Model, anyhow::Error>(model)
Ok::<ModelNgram, anyhow::Error>(model)
}));
}

Expand All @@ -201,8 +201,8 @@ impl Models {
}

// to avoid calling inner value
impl Index<usize> for Models {
type Output = Model;
impl Index<usize> for Model {
type Output = ModelNgram;

fn index(&self, num: usize) -> &Self::Output {
&self.inner[num]
Expand All @@ -220,22 +220,22 @@ mod tests {
#[test]
fn test_langs() {
let modelpath = Path::new("./LanguageModels");
let wordmodel = Model::from_text(&modelpath, ModelType::Word);
let wordmodel = ModelNgram::from_text(&modelpath, OrderNgram::Word);
let path = Path::new("wordict.ser");
wordmodel.save(path);

let charmodel = Model::from_text(&modelpath, ModelType::Quadgram);
let charmodel = ModelNgram::from_text(&modelpath, OrderNgram::Quadgram);
let path = Path::new("gramdict.ser");
charmodel.save(path);

let char_handle = thread::spawn(move || {
let path = Path::new("gramdict.ser");
Model::from_bin(path)
ModelNgram::from_bin(path)
});

let word_handle = thread::spawn(move || {
let path = Path::new("wordict.ser");
Model::from_bin(path)
ModelNgram::from_bin(path)
});

// let word_model = word_handle.join().unwrap();
Expand Down

0 comments on commit 2c6d366

Please sign in to comment.