diff --git a/.appveyor.yml b/.appveyor.yml new file mode 100644 index 0000000..caa3834 --- /dev/null +++ b/.appveyor.yml @@ -0,0 +1,56 @@ +environment: + matrix: + - PYTHON: "C:\\Python27-x64" + TARGET: x86_64-pc-windows-msvc + - PYTHON: "C:\\Python36-x64" + TARGET: x86_64-pc-windows-msvc + - PYTHON: "C:\\Python37-x64" + TARGET: x86_64-pc-windows-msvc + +branches: + only: + - develop + - master + +install: + - if "%APPVEYOR_REPO_BRANCH%" == "develop" if NOT "%APPVEYOR_PULL_REQUEST_HEAD_REPO_BRANCH:~0,8%" == "release/" if NOT "%APPVEYOR_PULL_REQUEST_HEAD_REPO_BRANCH:~0,5%" == "main/" appveyor exit + - appveyor-retry appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init.exe -y --default-host %TARGET% + - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin + - if defined MSYS2_BITS set PATH=%PATH%;C:\msys64\mingw%MSYS2_BITS%\bin + - rustc -V + - cargo -V + - ps: (Get-Content python/ffi/Cargo.toml) | ForEach-Object { $_ -replace "^snips-nlu-parsers-ffi-macros = .*$", "snips-nlu-parsers-ffi-macros = { path = `"../../ffi/ffi-macros`" }" } | Set-Content python/ffi/Cargo.toml + - "%PYTHON%\\python.exe -m pip install -r python/requirements.txt" + +build: false + +test_script: + - cargo build --verbose + - cargo test --all --verbose + - cd python + - "%PYTHON%\\python.exe -m pip install -e . --verbose --install-option=\"--verbose\"" + - "%PYTHON%\\python.exe -m unittest discover" + +after_test: + - ECHO "BUILDING WHEELS..." + - "%PYTHON%\\python.exe setup.py bdist_wheel" + +artifacts: + - path: python\dist\* + name: pypiartifacts + +for: +- + branches: + only: + - master + + environment: + matrix: + - PYTHON: "C:\\Python27" + TARGET: i686-pc-windows-msvc + - PYTHON: "C:\\Python36" + TARGET: i686-pc-windows-msvc + - PYTHON: "C:\\Python37" + TARGET: i686-pc-windows-msvc diff --git a/.gitignore b/.gitignore index 088ba6b..6fb81ff 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Generated by Cargo # will have compiled files and executables /target/ +*/target +*/**/target/ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html @@ -8,3 +10,6 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Intellij +/.idea diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..ba6eb54 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,54 @@ +language: python +conditions: v1 + +matrix: + include: + - if: branch = master + os: osx + osx_image: xcode8 + language: generic + env: + - TOXENV=py27 + - if: branch = master + os: osx + osx_image: xcode8 + language: generic + env: + - TOXENV=py36 + - if: branch = master + os: osx + osx_image: xcode8 + language: generic + env: + - TOXENV=py37 + - if: branch = master + os: linux + python: 2.7 + env: + - TOXENV=py27 + - os: linux + python: 3.6 + env: + - TOXENV=py36 + - if: branch = master + os: linux + python: 3.7 + # cf https://github.com/travis-ci/travis-ci/issues/9815 + dist: xenial + sudo: true + env: + - TOXENV=py37 + +before_install: . ./.travis/before_install.sh + +install: + - ./.travis/install.sh + +script: + - ./.travis/test.sh + +before_script: + - echo $TRAVIS_COMMIT + - echo $TRAVIS_TAG + - echo $TRAVIS_BRANCH + - echo $TRAVIS_BUILD_NUMBER diff --git a/.travis/before_install.sh b/.travis/before_install.sh new file mode 100755 index 0000000..508efd7 --- /dev/null +++ b/.travis/before_install.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Install Rust +curl https://sh.rustup.rs -sSf | bash -s -- -y +if [ "$TRAVIS_OS_NAME" == "osx" ]; then + brew update || brew update + + brew outdated openssl || brew upgrade openssl + brew install openssl@1.1 + + # install pyenv + git clone --depth 1 https://github.com/pyenv/pyenv ~/.pyenv + PYENV_ROOT="$HOME/.pyenv" + PATH="$PYENV_ROOT/bin:$PATH" + eval "$(pyenv init -)" + + case "${TOXENV}" in + py27) + pyenv install 2.7.14 + pyenv global 2.7.14 + ;; + py36) + pyenv install 3.6.1 + pyenv global 3.6.1 + ;; + py37) + pyenv install 3.7.2 + pyenv global 3.7.2 + ;; + esac + pyenv rehash + + # A manual check that the correct version of Python is running. + python --version +fi diff --git a/.travis/common.sh b/.travis/common.sh new file mode 100644 index 0000000..95d4662 --- /dev/null +++ b/.travis/common.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +export PATH="/usr/local/bin:$HOME/.cargo/bin:$PATH" + +PYTHON_PATH=$(which python"$PYTHON_VERSION") +COMMIT_ID=$(git rev-parse --short HEAD) +VENV_PATH="/tmp/venv$PYTHON_VERSION-$COMMIT_ID" + +warn() { echo "$@" >&2; } + +die() { warn "$@"; exit 1; } + +escape() { + echo "$1" | sed 's/\([\.\$\*]\)/\\\1/g' +} + +has() { + local item=$1; shift + echo " $@ " | grep -q " $(escape $item) " +} diff --git a/.travis/install.sh b/.travis/install.sh new file mode 100755 index 0000000..746d37e --- /dev/null +++ b/.travis/install.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -ev + +source .travis/common.sh + +perl -p -i -e \ + "s/^snips-nlu-parsers-ffi-macros = .*\$/snips-nlu-parsers-ffi-macros = { path = \"..\/..\/ffi\/ffi-macros\" \}/g" \ + python/ffi/Cargo.toml diff --git a/.travis/test.sh b/.travis/test.sh new file mode 100755 index 0000000..ae01113 --- /dev/null +++ b/.travis/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -ev + +source .travis/common.sh + +echo "Running rust tests..." +export PATH="$HOME/.cargo/bin:$PATH" +cargo build --all +cargo test --all + +if [ "$TRAVIS_BRANCH" == "master" ]; then + echo "Running python tests..." + cd python + python -m pip install tox + tox + cd ../.. +fi diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..1e17896 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,2 @@ +# Changelog +All notable changes to this project will be documented in this file. diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c35723a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "snips-nlu-parsers" +version = "0.1.0" +authors = ["Adrien Ball "] +edition = "2018" + +[workspace] +members = [ + ".", + "ffi", + "ffi/ffi-macros", + "python/ffi" +] + +[dependencies] +failure = "0.1" +itertools = "0.7" +lazy_static = "1.0" +regex = "0.2" +serde = "1.0" +serde_derive = "1.0" +serde_json = "1.0" +gazetteer-entity-parser = { git = "https://github.com/snipsco/gazetteer-entity-parser", tag = "0.6.0" } +rustling-ontology = { git = "https://github.com/snipsco/rustling-ontology", tag = "0.17.7" } +snips-nlu-ontology = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-utils = { git = "https://github.com/snipsco/snips-nlu-utils", tag = "0.7.1" } + +[dev-dependencies] +tempfile = "3.0" diff --git a/LICENSE b/LICENSE index e27e797..09250ca 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,12 @@ -MIT License +## License -Copyright (c) 2019 Snips +Licensed under either of + * Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) +at your option. -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +### Contribution -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall +be dual licensed as above, without any additional terms or conditions. diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..16fe87b --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..31aa793 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index 5baca46..0000000 --- a/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# snips-nlu-parsers -Rust crate for entity parsing diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..c7e9399 --- /dev/null +++ b/README.rst @@ -0,0 +1,61 @@ +Snips NLU Parsers +================= + +.. image:: https://travis-ci.org/snipsco/snips-nlu-parsers.svg?branch=develop + :target: https://travis-ci.org/snipsco/snips-nlu-parsers + +.. image:: https://ci.appveyor.com/api/projects/status/github/snipsco/snips-nlu-parsers?branch=develop&svg=true + :target: https://ci.appveyor.com/project/snipsco/snips-nlu-parsers + +This crate provides APIs to extract entities in the context of a Natural Language Understanding (NLU) +task. + +Installation +------------ + +Add this to your ``Cargo.toml``: + +.. code-block:: toml + + [dependencies] + snips-nlu-parsers = { git = "https://github.com/snipsco/snips-nlu-parsers", tag = "0.1.0" } + + +Usage +----- + +.. code-block:: rust + + use snips_nlu_parsers::{BuiltinEntityKind, BuiltinEntityParserLoader, Language}; + + fn parse_entities() { + let parser = BuiltinEntityParserLoader::new(Language::EN).load().unwrap(); + let entities: Vec<(_, _)> = parser + .extract_entities("Book me restaurant for two people tomorrow", None) + .unwrap() + .into_iter() + .map(|e| (e.entity_kind, e.range)) + .collect(); + assert_eq!( + vec![ + (BuiltinEntityKind::Number, 23..26), + (BuiltinEntityKind::Time, 34..42) + ], + entities + ); + } + +License +------- + +Licensed under either of + * Apache License, Version 2.0 (`LICENSE-APACHE `_ or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license (`LICENSE-MIT `_) or http://opensource.org/licenses/MIT) +at your option. + +Contribution +------------ + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall +be dual licensed as above, without any additional terms or conditions. diff --git a/data/tests/builtin_entity_parser/gazetteer_entity_parser/metadata.json b/data/tests/builtin_entity_parser/gazetteer_entity_parser/metadata.json new file mode 100644 index 0000000..ec3ea1f --- /dev/null +++ b/data/tests/builtin_entity_parser/gazetteer_entity_parser/metadata.json @@ -0,0 +1,12 @@ +{ + "parsers_metadata": [ + { + "entity_identifier": "snips/musicArtist", + "entity_parser": "parser_1" + }, + { + "entity_identifier": "snips/musicTrack", + "entity_parser": "parser_2" + } + ] +} diff --git a/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/metadata.json b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/metadata.json new file mode 100644 index 0000000..eff8b1e --- /dev/null +++ b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.6,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/parser b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/parser new file mode 100644 index 0000000..6f10de7 Binary files /dev/null and b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_1/parser differ diff --git a/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/metadata.json b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/metadata.json new file mode 100644 index 0000000..23ea1a1 --- /dev/null +++ b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.7,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/parser b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/parser new file mode 100644 index 0000000..a6ad7c5 Binary files /dev/null and b/data/tests/builtin_entity_parser/gazetteer_entity_parser/parser_2/parser differ diff --git a/data/tests/builtin_entity_parser/metadata.json b/data/tests/builtin_entity_parser/metadata.json new file mode 100644 index 0000000..7cf1746 --- /dev/null +++ b/data/tests/builtin_entity_parser/metadata.json @@ -0,0 +1,4 @@ +{ + "language": "fr", + "gazetteer_parser": "gazetteer_entity_parser" +} diff --git a/data/tests/builtin_entity_parser_no_gazetteer/metadata.json b/data/tests/builtin_entity_parser_no_gazetteer/metadata.json new file mode 100644 index 0000000..a6042a4 --- /dev/null +++ b/data/tests/builtin_entity_parser_no_gazetteer/metadata.json @@ -0,0 +1,3 @@ +{ + "language": "en" +} diff --git a/data/tests/builtin_gazetteer_parser/metadata.json b/data/tests/builtin_gazetteer_parser/metadata.json new file mode 100644 index 0000000..ec3ea1f --- /dev/null +++ b/data/tests/builtin_gazetteer_parser/metadata.json @@ -0,0 +1,12 @@ +{ + "parsers_metadata": [ + { + "entity_identifier": "snips/musicArtist", + "entity_parser": "parser_1" + }, + { + "entity_identifier": "snips/musicTrack", + "entity_parser": "parser_2" + } + ] +} diff --git a/data/tests/builtin_gazetteer_parser/parser_1/metadata.json b/data/tests/builtin_gazetteer_parser/parser_1/metadata.json new file mode 100644 index 0000000..eff8b1e --- /dev/null +++ b/data/tests/builtin_gazetteer_parser/parser_1/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.6,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/builtin_gazetteer_parser/parser_1/parser b/data/tests/builtin_gazetteer_parser/parser_1/parser new file mode 100644 index 0000000..6f10de7 Binary files /dev/null and b/data/tests/builtin_gazetteer_parser/parser_1/parser differ diff --git a/data/tests/builtin_gazetteer_parser/parser_2/metadata.json b/data/tests/builtin_gazetteer_parser/parser_2/metadata.json new file mode 100644 index 0000000..23ea1a1 --- /dev/null +++ b/data/tests/builtin_gazetteer_parser/parser_2/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.7,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/builtin_gazetteer_parser/parser_2/parser b/data/tests/builtin_gazetteer_parser/parser_2/parser new file mode 100644 index 0000000..a6ad7c5 Binary files /dev/null and b/data/tests/builtin_gazetteer_parser/parser_2/parser differ diff --git a/data/tests/custom_gazetteer_parser/metadata.json b/data/tests/custom_gazetteer_parser/metadata.json new file mode 100644 index 0000000..418fcd9 --- /dev/null +++ b/data/tests/custom_gazetteer_parser/metadata.json @@ -0,0 +1,12 @@ +{ + "parsers_metadata": [ + { + "entity_identifier": "music_artist", + "entity_parser": "parser_1" + }, + { + "entity_identifier": "music_track", + "entity_parser": "parser_2" + } + ] +} diff --git a/data/tests/custom_gazetteer_parser/parser_1/metadata.json b/data/tests/custom_gazetteer_parser/parser_1/metadata.json new file mode 100644 index 0000000..eff8b1e --- /dev/null +++ b/data/tests/custom_gazetteer_parser/parser_1/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.6,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/custom_gazetteer_parser/parser_1/parser b/data/tests/custom_gazetteer_parser/parser_1/parser new file mode 100644 index 0000000..6f10de7 Binary files /dev/null and b/data/tests/custom_gazetteer_parser/parser_1/parser differ diff --git a/data/tests/custom_gazetteer_parser/parser_2/metadata.json b/data/tests/custom_gazetteer_parser/parser_2/metadata.json new file mode 100644 index 0000000..23ea1a1 --- /dev/null +++ b/data/tests/custom_gazetteer_parser/parser_2/metadata.json @@ -0,0 +1 @@ +{"version":"0.6.0","parser_filename":"parser","threshold":0.7,"stop_words":[],"edge_cases":[]} diff --git a/data/tests/custom_gazetteer_parser/parser_2/parser b/data/tests/custom_gazetteer_parser/parser_2/parser new file mode 100644 index 0000000..a6ad7c5 Binary files /dev/null and b/data/tests/custom_gazetteer_parser/parser_2/parser differ diff --git a/ffi/Cargo.toml b/ffi/Cargo.toml new file mode 100644 index 0000000..093cca1 --- /dev/null +++ b/ffi/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "snips-nlu-parsers-ffi" +version = "0.1.0" +authors = ["Adrien Ball "] +edition = "2018" + +[dependencies] +failure = "0.1" +ffi-utils = { git = "https://github.com/snipsco/snips-utils-rs", rev = "4292ad9" } +libc = "0.2" +snips-nlu-ontology = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-ontology-ffi-macros = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-parsers-ffi-macros = { path = "ffi-macros" } + +[lib] +crate-type = ["rlib", "cdylib", "staticlib"] diff --git a/ffi/ffi-macros/Cargo.toml b/ffi/ffi-macros/Cargo.toml new file mode 100644 index 0000000..eaea029 --- /dev/null +++ b/ffi/ffi-macros/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "snips-nlu-parsers-ffi-macros" +version = "0.1.0" +authors = ["Adrien Ball "] +edition = "2018" + +[dependencies] +failure = "0.1" +ffi-utils = { git = "https://github.com/snipsco/snips-utils-rs", rev = "4292ad9" } +libc = "0.2" +serde = "1.0" +serde_json = "1.0" +snips-nlu-ontology = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-ontology-ffi-macros = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-parsers = { path = "../.." } + +[lib] +crate-type = ["rlib", "cdylib", "staticlib"] diff --git a/ffi/ffi-macros/src/builtin_entity_parser.rs b/ffi/ffi-macros/src/builtin_entity_parser.rs new file mode 100644 index 0000000..b957a9e --- /dev/null +++ b/ffi/ffi-macros/src/builtin_entity_parser.rs @@ -0,0 +1,141 @@ +use std::ffi::CStr; +use std::slice; + +use failure::ResultExt; +use libc; +use serde_json; + +use crate::Result; +use ffi_utils::{ + convert_to_c_string, convert_to_c_string_result, CReprOf, CStringArray, RawPointerConverter, +}; +use snips_nlu_ontology::{BuiltinEntity, BuiltinEntityKind}; +use snips_nlu_ontology_ffi_macros::{CBuiltinEntity, CBuiltinEntityArray}; +use snips_nlu_parsers::{BuiltinEntityParser, BuiltinEntityParserLoader}; + +#[repr(C)] +pub struct CBuiltinEntityParser(*const libc::c_void); + +macro_rules! get_parser { + ($opaque:ident) => {{ + let container: &$crate::CBuiltinEntityParser = unsafe { &*$opaque }; + let x = container.0 as *const BuiltinEntityParser; + unsafe { &*x } + }}; +} + +pub fn create_builtin_entity_parser( + ptr: *mut *const CBuiltinEntityParser, + json_config: *const libc::c_char, +) -> Result<()> { + let json_config = unsafe { CStr::from_ptr(json_config) }.to_str()?; + let parser_loader: BuiltinEntityParserLoader = serde_json::from_str(json_config)?; + let parser = parser_loader.load()?; + + let c_parser = CBuiltinEntityParser(parser.into_raw_pointer() as _).into_raw_pointer(); + + unsafe { + *ptr = c_parser; + } + Ok(()) +} + +pub fn persist_builtin_entity_parser( + ptr: *const CBuiltinEntityParser, + path: *const libc::c_char, +) -> Result<()> { + let parser = get_parser!(ptr); + let parser_path = unsafe { CStr::from_ptr(path) }.to_str()?; + parser.persist(parser_path)?; + Ok(()) +} + +pub fn load_builtin_entity_parser( + ptr: *mut *const CBuiltinEntityParser, + path: *const libc::c_char, +) -> Result<()> { + let parser_path = unsafe { CStr::from_ptr(path) }.to_str()?; + let builtin_entity_parser = BuiltinEntityParser::from_path(parser_path)?; + let c_parser = + CBuiltinEntityParser(builtin_entity_parser.into_raw_pointer() as _).into_raw_pointer(); + + unsafe { + *ptr = c_parser; + } + Ok(()) +} + +pub fn extract_builtin_entity_c( + ptr: *const CBuiltinEntityParser, + sentence: *const libc::c_char, + filter_entity_kinds: *const CStringArray, + results: *mut *const CBuiltinEntityArray, +) -> Result<()> { + let c_entities = extract_builtin_entity(ptr, sentence, filter_entity_kinds)? + .into_iter() + .map(CBuiltinEntity::from) + .collect::>(); + let c_entities = CBuiltinEntityArray::from(c_entities).into_raw_pointer(); + + unsafe { + *results = c_entities; + } + + Ok(()) +} + +pub fn extract_builtin_entity_json( + ptr: *const CBuiltinEntityParser, + sentence: *const libc::c_char, + filter_entity_kinds: *const CStringArray, + results: *mut *const libc::c_char, +) -> Result<()> { + let entities = extract_builtin_entity(ptr, sentence, filter_entity_kinds)?; + let json = ::serde_json::to_string(&entities)?; + + let cs = convert_to_c_string!(json); + unsafe { *results = cs } + + Ok(()) +} + +pub fn extract_builtin_entity( + ptr: *const CBuiltinEntityParser, + sentence: *const libc::c_char, + filter_entity_kinds: *const CStringArray, +) -> Result> { + let parser = get_parser!(ptr); + let sentence = unsafe { CStr::from_ptr(sentence) }.to_str()?; + + let opt_filters: Option> = if !filter_entity_kinds.is_null() { + let filters = unsafe { + let array = &*filter_entity_kinds; + slice::from_raw_parts(array.data, array.size as usize) + } + .into_iter() + .map(|&ptr| { + Ok(unsafe { CStr::from_ptr(ptr) } + .to_str() + .map_err(::failure::Error::from) + .and_then(|s| { + Ok(BuiltinEntityKind::from_identifier(s) + .with_context(|_| format!("`{}` isn't a known builtin entity kind", s))?) + })?) + }) + .collect::>>()?; + Some(filters) + } else { + None + }; + let opt_filters = opt_filters.as_ref().map(|vec| vec.as_slice()); + + parser.extract_entities(sentence, opt_filters) +} + +pub fn destroy_builtin_entity_parser(ptr: *mut CBuiltinEntityParser) -> Result<()> { + unsafe { + let parser = CBuiltinEntityParser::from_raw_pointer(ptr)?.0; + let _ = BuiltinEntityParser::from_raw_pointer(parser as _); + } + Ok(()) +} diff --git a/ffi/ffi-macros/src/gazetteer_entity_parser.rs b/ffi/ffi-macros/src/gazetteer_entity_parser.rs new file mode 100644 index 0000000..e5b148b --- /dev/null +++ b/ffi/ffi-macros/src/gazetteer_entity_parser.rs @@ -0,0 +1,115 @@ +use std::ffi::CStr; +use std::slice; + +use libc; +use serde_json; + +use crate::Result; +use ffi_utils::{ + convert_to_c_string, convert_to_c_string_result, CReprOf, CStringArray, RawPointerConverter, +}; +use snips_nlu_parsers::{GazetteerEntityMatch, GazetteerParser, GazetteerParserBuilder}; + +#[repr(C)] +pub struct CGazetteerEntityParser(*const libc::c_void); + +macro_rules! get_parser { + ($opaque:ident) => {{ + let container: &$crate::CGazetteerEntityParser = unsafe { &*$opaque }; + let x = container.0 as *const GazetteerParser; + unsafe { &*x } + }}; +} + +pub fn load_gazetteer_entity_parser( + ptr: *mut *const CGazetteerEntityParser, + path: *const libc::c_char, +) -> Result<()> { + let parser_path = unsafe { CStr::from_ptr(path) }.to_str()?; + let gazetteer_parser = GazetteerParser::::from_path(parser_path)?; + let c_parser = + CGazetteerEntityParser(gazetteer_parser.into_raw_pointer() as _).into_raw_pointer(); + + unsafe { + *ptr = c_parser; + } + Ok(()) +} + +pub fn build_gazetteer_entity_parser( + ptr: *mut *const CGazetteerEntityParser, + json_config: *const libc::c_char, +) -> Result<()> { + let json_config = unsafe { CStr::from_ptr(json_config) }.to_str()?; + let gazetteer_parser = + serde_json::from_str::(json_config)?.build::()?; + let c_parser = + CGazetteerEntityParser(gazetteer_parser.into_raw_pointer() as _).into_raw_pointer(); + + unsafe { + *ptr = c_parser; + } + Ok(()) +} + +pub fn persist_gazetteer_entity_parser( + ptr: *const CGazetteerEntityParser, + path: *const libc::c_char, +) -> Result<()> { + let parser = get_parser!(ptr); + let parser_path = unsafe { CStr::from_ptr(path) }.to_str()?; + parser.persist(parser_path)?; + Ok(()) +} + +pub fn extract_gazetteer_entity_json( + ptr: *const CGazetteerEntityParser, + sentence: *const libc::c_char, + filter_entity_kinds: *const CStringArray, + results: *mut *const libc::c_char, +) -> Result<()> { + let entities = extract_gazetteer_entity(ptr, sentence, filter_entity_kinds)?; + let json = ::serde_json::to_string(&entities)?; + + let cs = convert_to_c_string!(json); + unsafe { *results = cs } + + Ok(()) +} + +pub fn extract_gazetteer_entity( + ptr: *const CGazetteerEntityParser, + sentence: *const libc::c_char, + filter_entity_kinds: *const CStringArray, +) -> Result>> { + let parser = get_parser!(ptr); + let sentence = unsafe { CStr::from_ptr(sentence) }.to_str()?; + + let opt_filters: Option> = if !filter_entity_kinds.is_null() { + let filters = unsafe { + let array = &*filter_entity_kinds; + slice::from_raw_parts(array.data, array.size as usize) + } + .into_iter() + .map(|&ptr| { + Ok(unsafe { CStr::from_ptr(ptr) } + .to_str() + .map_err(::failure::Error::from)? + .to_string()) + }) + .collect::>>()?; + Some(filters) + } else { + None + }; + + parser.extract_entities(sentence, opt_filters.as_ref().map(|filters| &**filters)) +} + +pub fn destroy_gazetteer_entity_parser(ptr: *mut CGazetteerEntityParser) -> Result<()> { + unsafe { + let parser = CGazetteerEntityParser::from_raw_pointer(ptr)?.0; + let _ = GazetteerParser::::from_raw_pointer(parser as _); + } + Ok(()) +} diff --git a/ffi/ffi-macros/src/lib.rs b/ffi/ffi-macros/src/lib.rs new file mode 100644 index 0000000..0c1f218 --- /dev/null +++ b/ffi/ffi-macros/src/lib.rs @@ -0,0 +1,128 @@ +mod builtin_entity_parser; +mod gazetteer_entity_parser; + +pub use builtin_entity_parser::*; +pub use gazetteer_entity_parser::*; + +type Result = ::std::result::Result; + +#[macro_export] +macro_rules! export_nlu_parsers_c_symbols { + () => { + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_create_builtin_entity_parser( + ptr: *mut *const $crate::CBuiltinEntityParser, + json_config: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::create_builtin_entity_parser(ptr, json_config)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_persist_builtin_entity_parser( + ptr: *const $crate::CBuiltinEntityParser, + path: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::persist_builtin_entity_parser(ptr, path)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_load_builtin_entity_parser( + ptr: *mut *const $crate::CBuiltinEntityParser, + parser_path: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::load_builtin_entity_parser(ptr, parser_path)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_extract_builtin_entities( + ptr: *const $crate::CBuiltinEntityParser, + sentence: *const ::libc::c_char, + filter_entity_kinds: *const ::ffi_utils::CStringArray, + results: *mut *const snips_nlu_ontology_ffi_macros::CBuiltinEntityArray, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::extract_builtin_entity_c( + ptr, + sentence, + filter_entity_kinds, + results + )) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_extract_builtin_entities_json( + ptr: *const $crate::CBuiltinEntityParser, + sentence: *const ::libc::c_char, + filter_entity_kinds: *const ::ffi_utils::CStringArray, + results: *mut *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::extract_builtin_entity_json( + ptr, + sentence, + filter_entity_kinds, + results + )) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_destroy_builtin_entity_array( + ptr: *mut ::snips_nlu_ontology_ffi_macros::CBuiltinEntityArray, + ) -> ::ffi_utils::SNIPS_RESULT { + use ffi_utils::RawPointerConverter; + use snips_nlu_ontology_ffi_macros::CBuiltinEntityArray; + wrap!(unsafe { CBuiltinEntityArray::from_raw_pointer(ptr) }) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_destroy_builtin_entity_parser( + ptr: *mut $crate::CBuiltinEntityParser, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::destroy_builtin_entity_parser(ptr)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_build_gazetteer_entity_parser( + ptr: *mut *const $crate::CGazetteerEntityParser, + json_config: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::build_gazetteer_entity_parser(ptr, json_config)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_load_gazetteer_entity_parser( + ptr: *mut *const $crate::CGazetteerEntityParser, + parser_path: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::load_gazetteer_entity_parser(ptr, parser_path)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_persist_gazetteer_entity_parser( + ptr: *const $crate::CGazetteerEntityParser, + path: *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::persist_gazetteer_entity_parser(ptr, path)) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_extract_gazetteer_entities_json( + ptr: *const $crate::CGazetteerEntityParser, + sentence: *const ::libc::c_char, + filter_entity_kinds: *const ::ffi_utils::CStringArray, + results: *mut *const ::libc::c_char, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::extract_gazetteer_entity_json( + ptr, + sentence, + filter_entity_kinds, + results + )) + } + + #[no_mangle] + pub extern "C" fn snips_nlu_parsers_destroy_gazetteer_entity_parser( + ptr: *mut $crate::CGazetteerEntityParser, + ) -> ::ffi_utils::SNIPS_RESULT { + wrap!($crate::destroy_gazetteer_entity_parser(ptr)) + } + }; +} diff --git a/ffi/src/lib.rs b/ffi/src/lib.rs new file mode 100644 index 0000000..5efadf2 --- /dev/null +++ b/ffi/src/lib.rs @@ -0,0 +1,9 @@ +use ffi_utils::{generate_error_handling, wrap}; +use snips_nlu_parsers_ffi_macros::export_nlu_parsers_c_symbols; +use snips_nlu_ontology_ffi_macros::export_nlu_ontology_c_symbols; + +generate_error_handling!(snips_nlu_parsers_get_last_error); + +export_nlu_ontology_c_symbols!(); + +export_nlu_parsers_c_symbols!(); diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000..4ea4c06 --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,14 @@ +venv/ +venv2/ +venv3/ +venv34/ +venv35/ +venv36/ +venv37/ +build/ +dist/ +*.pyc +*.py.bak +*.egg-info/ +.idea +.tox/ diff --git a/python/LICENSE b/python/LICENSE new file mode 100644 index 0000000..de8afc6 --- /dev/null +++ b/python/LICENSE @@ -0,0 +1,13 @@ +Copyright 2018 Snips + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/python/MANIFEST.in b/python/MANIFEST.in new file mode 100644 index 0000000..a625519 --- /dev/null +++ b/python/MANIFEST.in @@ -0,0 +1,6 @@ +include snips_nlu_parsers/__version__ +include LICENSE README.rst +recursive-include snips_nlu_parsers/dylib/ * +recursive-include ffi/ * +recursive-exclude ffi/target/ * +global-exclude __pycache__ *.py[cod] diff --git a/python/README.rst b/python/README.rst new file mode 100644 index 0000000..4eb2cc6 --- /dev/null +++ b/python/README.rst @@ -0,0 +1,40 @@ +Snips NLU Parsers +================= + +Installation +------------ + +------------- +Linux / MacOS +------------- + +.. code-block:: bash + + pip install snips-nlu-parsers + +--------------- +Other platforms +--------------- + +This package can be installed via pip from a source distribution. As it contains +some ``rust`` code, ``rust`` must be installed on your machine. + +To install Rust, run the following in your terminal, then follow the onscreen instructions: + +.. code-block:: bash + + curl https://sh.rustup.rs -sSf | sh + + +You will also need the python lib ``setuptools_rust``: + +.. code-block:: bash + + pip install setuptools_rust + +Finally, you can install ``snips-nlu-parsers`` using pip: + +.. code-block:: bash + + pip install snips-nlu-parsers + diff --git a/python/ffi/Cargo.toml b/python/ffi/Cargo.toml new file mode 100644 index 0000000..fa6fefa --- /dev/null +++ b/python/ffi/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "snips-nlu-parsers-python-ffi" +version = "0.1.0" +authors = ["Adrien Ball "] +edition = "2018" + +[lib] +name = "snips_nlu_parsers_python_ffi" +crate-type = ["cdylib"] + +[dependencies] +failure = "0.1" +libc = "0.2" +ffi-utils = { git = "https://github.com/snipsco/snips-utils-rs", rev = "4292ad9" } +snips-nlu-parsers-ffi-macros = { git = "https://github.com/snipsco/snips-nlu-parsers", tag = "0.1.0" } +snips-nlu-ontology = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } +snips-nlu-ontology-ffi-macros = { git = "https://github.com/snipsco/snips-nlu-ontology", tag = "0.62.0" } diff --git a/python/ffi/src/lib.rs b/python/ffi/src/lib.rs new file mode 100644 index 0000000..08fe6df --- /dev/null +++ b/python/ffi/src/lib.rs @@ -0,0 +1,9 @@ +use ffi_utils::*; +use snips_nlu_ontology_ffi_macros::export_nlu_ontology_c_symbols; +use snips_nlu_parsers_ffi_macros::export_nlu_parsers_c_symbols; + +generate_error_handling!(snips_nlu_parsers_get_last_error); + +export_nlu_ontology_c_symbols!(); + +export_nlu_parsers_c_symbols!(); diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..5bfbd7e --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,2 @@ +setuptools_rust==0.8.4 +wheel==0.30.0 diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 0000000..2889290 --- /dev/null +++ b/python/setup.py @@ -0,0 +1,60 @@ +from __future__ import print_function + +import io +import os +import sys + +from setuptools import setup, find_packages +from setuptools_rust import Binding, RustExtension + +packages = [p for p in find_packages() if "tests" not in p] + +PACKAGE_NAME = "snips_nlu_parsers" +ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) +PACKAGE_PATH = os.path.join(ROOT_PATH, PACKAGE_NAME) +README = os.path.join(ROOT_PATH, "README.rst") +VERSION = "__version__" + +RUST_EXTENSION_NAME = 'snips_nlu_parsers.dylib.libsnips_nlu_parsers_rs' +CARGO_ROOT_PATH = os.path.join(ROOT_PATH, 'ffi') +CARGO_FILE_PATH = os.path.join(CARGO_ROOT_PATH, 'Cargo.toml') +CARGO_TARGET_DIR = os.path.join(CARGO_ROOT_PATH, 'target') +os.environ['CARGO_TARGET_DIR'] = CARGO_TARGET_DIR + +with io.open(os.path.join(PACKAGE_PATH, VERSION)) as f: + version = f.readline() + +with io.open(README, "rt", encoding="utf8") as f: + readme = f.read() + +required = [ + "future==0.16.0", + "pathlib==1.0.1; python_version < '3.4'", +] + +rust_extension = RustExtension( + RUST_EXTENSION_NAME, CARGO_FILE_PATH, debug="develop" in sys.argv, + args=["--verbose"] if "--verbose" in sys.argv else None, + binding=Binding.NoBinding) + +setup(name=PACKAGE_NAME, + description="Python wrapper of the snips-nlu-parsers Rust crate", + long_description=readme, + version=version, + license="Apache 2.0", + author="Adrien Ball", + author_email="adrien.ball@snips.ai", + classifiers=[ + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + ], + rust_extensions=[rust_extension], + install_requires=required, + packages=packages, + include_package_data=True, + zip_safe=False) diff --git a/python/snips_nlu_parsers/__init__.py b/python/snips_nlu_parsers/__init__.py new file mode 100644 index 0000000..8a8ae77 --- /dev/null +++ b/python/snips_nlu_parsers/__init__.py @@ -0,0 +1,9 @@ +from __future__ import absolute_import + +from snips_nlu_parsers.builtin_entities import ( + get_all_builtin_entities, get_all_gazetteer_entities, + get_all_grammar_entities, get_all_languages, get_builtin_entity_examples, + get_builtin_entity_shortname, get_ontology_version, get_supported_entities, + get_supported_gazetteer_entities, get_supported_grammar_entities) +from snips_nlu_parsers.builtin_entity_parser import BuiltinEntityParser +from snips_nlu_parsers.gazetteer_entity_parser import GazetteerEntityParser diff --git a/python/snips_nlu_parsers/__version__ b/python/snips_nlu_parsers/__version__ new file mode 100644 index 0000000..6e8bf73 --- /dev/null +++ b/python/snips_nlu_parsers/__version__ @@ -0,0 +1 @@ +0.1.0 diff --git a/python/snips_nlu_parsers/builtin_entities.py b/python/snips_nlu_parsers/builtin_entities.py new file mode 100644 index 0000000..5f4b4e7 --- /dev/null +++ b/python/snips_nlu_parsers/builtin_entities.py @@ -0,0 +1,198 @@ +# coding=utf-8 +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +from _ctypes import byref, pointer +from builtins import range, str +from ctypes import c_char_p, string_at + +from snips_nlu_parsers.utils import (CStringArray, check_ffi_error, lib, + string_array_pointer, string_pointer) + +_ALL_LANGUAGES = None +_SUPPORTED_ENTITIES = dict() +_SUPPORTED_GAZETTEER_ENTITIES = dict() +_SUPPORTED_GRAMMAR_ENTITIES = dict() +_ENTITIES_EXAMPLES = dict() +_ALL_BUILTIN_ENTITIES = None +_ALL_GAZETTEER_ENTITIES = None +_ALL_GRAMMAR_ENTITIES = None +_BUILTIN_ENTITIES_SHORTNAMES = dict() +_ONTOLOGY_VERSION = None + + +def get_ontology_version(): + """Get the version of the ontology""" + global _ONTOLOGY_VERSION + if _ONTOLOGY_VERSION is None: + lib.snips_nlu_ontology_version.restype = c_char_p + _ONTOLOGY_VERSION = lib.snips_nlu_ontology_version().decode("utf8") + return _ONTOLOGY_VERSION + + +def get_all_languages(): + """Lists all the supported languages""" + global _ALL_LANGUAGES + if _ALL_LANGUAGES is None: + lib.snips_nlu_ontology_supported_languages.restype = CStringArray + array = lib.snips_nlu_ontology_supported_languages() + _ALL_LANGUAGES = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _ALL_LANGUAGES + + +def get_all_builtin_entities(): + """Lists the builtin entities that are supported in at least one + language""" + global _ALL_BUILTIN_ENTITIES + if _ALL_BUILTIN_ENTITIES is None: + lib.snips_nlu_ontology_all_builtin_entities.restype = CStringArray + array = lib.snips_nlu_ontology_all_builtin_entities() + _ALL_BUILTIN_ENTITIES = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _ALL_BUILTIN_ENTITIES + + +def get_all_gazetteer_entities(): + """Lists the gazetteer entities that are supported in at least one + language""" + global _ALL_GAZETTEER_ENTITIES + if _ALL_GAZETTEER_ENTITIES is None: + lib.snips_nlu_ontology_all_gazetteer_entities.restype = CStringArray + array = lib.snips_nlu_ontology_all_gazetteer_entities() + _ALL_GAZETTEER_ENTITIES = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _ALL_GAZETTEER_ENTITIES + + +def get_all_grammar_entities(): + """Lists the grammar entities that are supported in at least one + language""" + global _ALL_GRAMMAR_ENTITIES + if _ALL_GRAMMAR_ENTITIES is None: + lib.snips_nlu_ontology_all_grammar_entities.restype = CStringArray + array = lib.snips_nlu_ontology_all_grammar_entities() + _ALL_GRAMMAR_ENTITIES = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _ALL_GRAMMAR_ENTITIES + + +def get_builtin_entity_shortname(entity): + """Get the short name of the entity + + Examples: + + >>> get_builtin_entity_shortname(u"snips/amountOfMoney") + 'AmountOfMoney' + """ + global _BUILTIN_ENTITIES_SHORTNAMES + if entity not in _BUILTIN_ENTITIES_SHORTNAMES: + with string_pointer(c_char_p()) as ptr: + exit_code = lib.snips_nlu_ontology_entity_shortname( + entity.encode("utf8"), byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when retrieving " + "builtin entity shortname") + result = string_at(ptr) + _BUILTIN_ENTITIES_SHORTNAMES[entity] = result.decode("utf8") + return _BUILTIN_ENTITIES_SHORTNAMES[entity] + + +def get_supported_entities(language): + """Lists the builtin entities supported in the specified *language* + + Returns: + list of str: the list of entity labels + """ + global _SUPPORTED_ENTITIES + + if not isinstance(language, str): + raise TypeError("Expected language to be of type 'str' but found: %s" + % type(language)) + + if language not in _SUPPORTED_ENTITIES: + with string_array_pointer(pointer(CStringArray())) as ptr: + exit_code = lib.snips_nlu_ontology_supported_builtin_entities( + language.encode("utf8"), byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when retrieving " + "supported entities") + array = ptr.contents + _SUPPORTED_ENTITIES[language] = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _SUPPORTED_ENTITIES[language] + + +def get_supported_gazetteer_entities(language): + """Lists the gazetteer entities supported in the specified *language* + + Returns: + list of str: the list of entity labels + """ + global _SUPPORTED_GAZETTEER_ENTITIES + + if not isinstance(language, str): + raise TypeError("Expected language to be of type 'str' but found: %s" + % type(language)) + + if language not in _SUPPORTED_GAZETTEER_ENTITIES: + with string_array_pointer(pointer(CStringArray())) as ptr: + exit_code = \ + lib.snips_nlu_ontology_supported_builtin_gazetteer_entities( + language.encode("utf8"), byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when retrieving " + "supported gazetteer entities") + array = ptr.contents + _SUPPORTED_GAZETTEER_ENTITIES[language] = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _SUPPORTED_GAZETTEER_ENTITIES[language] + + +def get_supported_grammar_entities(language): + """Lists the grammar entities supported in the specified *language* + + Returns: + list of str: the list of entity labels + """ + global _SUPPORTED_GRAMMAR_ENTITIES + + if not isinstance(language, str): + raise TypeError("Expected language to be of type 'str' but found: %s" + % type(language)) + + if language not in _SUPPORTED_GRAMMAR_ENTITIES: + with string_array_pointer(pointer(CStringArray())) as ptr: + exit_code = lib.snips_nlu_ontology_supported_grammar_entities( + language.encode("utf8"), byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when retrieving " + "supported grammar entities") + array = ptr.contents + _SUPPORTED_GRAMMAR_ENTITIES[language] = set( + array.data[i].decode("utf8") for i in range(array.size)) + return _SUPPORTED_GRAMMAR_ENTITIES[language] + + +def get_builtin_entity_examples(builtin_entity_kind, language): + """Provides some examples of the builtin entity in the specified language + """ + global _ENTITIES_EXAMPLES + + if not isinstance(builtin_entity_kind, str): + raise TypeError("Expected `builtin_entity_kind` to be of type 'str' " + "but found: %s" % type(builtin_entity_kind)) + if not isinstance(language, str): + raise TypeError("Expected `language` to be of type 'str' but found: %s" + % type(language)) + + if builtin_entity_kind not in _ENTITIES_EXAMPLES: + _ENTITIES_EXAMPLES[builtin_entity_kind] = dict() + + if language not in _ENTITIES_EXAMPLES[builtin_entity_kind]: + with string_array_pointer(pointer(CStringArray())) as ptr: + exit_code = lib.snips_nlu_ontology_builtin_entity_examples( + builtin_entity_kind.encode("utf8"), + language.encode("utf8"), byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when retrieving " + "builtin entity examples") + array = ptr.contents + _ENTITIES_EXAMPLES[builtin_entity_kind][language] = list( + array.data[i].decode("utf8") for i in range(array.size)) + return _ENTITIES_EXAMPLES[builtin_entity_kind][language] diff --git a/python/snips_nlu_parsers/builtin_entity_parser.py b/python/snips_nlu_parsers/builtin_entity_parser.py new file mode 100644 index 0000000..8cd0743 --- /dev/null +++ b/python/snips_nlu_parsers/builtin_entity_parser.py @@ -0,0 +1,101 @@ +import json +from _ctypes import byref, pointer +from builtins import bytes, str +from ctypes import c_char_p, c_int, c_void_p, string_at +from pathlib import Path + +from snips_nlu_parsers.utils import ( + CStringArray, check_ffi_error, lib, string_pointer) + + +class BuiltinEntityParser(object): + def __init__(self, parser): + self._parser = parser + + @classmethod + def build(cls, language, gazetteer_entity_parser_path=None): + """Build a `BuiltinEntityParser` + + Args: + language (str): Language identifier + gazetteer_entity_parser_path (str, optional): Path to a gazetteer + entity parser. If None, the builtin entity parser will only + use grammar entities. + """ + if isinstance(gazetteer_entity_parser_path, Path): + gazetteer_entity_parser_path = str(gazetteer_entity_parser_path) + if not isinstance(language, str): + raise TypeError("Expected language to be of type 'str' but found:" + " %s" % type(language)) + parser_config = dict( + language=language.upper(), + gazetteer_parser_path=gazetteer_entity_parser_path) + parser = pointer(c_void_p()) + json_parser_config = bytes(json.dumps(parser_config), encoding="utf8") + exit_code = lib.snips_nlu_parsers_create_builtin_entity_parser( + byref(parser), json_parser_config) + check_ffi_error(exit_code, "Something went wrong while creating the " + "builtin entity parser") + return cls(parser) + + def parse(self, text, scope=None): + """Extract builtin entities from *text* + + Args: + text (str): Input + scope (list of str, optional): List of builtin entity labels. If + defined, the parser will extract entities using the provided + scope instead of the entire scope of all available entities. + This allows to look for specifics builtin entity kinds. + + Returns: + list of dict: The list of extracted entities + """ + if not isinstance(text, str): + raise TypeError("Expected language to be of type 'str' but found: " + "%s" % type(text)) + if scope is not None: + if not all(isinstance(e, str) for e in scope): + raise TypeError( + "Expected scope to contain objects of type 'str'") + scope = [e.encode("utf8") for e in scope] + arr = CStringArray() + arr.size = c_int(len(scope)) + arr.data = (c_char_p * len(scope))(*scope) + scope = byref(arr) + + with string_pointer(c_char_p()) as ptr: + exit_code = lib.snips_nlu_parsers_extract_builtin_entities_json( + self._parser, text.encode("utf8"), scope, byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when extracting " + "builtin entities") + result = string_at(ptr) + return json.loads(result.decode("utf8")) + + def persist(self, path): + """Persist the gazetteer parser on disk at the provided path""" + if isinstance(path, Path): + path = str(path) + exit_code = lib.snips_nlu_parsers_persist_builtin_entity_parser( + self._parser, path.encode("utf8")) + check_ffi_error(exit_code, "Something went wrong when persisting the " + "builtin entity parser") + + @classmethod + def from_path(cls, parser_path): + """Create a :class:`GazetteerEntityParser` from a gazetteer parser + persisted on disk + """ + if isinstance(parser_path, Path): + parser_path = str(parser_path) + parser = pointer(c_void_p()) + parser_path = bytes(parser_path, encoding="utf8") + exit_code = lib.snips_nlu_parsers_load_builtin_entity_parser( + byref(parser), parser_path) + check_ffi_error(exit_code, "Something went wrong when loading the " + "builtin entity parser") + return cls(parser) + + def __del__(self): + if lib is not None and hasattr(self, '_parser'): + lib.snips_nlu_parsers_destroy_builtin_entity_parser(self._parser) diff --git a/python/snips_nlu_parsers/dylib/.gitignore b/python/snips_nlu_parsers/dylib/.gitignore new file mode 100644 index 0000000..41b5b91 --- /dev/null +++ b/python/snips_nlu_parsers/dylib/.gitignore @@ -0,0 +1,3 @@ +*.dylib +*.so +*.dll \ No newline at end of file diff --git a/python/snips_nlu_parsers/gazetteer_entity_parser.py b/python/snips_nlu_parsers/gazetteer_entity_parser.py new file mode 100644 index 0000000..53ae3f9 --- /dev/null +++ b/python/snips_nlu_parsers/gazetteer_entity_parser.py @@ -0,0 +1,127 @@ +import json +from _ctypes import byref, pointer +from builtins import bytes, str +from ctypes import c_char_p, c_int, c_void_p, string_at +from pathlib import Path + +from snips_nlu_parsers.utils import (CStringArray, check_ffi_error, lib, + string_pointer) + + +class GazetteerEntityParser(object): + def __init__(self, parser): + self._parser = parser + + @classmethod + def build(cls, build_config): + """Create a new :class:`GazetteerEntityParser` from a build config + + The build configuration must have the following format: + + { + "entity_parsers": [ + { + "entity_identifier": "my_first_entity", + "entity_parser": { + "gazetteer": [ + { + "raw_value": "foo bar", + "resolved_value": "Foo Bar" + }, + { + "raw_value": "yolo", + "resolved_value": "Yala" + } + ], + "threshold": 0.6, + "n_gazetteer_stop_words": 10, + "additional_stop_words": ["the", "a"] + } + }, + { + "entity_identifier": "my_second_entity", + "entity_parser": { + "gazetteer": [ + { + "raw_value": "the stones", + "resolved_value": "The Rolling Stones" + } + ], + "threshold": 0.6, + "n_gazetteer_stop_words": None, + "additional_stop_words": None + } + }, + ] + } + """ + parser = pointer(c_void_p()) + json_parser_config = bytes(json.dumps(build_config), encoding="utf8") + exit_code = lib.snips_nlu_parsers_build_gazetteer_entity_parser( + byref(parser), json_parser_config) + check_ffi_error(exit_code, "Something went wrong when building the " + "gazetteer entity parser") + return cls(parser) + + def parse(self, text, scope=None): + """Extract gazetteer entities from *text* + + Args: + text (str): Input + scope (list of str, optional): List of entity labels. If defined, + the parser will extract entities using the provided scope + instead of the entire scope of all available entities. This + allows to look for specifics entities. + + Returns: + list of dict: The list of extracted entities + """ + if not isinstance(text, str): + raise TypeError("Expected text to be of type 'str' but found: " + "%s" % type(text)) + if scope is not None: + if not all(isinstance(e, str) for e in scope): + raise TypeError( + "Expected scope to contain objects of type 'str'") + scope = [e.encode("utf8") for e in scope] + arr = CStringArray() + arr.size = c_int(len(scope)) + arr.data = (c_char_p * len(scope))(*scope) + scope = byref(arr) + + with string_pointer(c_char_p()) as ptr: + exit_code = lib.snips_nlu_parsers_extract_gazetteer_entities_json( + self._parser, text.encode("utf8"), scope, byref(ptr)) + check_ffi_error(exit_code, "Something went wrong when " + "extracting gazetteer entities") + result = string_at(ptr) + return json.loads(result.decode("utf8")) + + def persist(self, path): + """Persist the gazetteer parser on disk at the provided path""" + if isinstance(path, Path): + path = str(path) + exit_code = lib.snips_nlu_parsers_persist_gazetteer_entity_parser( + self._parser, path.encode("utf8")) + check_ffi_error(exit_code, "Something went wrong when persisting " + "the gazetteer entity parser") + + @classmethod + def from_path(cls, parser_path): + """Create a :class:`GazetteerEntityParser` from a gazetteer parser + persisted on disk + """ + if isinstance(parser_path, Path): + parser_path = str(parser_path) + parser = pointer(c_void_p()) + parser_path = bytes(parser_path, encoding="utf8") + exit_code = lib.snips_nlu_parsers_load_gazetteer_entity_parser( + byref(parser), parser_path) + check_ffi_error(exit_code, "Something went wrong when loading the " + "gazetteer entity parser") + return cls(parser) + + def __del__(self): + if lib is not None: + lib.snips_nlu_parsers_destroy_gazetteer_entity_parser( + self._parser) diff --git a/python/snips_nlu_parsers/tests/__init__.py b/python/snips_nlu_parsers/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/snips_nlu_parsers/tests/test_builtin_entities.py b/python/snips_nlu_parsers/tests/test_builtin_entities.py new file mode 100644 index 0000000..b853dee --- /dev/null +++ b/python/snips_nlu_parsers/tests/test_builtin_entities.py @@ -0,0 +1,103 @@ +from __future__ import unicode_literals + +import unittest +from builtins import str + +from snips_nlu_parsers.builtin_entities import ( + get_all_builtin_entities, get_all_gazetteer_entities, + get_all_grammar_entities, get_all_languages, get_builtin_entity_examples, + get_builtin_entity_shortname, get_ontology_version, get_supported_entities, + get_supported_gazetteer_entities, get_supported_grammar_entities) + + +class TestBuiltinEntities(unittest.TestCase): + def test_should_get_all_languages(self): + # When + all_languages = get_all_languages() + + # Then + self.assertIn("en", all_languages) + self.assertIn("fr", all_languages) + for language in all_languages: + self.assertIsInstance(language, str) + + def test_should_get_builtin_entity_shortname(self): + # Given + entity_name = "snips/amountOfMoney" + + # When + short_name = get_builtin_entity_shortname(entity_name) + + # Then + self.assertEqual("AmountOfMoney", short_name) + + def test_should_get_all_builtin_entities(self): + # When + all_builtins = get_all_builtin_entities() + + # Then + self.assertIn("snips/number", all_builtins) + self.assertIn("snips/musicArtist", all_builtins) + for builtin in all_builtins: + self.assertIsInstance(builtin, str) + + def test_should_get_all_grammar_entities(self): + # When + all_grammar_entities = get_all_grammar_entities() + + # Then + self.assertIn("snips/number", all_grammar_entities) + self.assertNotIn("snips/musicArtist", all_grammar_entities) + for builtin in all_grammar_entities: + self.assertIsInstance(builtin, str) + + def test_should_get_all_gazetteer_entities(self): + # When + all_gazetteer_entities = get_all_gazetteer_entities() + + # Then + self.assertNotIn("snips/number", all_gazetteer_entities) + self.assertIn("snips/musicArtist", all_gazetteer_entities) + for builtin in all_gazetteer_entities: + self.assertIsInstance(builtin, str) + + def test_should_get_supported_builtin_entities(self): + # When + supported_entities = get_supported_entities("en") + + # Then + self.assertIn("snips/number", supported_entities) + self.assertIn("snips/datetime", supported_entities) + for builtin in supported_entities: + self.assertIsInstance(builtin, str) + + def test_should_get_supported_gazetteer_entities(self): + # When + supported_entities = get_supported_gazetteer_entities("fr") + + # Then + self.assertIn("snips/musicArtist", supported_entities) + self.assertIn("snips/musicAlbum", supported_entities) + self.assertNotIn("snips/number", supported_entities) + for builtin in supported_entities: + self.assertIsInstance(builtin, str) + + def test_should_get_supported_grammar_entities(self): + # When + supported_entities = get_supported_grammar_entities("en") + + # Then + self.assertIn("snips/number", supported_entities) + self.assertIn("snips/datetime", supported_entities) + for builtin in supported_entities: + self.assertIsInstance(builtin, str) + + def test_should_get_ontology_version(self): + get_ontology_version() + + def test_should_get_builtin_entity_examples(self): + for language in get_all_languages(): + for builtin_entity in get_supported_entities(language): + examples = get_builtin_entity_examples(builtin_entity, + language) + self.assertGreaterEqual(len(examples), 1) diff --git a/python/snips_nlu_parsers/tests/test_builtin_entity_parser.py b/python/snips_nlu_parsers/tests/test_builtin_entity_parser.py new file mode 100644 index 0000000..a88a5c8 --- /dev/null +++ b/python/snips_nlu_parsers/tests/test_builtin_entity_parser.py @@ -0,0 +1,213 @@ +from __future__ import unicode_literals + +import unittest + +from snips_nlu_parsers import BuiltinEntityParser, get_all_languages +from snips_nlu_parsers.tests.utils import ROOT_DIR +from snips_nlu_parsers.utils import temp_dir + +BUILTIN_PARSER_PATH = ROOT_DIR / "data" / "tests" / "builtin_entity_parser" +BUILTIN_PARSER_NO_GAZETTEER_PATH = ROOT_DIR / "data" / "tests" / \ + "builtin_entity_parser_no_gazetteer" + + +class TestBuiltinEntityParser(unittest.TestCase): + def test_should_parse_without_scope(self): + # Given + parser = BuiltinEntityParser.build("en") + + # When + res = parser.parse("Raise to sixty two degrees celsius") + + # Then + expected_result = [ + { + "entity": { + "kind": "Temperature", + "unit": "celsius", + "value": 62.0 + }, + "entity_kind": "snips/temperature", + "range": {"end": 34, "start": 9}, + "value": "sixty two degrees celsius" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_parse_with_scope(self): + # Given + parser = BuiltinEntityParser.build("en") + scope = ["snips/duration", "snips/temperature"] + + # When + res = parser.parse("Raise to sixty two", scope) + + # Then + expected_result = [ + { + "entity": { + "kind": "Temperature", + "unit": None, + "value": 62.0 + }, + "entity_kind": "snips/temperature", + "range": {"end": 18, "start": 9}, + "value": "sixty two" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_parse_with_gazetteer_entity(self): + # Given + gazetteer_parser_path = ROOT_DIR / "data" / "tests" / \ + "builtin_gazetteer_parser" + parser = BuiltinEntityParser.build("en", gazetteer_parser_path) + scope = ["snips/musicArtist"] + + # When + res = parser.parse("I want to listen to the stones please!", scope) + + # Then + expected_result = [ + { + "entity": { + "kind": "MusicArtist", + "value": "The Rolling Stones" + }, + "entity_kind": "snips/musicArtist", + "range": {"end": 30, "start": 20}, + "value": "the stones" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_parse_in_all_languages(self): + # Given + all_languages = get_all_languages() + text = "1234" + + # When / Then + for language in all_languages: + parser = BuiltinEntityParser.build(language) + parser.parse(text) + + def test_should_persist_parser(self): + # Given + parser = BuiltinEntityParser.build("en") + + # When + with temp_dir() as tmpdir: + persisted_path = str(tmpdir / "persisted_builtin_parser") + parser.persist(persisted_path) + loaded_parser = BuiltinEntityParser.from_path(persisted_path) + res = loaded_parser.parse("Raise the temperature to 9 degrees", None) + + # Then + expected_result = [ + { + "value": "9 degrees", + "entity": { + "kind": "Temperature", + "unit": "degree", + "value": 9.0 + }, + "range": {"start": 25, "end": 34}, + "entity_kind": "snips/temperature" + } + ] + self.assertListEqual(expected_result, res) + + def test_should_load_parser_from_path(self): + # Given + parser = BuiltinEntityParser.from_path( + BUILTIN_PARSER_NO_GAZETTEER_PATH) + + # When + res = parser.parse("Raise the temperature to 9 degrees", None) + + # Then + expected_result = [ + { + "value": "9 degrees", + "entity": { + "kind": "Temperature", + "unit": "degree", + "value": 9.0 + }, + "range": {"start": 25, "end": 34}, + "entity_kind": "snips/temperature" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_persist_parser_with_gazetteer_entities(self): + # Given + parser = BuiltinEntityParser.from_path(BUILTIN_PARSER_PATH) + + # When + with temp_dir() as tmpdir: + persisted_path = str(tmpdir / "persisted_builtin_parser") + parser.persist(persisted_path) + loaded_parser = BuiltinEntityParser.from_path(persisted_path) + res = loaded_parser.parse("I want to listen to the stones", None) + + # Then + expected_result = [ + { + "value": "the stones", + "entity": { + "kind": "MusicArtist", + "value": "The Rolling Stones" + }, + "range": {"start": 20, "end": 30}, + "entity_kind": "snips/musicArtist" + } + ] + self.assertListEqual(expected_result, res) + + def test_should_load_parser_with_gazetteer_entities_from_path(self): + # Given + parser = BuiltinEntityParser.from_path(BUILTIN_PARSER_PATH) + + # When + res = parser.parse("I want to listen to the stones", None) + + # Then + expected_result = [ + { + "value": "the stones", + "entity": { + "kind": "MusicArtist", + "value": "The Rolling Stones" + }, + "range": {"start": 20, "end": 30}, + "entity_kind": "snips/musicArtist" + } + ] + self.assertListEqual(expected_result, res) + + def test_should_not_accept_bytes_as_language(self): + with self.assertRaises(TypeError): + BuiltinEntityParser.build(b"en") + + def test_should_not_accept_bytes_in_text(self): + # Given + parser = BuiltinEntityParser.build("en") + bytes_text = b"Raise to sixty" + + # When/Then + with self.assertRaises(TypeError): + parser.parse(bytes_text) + + def test_should_not_accept_bytes_in_scope(self): + # Given + scope = [b"snips/number", b"snips/datetime"] + parser = BuiltinEntityParser.build("en") + + # When/Then + with self.assertRaises(TypeError): + parser.parse("Raise to sixty", scope) diff --git a/python/snips_nlu_parsers/tests/test_gazetteer_entity_parser.py b/python/snips_nlu_parsers/tests/test_gazetteer_entity_parser.py new file mode 100644 index 0000000..f0e652a --- /dev/null +++ b/python/snips_nlu_parsers/tests/test_gazetteer_entity_parser.py @@ -0,0 +1,169 @@ +from __future__ import unicode_literals + +import unittest +from builtins import str + +from snips_nlu_parsers import GazetteerEntityParser +from snips_nlu_parsers.tests.utils import ROOT_DIR +from snips_nlu_parsers.utils import temp_dir + +CUSTOM_PARSER_PATH = ROOT_DIR / "data" / "tests" / "custom_gazetteer_parser" + + +class TestGazetteerEntityParser(unittest.TestCase): + def get_test_parser_config(self): + return { + "entity_parsers": [ + self.get_music_artist_entity_config(), + self.get_music_track_entity_config(), + ] + } + + @staticmethod + def get_music_track_entity_config(): + return { + "entity_identifier": "music_track", + "entity_parser": { + "gazetteer": [ + { + "raw_value": "what s my age again", + "resolved_value": "What's my age again" + } + ], + "threshold": 0.7, + "n_gazetteer_stop_words": None, + "additional_stop_words": None + } + } + + @staticmethod + def get_music_artist_entity_config(): + return { + "entity_identifier": "music_artist", + "entity_parser": { + "gazetteer": [ + { + "raw_value": "the rolling stones", + "resolved_value": "The Rolling Stones" + }, + { + "raw_value": "blink one eight two", + "resolved_value": "Blink 182" + } + ], + "threshold": 0.6, + "n_gazetteer_stop_words": None, + "additional_stop_words": None + } + } + + def test_should_parse_from_built_parser(self): + # Given + parser_config = self.get_test_parser_config() + parser = GazetteerEntityParser.build(parser_config) + + # When + res = parser.parse("I want to listen to the stones", None) + + # Then + expected_result = [ + { + "value": "the stones", + "resolved_value": "The Rolling Stones", + "range": {"start": 20, "end": 30}, + "entity_identifier": "music_artist" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_parse_from_built_parser_with_scope(self): + # Given + parser_config = self.get_test_parser_config() + parser = GazetteerEntityParser.build(parser_config) + + # When + text = "I want to listen to what s my age again by blink one eight two" + res_artist = parser.parse(text, ["music_artist"]) + res_track = parser.parse(text, ["music_track"]) + + # Then + expected_artist_result = [ + { + "value": "blink one eight two", + "resolved_value": "Blink 182", + "range": {"start": 43, "end": 62}, + "entity_identifier": "music_artist" + } + ] + + expected_track_result = [ + { + "value": "what s my age again", + "resolved_value": "What's my age again", + "range": {"start": 20, "end": 39}, + "entity_identifier": "music_track" + } + ] + + self.assertListEqual(expected_artist_result, res_artist) + self.assertListEqual(expected_track_result, res_track) + + def test_should_persist_parser(self): + # Given + parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH) + + # When + with temp_dir() as tmpdir: + persisted_path = str(tmpdir / "persisted_gazetteer_parser") + parser.persist(persisted_path) + loaded_parser = GazetteerEntityParser.from_path(persisted_path) + res = loaded_parser.parse("I want to listen to the stones", None) + + # Then + expected_result = [ + { + "value": "the stones", + "resolved_value": "The Rolling Stones", + "range": {"start": 20, "end": 30}, + "entity_identifier": "music_artist" + } + ] + self.assertListEqual(expected_result, res) + + def test_should_load_parser_from_path(self): + # Given + parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH) + + # When + res = parser.parse("I want to listen to the stones", None) + + # Then + expected_result = [ + { + "value": "the stones", + "resolved_value": "The Rolling Stones", + "range": {"start": 20, "end": 30}, + "entity_identifier": "music_artist" + } + ] + + self.assertListEqual(expected_result, res) + + def test_should_not_accept_bytes_in_text(self): + # Given + parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH) + bytes_text = b"Raise to sixty" + + # When/Then + with self.assertRaises(TypeError): + parser.parse(bytes_text) + + def test_should_not_accept_bytes_in_scope(self): + # Given + scope = [b"snips/number", b"snips/datetime"] + parser = GazetteerEntityParser.from_path(CUSTOM_PARSER_PATH) + + # When/Then + with self.assertRaises(TypeError): + parser.parse("Raise to sixty", scope) diff --git a/python/snips_nlu_parsers/tests/utils.py b/python/snips_nlu_parsers/tests/utils.py new file mode 100644 index 0000000..2ca61cb --- /dev/null +++ b/python/snips_nlu_parsers/tests/utils.py @@ -0,0 +1,3 @@ +from snips_nlu_parsers.utils import PACKAGE_PATH + +ROOT_DIR = PACKAGE_PATH.parents[1] diff --git a/python/snips_nlu_parsers/utils.py b/python/snips_nlu_parsers/utils.py new file mode 100644 index 0000000..c926e01 --- /dev/null +++ b/python/snips_nlu_parsers/utils.py @@ -0,0 +1,54 @@ +import shutil +from _ctypes import POINTER, Structure, byref +from contextlib import contextmanager +from ctypes import c_char_p, c_int32, cdll, string_at +from pathlib import Path +from tempfile import mkdtemp + +PACKAGE_PATH = Path(__file__).absolute().parent + +dylib_dir = PACKAGE_PATH / "dylib" +dylib_path = list(dylib_dir.glob("libsnips_nlu_parsers_rs*"))[0] +lib = cdll.LoadLibrary(str(dylib_path)) + + +@contextmanager +def string_array_pointer(ptr): + try: + yield ptr + finally: + lib.snips_nlu_ontology_destroy_string_array(ptr) + + +@contextmanager +def string_pointer(ptr): + try: + yield ptr + finally: + lib.snips_nlu_ontology_destroy_string(ptr) + + +class CStringArray(Structure): + _fields_ = [ + ("data", POINTER(c_char_p)), + ("size", c_int32) + ] + + +@contextmanager +def temp_dir(): + tmp_dir = mkdtemp() + try: + yield Path(tmp_dir) + finally: + shutil.rmtree(tmp_dir) + + +def check_ffi_error(exit_code, error_context_msg): + if exit_code != 0: + with string_pointer(c_char_p()) as ptr: + if lib.snips_nlu_parsers_get_last_error(byref(ptr)) == 0: + ffi_error_message = string_at(ptr).decode("utf8") + else: + ffi_error_message = "see stderr" + raise ValueError("%s: %s" % (error_context_msg, ffi_error_message)) diff --git a/python/tox.ini b/python/tox.ini new file mode 100644 index 0000000..539cc03 --- /dev/null +++ b/python/tox.ini @@ -0,0 +1,13 @@ +[tox] +envlist = py27, py34, py35, py36, py37 + +[testenv] +usedevelop = true +skip_install = true +commands = + pip install -r requirements.txt + pip install --verbose -e . --install-option="--verbose" + python -m unittest discover +setenv= + LANG=en_US.UTF-8 + PYTHONIOENCODING=UTF-8 diff --git a/src/builtin_entity_parser.rs b/src/builtin_entity_parser.rs new file mode 100644 index 0000000..f3c002a --- /dev/null +++ b/src/builtin_entity_parser.rs @@ -0,0 +1,495 @@ +use std::fs; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::str::FromStr; + +use crate::conversion::*; +use crate::errors::*; +use crate::gazetteer_parser::GazetteerParser; +use crate::utils::{get_ranges_mapping, NON_SPACE_REGEX, NON_SPACE_SEPARATED_LANGUAGES}; +use failure::{format_err, ResultExt}; +use itertools::Itertools; +use rustling_ontology::{build_parser, OutputKind, Parser as RustlingParser, ResolverContext}; +use serde_derive::{Deserialize, Serialize}; +use serde_json; +use snips_nlu_ontology::*; +use snips_nlu_utils::string::{convert_to_byte_range, convert_to_char_index}; + +pub struct BuiltinEntityParser { + gazetteer_parser: Option>, + rustling_parser: RustlingParser, + language: Language, + rustling_entity_kinds: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct BuiltinEntityParserLoader { + language: Language, + gazetteer_parser_path: Option, +} + +impl BuiltinEntityParserLoader { + pub fn new(language: Language) -> Self { + BuiltinEntityParserLoader { + language, + gazetteer_parser_path: None, + } + } + + pub fn use_gazetter_parser>(&mut self, parser_path: P) -> &mut Self { + self.gazetteer_parser_path = Some(parser_path.as_ref().to_path_buf()); + self + } + + pub fn load(&self) -> Result { + let supported_entity_kinds = BuiltinEntityKind::supported_entity_kinds(self.language); + let ordered_entity_kinds = OutputKind::all() + .iter() + .map(|output_kind| output_kind.ontology_into()) + .filter(|builtin_entity_kind| supported_entity_kinds.contains(&builtin_entity_kind)) + .collect(); + let rustling_parser = build_parser(self.language.ontology_into()).map_err(|_| { + format_err!( + "Cannot create Rustling Parser for language {:?}", + self.language + ) + })?; + let gazetteer_parser = match &self.gazetteer_parser_path { + Some(parser_path) => Some(GazetteerParser::from_path(parser_path)?), + None => None, + }; + + Ok(BuiltinEntityParser { + gazetteer_parser, + rustling_parser, + language: self.language, + rustling_entity_kinds: ordered_entity_kinds, + }) + } +} + +impl BuiltinEntityParser { + pub fn extract_entities( + &self, + sentence: &str, + filter_entity_kinds: Option<&[BuiltinEntityKind]>, + ) -> Result> { + if NON_SPACE_SEPARATED_LANGUAGES.contains(&self.language) { + self._extract_entities_for_non_space_separated(sentence, filter_entity_kinds) + } else { + self._extract_entities(sentence, filter_entity_kinds) + } + } + + fn _extract_entities( + &self, + sentence: &str, + filter_entity_kinds: Option<&[BuiltinEntityKind]>, + ) -> Result> { + let context = ResolverContext::default(); + let rustling_output_kinds = self + .rustling_entity_kinds + .iter() + .filter(|entity_kind| { + filter_entity_kinds + .map(|kinds| kinds.contains(&entity_kind)) + .unwrap_or(true) + }) + .flat_map(|kind| kind.try_ontology_into().ok()) + .collect::>(); + + let rustling_entities = if rustling_output_kinds.is_empty() { + vec![] + } else { + self.rustling_parser + .parse_with_kind_order(&sentence.to_lowercase(), &context, &rustling_output_kinds) + .unwrap_or_else(|_| vec![]) + .into_iter() + .map(|parser_match| rustling::convert_to_builtin(sentence, parser_match)) + .sorted_by(|a, b| Ord::cmp(&a.range.start, &b.range.start)) + }; + + let mut gazetteer_entities = match &self.gazetteer_parser { + Some(gazetteer_parser) => { + let gazetteer_entity_kinds: Option> = + filter_entity_kinds.map(|kinds| { + kinds + .into_iter() + .flat_map(|kind| kind.try_into_gazetteer_kind().ok()) + .collect() + }); + gazetteer_parser.extract_builtin_entities( + sentence, + gazetteer_entity_kinds.as_ref().map(|kinds| &**kinds), + )? + } + None => vec![], + }; + + let mut entities = rustling_entities; + entities.append(&mut gazetteer_entities); + Ok(entities) + } + + pub fn _extract_entities_for_non_space_separated( + &self, + sentence: &str, + filter_entity_kinds: Option<&[BuiltinEntityKind]>, + ) -> Result> { + let original_tokens_bytes_ranges: Vec> = NON_SPACE_REGEX + .find_iter(sentence) + .map(|m| m.start()..m.end()) + .collect(); + + let joined_sentence = original_tokens_bytes_ranges + .iter() + .map(|r| &sentence[r.clone()]) + .join(""); + + if original_tokens_bytes_ranges.is_empty() { + return Ok(vec![]); + } + + let ranges_mapping = get_ranges_mapping(&original_tokens_bytes_ranges); + + Ok(self + ._extract_entities(&*joined_sentence, filter_entity_kinds)? + .into_iter() + .filter_map(|ent| { + let byte_range = convert_to_byte_range(&*joined_sentence, &ent.range); + let start = byte_range.start; + let end = byte_range.end; + // Check if match range correspond to original tokens otherwise skip the entity + if (start == 0 as usize || ranges_mapping.contains_key(&start)) + && (ranges_mapping.contains_key(&end)) + { + let start_token_index = if start == 0 as usize { + 0 as usize + } else { + ranges_mapping[&start] + 1 + }; + let end_token_index = ranges_mapping[&end]; + + let original_start = original_tokens_bytes_ranges[start_token_index].start; + let original_end = original_tokens_bytes_ranges[end_token_index].end; + let value = sentence[original_start..original_end].to_string(); + + let original_ent = BuiltinEntity { + value, + range: convert_to_char_index(&sentence, original_start) + ..convert_to_char_index(&sentence, original_end), + entity: ent.entity, + entity_kind: ent.entity_kind, + }; + Some(original_ent) + } else { + None + } + }) + .collect()) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct BuiltinParserMetadata { + pub language: String, + pub gazetteer_parser: Option, +} + +impl BuiltinEntityParser { + pub fn persist>(&self, path: P) -> Result<()> { + fs::create_dir(path.as_ref()).with_context(|_| { + format!( + "Cannot create builtin entity parser directory at path: {:?}", + path.as_ref() + ) + })?; + let gazetteer_parser_directory = if let Some(ref gazetteer_parser) = self.gazetteer_parser { + let gazetteer_parser_path = path.as_ref().join("gazetteer_entity_parser"); + gazetteer_parser.persist(gazetteer_parser_path)?; + Some("gazetteer_entity_parser".to_string()) + } else { + None + }; + let gazetteer_parser_metadata = BuiltinParserMetadata { + language: self.language.to_string(), + gazetteer_parser: gazetteer_parser_directory, + }; + let metadata_path = path.as_ref().join("metadata.json"); + let metadata_file = fs::File::create(&metadata_path).with_context(|_| { + format!("Cannot create metadata file at path: {:?}", metadata_path) + })?; + serde_json::to_writer_pretty(metadata_file, &gazetteer_parser_metadata) + .with_context(|_| "Cannot serialize builtin parser metadata")?; + Ok(()) + } + + pub fn from_path>(path: P) -> Result { + let metadata_path = path.as_ref().join("metadata.json"); + let metadata_file = fs::File::open(&metadata_path).with_context(|_| { + format!( + "Cannot open builtin parser metadata file at path: {:?}", + metadata_path + ) + })?; + let metadata: BuiltinParserMetadata = serde_json::from_reader(metadata_file) + .with_context(|_| "Cannot deserialize builtin parser metadata")?; + let language = Language::from_str(&metadata.language)?; + let mut parser_loader = BuiltinEntityParserLoader::new(language); + if let Some(gazetteer_parser_dir) = metadata.gazetteer_parser { + let gazetteer_parser_path = path.as_ref().join(&gazetteer_parser_dir); + parser_loader.use_gazetter_parser(gazetteer_parser_path); + } + parser_loader.load() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::test_utils::test_path; + use snips_nlu_ontology::language::Language; + use snips_nlu_ontology::IntoBuiltinEntityKind; + use snips_nlu_ontology::SlotValue::InstantTime; + use tempfile::tempdir; + + #[test] + fn test_entities_extraction() { + let parser = BuiltinEntityParserLoader::new(Language::EN).load().unwrap(); + assert_eq!( + vec![BuiltinEntityKind::Number, BuiltinEntityKind::Time], + parser + .extract_entities("Book me restaurant for two people tomorrow", None) + .unwrap() + .iter() + .map(|e| e.entity_kind) + .collect_vec() + ); + + assert_eq!( + vec![BuiltinEntityKind::Duration], + parser + .extract_entities("The weather during two weeks", None) + .unwrap() + .iter() + .map(|e| e.entity_kind) + .collect_vec() + ); + + assert_eq!( + vec![BuiltinEntityKind::Percentage], + parser + .extract_entities("Set light to ten percents", None) + .unwrap() + .iter() + .map(|e| e.entity_kind) + .collect_vec() + ); + + assert_eq!( + vec![BuiltinEntityKind::AmountOfMoney], + parser + .extract_entities( + "I would like to do a bank transfer of ten euros for my friends", + None, + ) + .unwrap() + .iter() + .map(|e| e.entity_kind) + .collect_vec() + ); + } + + #[test] + fn test_entities_extraction_with_empty_scope() { + let parser = BuiltinEntityParserLoader::new(Language::EN).load().unwrap(); + let entities = parser + .extract_entities("tomorrow morning", Some(&[])) + .unwrap(); + assert_eq!(Vec::::new(), entities); + } + + #[test] + fn test_entities_extraction_with_gazetteer_entities() { + // Given + let language = Language::FR; + let parser = BuiltinEntityParserLoader::new(language) + .use_gazetter_parser(test_path().join("builtin_gazetteer_parser")) + .load() + .unwrap(); + + // When + let above_threshold_entity = parser + .extract_entities("Je voudrais écouter the stones s'il vous plaît", None) + .unwrap(); + let below_threshold_entity = parser + .extract_entities("Je voudrais écouter les stones", None) + .unwrap(); + + // Then + let expected_entity = BuiltinEntity { + value: "the stones".to_string(), + range: 20..30, + entity: SlotValue::MusicArtist(StringValue { + value: "The Rolling Stones".to_string(), + }), + entity_kind: BuiltinEntityKind::MusicArtist, + }; + assert_eq!(vec![expected_entity], above_threshold_entity); + assert_eq!(Vec::::new(), below_threshold_entity); + } + + #[test] + fn test_entities_extraction_for_non_space_separated_languages() { + let parser = BuiltinEntityParserLoader::new(Language::JA).load().unwrap(); + let expected_time_value = InstantTimeValue { + value: "2013-02-10 00:00:00 +01:00".to_string(), + grain: Grain::Day, + precision: Precision::Exact, + }; + + let expected_entity = BuiltinEntity { + value: "二 千 十三 年二 月十 日".to_string(), + range: 10..24, + entity_kind: BuiltinEntityKind::Time, + entity: InstantTime(expected_time_value.clone()), + }; + + let parsed_entities = parser.extract_entities( + " の カリフォル 二 千 十三 年二 月十 日 ニア州の天気予報は?", + None, + ).unwrap(); + assert_eq!(1, parsed_entities.len()); + let parsed_entity = &parsed_entities[0]; + assert_eq!(expected_entity.value, parsed_entity.value); + assert_eq!(expected_entity.range, parsed_entity.range); + assert_eq!(expected_entity.entity_kind, parsed_entity.entity_kind); + + if let SlotValue::InstantTime(ref parsed_time) = parsed_entity.entity { + assert_eq!(expected_time_value.grain, parsed_time.grain); + assert_eq!(expected_time_value.precision, parsed_time.precision); + } else { + panic!("") + } + + assert_eq!( + Vec::::new(), + parser.extract_entities( + "二 千 十三 年二 月十 日の カリフォルニア州の天気予報は?", + None, + ).unwrap() + ); + } + + #[test] + fn test_entity_examples_should_be_parsed() { + for language in Language::all() { + let parser = BuiltinEntityParserLoader::new(*language).load().unwrap(); + for entity_kind in GrammarEntityKind::all() { + for example in (*entity_kind).examples(*language) { + let results = parser + .extract_entities(example, Some(&[entity_kind.into_builtin_kind()])) + .unwrap(); + assert_eq!( + 1, + results.len(), + "Expected 1 result for entity kind '{:?}' in language '{:?}' for example \ + {:?}, but found: {:?}", + entity_kind, + language, + example, + results + ); + assert_eq!(example.to_string(), results[0].value); + } + } + } + } + + #[test] + fn test_should_persist_parser() { + // Given + let language = Language::FR; + let parser = BuiltinEntityParserLoader::new(language).load().unwrap(); + + let temp_dir = tempdir().unwrap(); + let parser_dir = temp_dir.path().join("builtin_entity_parser"); + + // When + parser.persist(&parser_dir).unwrap(); + let loaded_parser = BuiltinEntityParser::from_path(&parser_dir).unwrap(); + + // Then + assert_eq!(parser.language, loaded_parser.language); + assert_eq!(None, loaded_parser.gazetteer_parser); + assert_eq!( + parser.rustling_entity_kinds, + loaded_parser.rustling_entity_kinds + ); + } + + #[test] + fn test_should_load_parser_from_path() { + // Given + let parser_path = test_path().join("builtin_entity_parser_no_gazetteer"); + + // When + let parser = BuiltinEntityParser::from_path(parser_path).unwrap(); + + // Then + let expected_parser = BuiltinEntityParserLoader::new(Language::EN).load().unwrap(); + assert_eq!(expected_parser.language, parser.language); + assert_eq!(expected_parser.gazetteer_parser, parser.gazetteer_parser); + assert_eq!( + expected_parser.rustling_entity_kinds, + parser.rustling_entity_kinds + ); + } + + #[test] + fn test_should_persist_parser_with_gazetteer_entities() { + // Given + let language = Language::FR; + let parser = BuiltinEntityParserLoader::new(language) + .use_gazetter_parser(test_path().join("builtin_gazetteer_parser")) + .load() + .unwrap(); + + let temp_dir = tempdir().unwrap(); + let parser_dir = temp_dir.path().join("builtin_entity_parser"); + + // When + parser.persist(&parser_dir).unwrap(); + let loaded_parser = BuiltinEntityParser::from_path(&parser_dir).unwrap(); + + // Then + assert_eq!(parser.language, loaded_parser.language); + assert_eq!(parser.gazetteer_parser, loaded_parser.gazetteer_parser); + assert_eq!( + parser.rustling_entity_kinds, + loaded_parser.rustling_entity_kinds + ); + } + + #[test] + fn test_should_load_parser_with_gazetteer_entities_from_path() { + // Given + let parser_path = test_path().join("builtin_entity_parser"); + + // When + let parser = BuiltinEntityParser::from_path(parser_path).unwrap(); + + // Then + let expected_parser = BuiltinEntityParserLoader::new(Language::FR) + .use_gazetter_parser(test_path().join("builtin_gazetteer_parser")) + .load() + .unwrap(); + assert_eq!(expected_parser.language, parser.language); + assert_eq!(expected_parser.gazetteer_parser, parser.gazetteer_parser); + assert_eq!( + expected_parser.rustling_entity_kinds, + parser.rustling_entity_kinds + ); + } +} diff --git a/src/conversion/gazetteer_entities.rs b/src/conversion/gazetteer_entities.rs new file mode 100644 index 0000000..17df134 --- /dev/null +++ b/src/conversion/gazetteer_entities.rs @@ -0,0 +1,18 @@ +use snips_nlu_ontology::{BuiltinGazetteerEntityKind, SlotValue, StringValue}; + +pub fn convert_to_slot_value( + resolved_value: String, + entity_kind: BuiltinGazetteerEntityKind, +) -> SlotValue { + macro_rules! match_entity_kind_to_slot_value { + ($($varname:ident),*) => { + match entity_kind { + $( + BuiltinGazetteerEntityKind::$varname => SlotValue::$varname( + StringValue {value: resolved_value}), + )* + } + } + }; + return match_entity_kind_to_slot_value!(MusicAlbum, MusicArtist, MusicTrack); +} diff --git a/src/conversion/mod.rs b/src/conversion/mod.rs new file mode 100644 index 0000000..64d88eb --- /dev/null +++ b/src/conversion/mod.rs @@ -0,0 +1,49 @@ +pub mod gazetteer_entities; +pub mod rustling; + +use crate::errors::*; + +pub trait OntologyFrom { + fn ontology_from(_: T) -> Self; +} + +pub trait OntologyInto { + fn ontology_into(self) -> T; +} + +pub trait TryOntologyInto: Sized { + /// Performs the conversion. + fn try_ontology_into(self) -> Result; +} + +/// Attempt to construct `Self` via a conversion. +pub trait TryOntologyFrom: Sized { + /// Performs the conversion. + fn try_ontology_from(value: T) -> Result; +} + +impl OntologyInto for T +where + U: OntologyFrom, +{ + fn ontology_into(self) -> U { + U::ontology_from(self) + } +} + +// From (and thus Into) is reflexive +impl OntologyFrom for T { + fn ontology_from(t: T) -> T { + t + } +} + +// TryFrom implies TryInto +impl TryOntologyInto for T +where + U: TryOntologyFrom, +{ + fn try_ontology_into(self) -> Result { + U::try_ontology_from(self) + } +} diff --git a/src/conversion/rustling.rs b/src/conversion/rustling.rs new file mode 100644 index 0000000..2a4a9fd --- /dev/null +++ b/src/conversion/rustling.rs @@ -0,0 +1,242 @@ +use crate::conversion::*; +use crate::errors::Result; +use failure::format_err; +use rustling_ontology::dimension::Precision as RustlingPrecision; +use rustling_ontology::output::{ + AmountOfMoneyOutput, DurationOutput, FloatOutput, IntegerOutput, OrdinalOutput, Output, + OutputKind, PercentageOutput, TemperatureOutput, TimeIntervalOutput, TimeOutput, +}; +use rustling_ontology::Grain as RustlingGrain; +use rustling_ontology::Lang as RustlingLanguage; +use rustling_ontology::ParserMatch; +use snips_nlu_ontology::*; + +impl OntologyFrom for NumberValue { + fn ontology_from(rustling_output: IntegerOutput) -> Self { + Self { + value: rustling_output.0 as f64, + } + } +} + +impl OntologyFrom for NumberValue { + fn ontology_from(rustling_output: FloatOutput) -> Self { + Self { + value: rustling_output.0 as f64, + } + } +} + +impl OntologyFrom for OrdinalValue { + fn ontology_from(rustling_output: OrdinalOutput) -> Self { + Self { + value: rustling_output.0, + } + } +} + +impl OntologyFrom for PercentageValue { + fn ontology_from(rustling_output: PercentageOutput) -> Self { + Self { + value: rustling_output.0 as f64, + } + } +} + +impl OntologyFrom for InstantTimeValue { + fn ontology_from(rustling_output: TimeOutput) -> Self { + Self { + value: rustling_output.moment.to_string(), + grain: Grain::ontology_from(rustling_output.grain), + precision: Precision::ontology_from(rustling_output.precision), + } + } +} + +impl OntologyFrom for TimeIntervalValue { + fn ontology_from(rustling_output: TimeIntervalOutput) -> Self { + match rustling_output { + TimeIntervalOutput::After(after) => Self { + from: Some(after.moment.to_string()), + to: None, + }, + TimeIntervalOutput::Before(before) => Self { + from: None, + to: Some(before.moment.to_string()), + }, + TimeIntervalOutput::Between { + start, + end, + precision: _, + latent: _, + } => Self { + from: Some(start.to_string()), + to: Some(end.to_string()), + }, + } + } +} + +impl OntologyFrom for AmountOfMoneyValue { + fn ontology_from(rustling_output: AmountOfMoneyOutput) -> Self { + Self { + value: rustling_output.value, + precision: rustling_output.precision.ontology_into(), + unit: rustling_output.unit.map(|s| s.to_string()), + } + } +} + +impl OntologyFrom for TemperatureValue { + fn ontology_from(rustling_output: TemperatureOutput) -> Self { + Self { + value: rustling_output.value, + unit: rustling_output.unit.map(|s| s.to_string()), + } + } +} + +impl OntologyFrom for DurationValue { + fn ontology_from(rustling_output: DurationOutput) -> Self { + let mut years: i64 = 0; + let mut quarters: i64 = 0; + let mut months: i64 = 0; + let mut weeks: i64 = 0; + let mut days: i64 = 0; + let mut hours: i64 = 0; + let mut minutes: i64 = 0; + let mut seconds: i64 = 0; + for comp in rustling_output.period.comps().iter() { + match comp.grain { + RustlingGrain::Year => years = comp.quantity, + RustlingGrain::Quarter => quarters = comp.quantity, + RustlingGrain::Month => months = comp.quantity, + RustlingGrain::Week => weeks = comp.quantity, + RustlingGrain::Day => days = comp.quantity, + RustlingGrain::Hour => hours = comp.quantity, + RustlingGrain::Minute => minutes = comp.quantity, + RustlingGrain::Second => seconds = comp.quantity, + } + } + + Self { + years, + quarters, + months, + weeks, + days, + hours, + minutes, + seconds, + precision: rustling_output.precision.ontology_into(), + } + } +} + +impl OntologyFrom for Grain { + fn ontology_from(rustling_output: RustlingGrain) -> Self { + match rustling_output { + RustlingGrain::Year => Grain::Year, + RustlingGrain::Quarter => Grain::Quarter, + RustlingGrain::Month => Grain::Month, + RustlingGrain::Week => Grain::Week, + RustlingGrain::Day => Grain::Day, + RustlingGrain::Hour => Grain::Hour, + RustlingGrain::Minute => Grain::Minute, + RustlingGrain::Second => Grain::Second, + } + } +} + +impl OntologyFrom for Precision { + fn ontology_from(rustling_output: RustlingPrecision) -> Self { + match rustling_output { + RustlingPrecision::Approximate => Precision::Approximate, + RustlingPrecision::Exact => Precision::Exact, + } + } +} + +impl OntologyFrom for SlotValue { + fn ontology_from(rustling_output: Output) -> Self { + match rustling_output { + Output::AmountOfMoney(v) => SlotValue::AmountOfMoney(v.ontology_into()), + Output::Percentage(v) => SlotValue::Percentage(v.ontology_into()), + Output::Duration(v) => SlotValue::Duration(v.ontology_into()), + Output::Float(v) => SlotValue::Number(v.ontology_into()), + Output::Integer(v) => SlotValue::Number(v.ontology_into()), + Output::Ordinal(v) => SlotValue::Ordinal(v.ontology_into()), + Output::Temperature(v) => SlotValue::Temperature(v.ontology_into()), + Output::Time(v) => SlotValue::InstantTime(v.ontology_into()), + Output::TimeInterval(v) => SlotValue::TimeInterval(v.ontology_into()), + } + } +} + +pub fn convert_to_builtin(input: &str, parser_match: ParserMatch) -> BuiltinEntity { + BuiltinEntity { + value: input[parser_match.byte_range.0..parser_match.byte_range.1].into(), + range: parser_match.char_range.0..parser_match.char_range.1, + entity: parser_match.value.clone().ontology_into(), + entity_kind: BuiltinEntityKind::ontology_from(&parser_match.value), + } +} + +impl<'a> OntologyFrom<&'a Output> for BuiltinEntityKind { + fn ontology_from(v: &Output) -> Self { + match *v { + Output::AmountOfMoney(_) => BuiltinEntityKind::AmountOfMoney, + Output::Duration(_) => BuiltinEntityKind::Duration, + Output::Float(_) => BuiltinEntityKind::Number, + Output::Integer(_) => BuiltinEntityKind::Number, + Output::Ordinal(_) => BuiltinEntityKind::Ordinal, + Output::Temperature(_) => BuiltinEntityKind::Temperature, + Output::Time(_) => BuiltinEntityKind::Time, + Output::TimeInterval(_) => BuiltinEntityKind::Time, + Output::Percentage(_) => BuiltinEntityKind::Percentage, + } + } +} + +impl<'a> OntologyFrom<&'a OutputKind> for BuiltinEntityKind { + fn ontology_from(v: &OutputKind) -> Self { + match *v { + OutputKind::AmountOfMoney => BuiltinEntityKind::AmountOfMoney, + OutputKind::Duration => BuiltinEntityKind::Duration, + OutputKind::Number => BuiltinEntityKind::Number, + OutputKind::Ordinal => BuiltinEntityKind::Ordinal, + OutputKind::Temperature => BuiltinEntityKind::Temperature, + OutputKind::Time => BuiltinEntityKind::Time, + OutputKind::Percentage => BuiltinEntityKind::Percentage, + } + } +} + +impl<'a> TryOntologyFrom<&'a BuiltinEntityKind> for OutputKind { + fn try_ontology_from(v: &BuiltinEntityKind) -> Result { + match *v { + BuiltinEntityKind::AmountOfMoney => Ok(OutputKind::AmountOfMoney), + BuiltinEntityKind::Duration => Ok(OutputKind::Duration), + BuiltinEntityKind::Number => Ok(OutputKind::Number), + BuiltinEntityKind::Ordinal => Ok(OutputKind::Ordinal), + BuiltinEntityKind::Temperature => Ok(OutputKind::Temperature), + BuiltinEntityKind::Time => Ok(OutputKind::Time), + BuiltinEntityKind::Percentage => Ok(OutputKind::Percentage), + _ => Err(format_err!("Cannot convert {:?} into rustling type", v)), + } + } +} + +impl OntologyFrom for RustlingLanguage { + fn ontology_from(lang: Language) -> Self { + match lang { + Language::DE => RustlingLanguage::DE, + Language::EN => RustlingLanguage::EN, + Language::ES => RustlingLanguage::ES, + Language::FR => RustlingLanguage::FR, + Language::IT => RustlingLanguage::IT, + Language::JA => RustlingLanguage::JA, + Language::KO => RustlingLanguage::KO, + } + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..3295431 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1 @@ +pub type Result = ::std::result::Result; diff --git a/src/gazetteer_parser.rs b/src/gazetteer_parser.rs new file mode 100644 index 0000000..f6f9664 --- /dev/null +++ b/src/gazetteer_parser.rs @@ -0,0 +1,473 @@ +use std::fmt::Debug; +use std::fs; +use std::fs::File; +use std::ops::Range; +use std::path::Path; + +use crate::conversion::gazetteer_entities::convert_to_slot_value; +use crate::errors::*; +use failure::ResultExt; +use gazetteer_entity_parser::{Parser as EntityParser, ParserBuilder as EntityParserBuilder}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_derive::{Deserialize, Serialize}; +use serde_json; +use snips_nlu_ontology::{BuiltinEntity, BuiltinGazetteerEntityKind, IntoBuiltinEntityKind}; +use snips_nlu_utils::string::substring_with_char_range; + +pub trait EntityIdentifier: + Clone + Debug + PartialEq + Serialize + DeserializeOwned + Sized +{ + fn try_from_identifier(identifier: String) -> Result; + fn into_identifier(self) -> String; +} + +impl EntityIdentifier for String { + fn try_from_identifier(identifier: String) -> Result { + Ok(identifier) + } + + fn into_identifier(self) -> String { + self + } +} + +impl EntityIdentifier for BuiltinGazetteerEntityKind { + fn try_from_identifier(identifier: String) -> Result { + BuiltinGazetteerEntityKind::from_identifier(&identifier) + } + + fn into_identifier(self) -> String { + self.identifier().to_string() + } +} + +#[derive(PartialEq, Debug)] +pub struct GazetteerParser +where + T: EntityIdentifier, +{ + entity_parsers: Vec>, +} + +#[derive(PartialEq, Debug)] +struct GazetteerEntityParser +where + T: EntityIdentifier, +{ + entity_identifier: T, + parser: EntityParser, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct GazetteerParserBuilder { + pub entity_parsers: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct GazetteerEntityParserBuilder { + pub entity_identifier: String, + pub entity_parser: EntityParserBuilder, +} + +impl GazetteerParserBuilder { + pub fn build(self) -> Result> + where + T: EntityIdentifier, + { + let entity_parsers = self + .entity_parsers + .into_iter() + .map(|parser_builder| parser_builder.build()) + .collect::>()?; + Ok(GazetteerParser { entity_parsers }) + } +} + +impl GazetteerEntityParserBuilder { + fn build(self) -> Result> + where + T: EntityIdentifier, + { + Ok(GazetteerEntityParser { + entity_identifier: T::try_from_identifier(self.entity_identifier)?, + parser: self.entity_parser.build()?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct GazetteerEntityMatch +where + T: EntityIdentifier, +{ + pub value: String, + pub resolved_value: String, + pub range: Range, + pub entity_identifier: T, +} + +impl GazetteerParser +where + T: EntityIdentifier, +{ + pub fn extract_entities( + &self, + sentence: &str, + filter_entities: Option<&[T]>, + ) -> Result>> { + Ok(self + .entity_parsers + .iter() + .filter(|&parser| { + filter_entities + .map(|kinds| kinds.contains(&parser.entity_identifier)) + .unwrap_or(true) + }) + .map(|parser| { + Ok(parser + .parser + .run(&sentence.to_lowercase())? + .into_iter() + .map(|parsed_value| GazetteerEntityMatch { + value: substring_with_char_range(sentence.to_string(), &parsed_value.range), + range: parsed_value.range, + resolved_value: parsed_value.resolved_value, + entity_identifier: parser.entity_identifier.clone(), + }) + .collect::>()) + }) + .collect::>>()? + .into_iter() + .flat_map(|v| v) + .collect()) + } +} + +impl GazetteerParser { + pub fn extract_builtin_entities( + &self, + sentence: &str, + filter_entities: Option<&[BuiltinGazetteerEntityKind]>, + ) -> Result> { + Ok(self + .extract_entities(sentence, filter_entities)? + .into_iter() + .map(|entity_match| BuiltinEntity { + value: entity_match.value, + range: entity_match.range, + entity: convert_to_slot_value( + entity_match.resolved_value, + entity_match.entity_identifier, + ), + entity_kind: entity_match.entity_identifier.into_builtin_kind(), + }) + .collect()) + } +} + +#[derive(Serialize, Deserialize, Default)] +pub struct GazetteerParserMetadata { + pub parsers_metadata: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct EntityParserMetadata { + pub entity_identifier: String, + pub entity_parser: String, +} + +impl GazetteerParser +where + T: EntityIdentifier, +{ + pub fn persist>(&self, path: P) -> Result<()> { + fs::create_dir(path.as_ref()).with_context(|_| { + format!( + "Cannot create gazetteer parser directory at path: {:?}", + path.as_ref() + ) + })?; + let mut gazetteer_parser_metadata = GazetteerParserMetadata::default(); + for (index, entity_parser) in self.entity_parsers.iter().enumerate() { + let parser_directory = format!("parser_{}", index + 1); + let parser_path = path.as_ref().join(&parser_directory); + let entity_identifier = entity_parser.entity_identifier.clone().into_identifier(); + entity_parser.parser.dump(parser_path).with_context(|_| { + format!( + "Cannot dump entity parser for entity '{}'", + &entity_identifier + ) + })?; + gazetteer_parser_metadata + .parsers_metadata + .push(EntityParserMetadata { + entity_identifier, + entity_parser: parser_directory, + }) + } + let metadata_path = path.as_ref().join("metadata.json"); + let metadata_file = File::create(&metadata_path).with_context(|_| { + format!( + "Cannot create metadata file for gazetteer parser at path: {:?}", + metadata_path + ) + })?; + serde_json::to_writer_pretty(metadata_file, &gazetteer_parser_metadata) + .with_context(|_| "Cannot serialize gazetteer parser metadata")?; + Ok(()) + } +} + +impl GazetteerParser +where + T: EntityIdentifier, +{ + pub fn from_path>(path: P) -> Result { + let metadata_path = path.as_ref().join("metadata.json"); + let metadata_file = File::open(&metadata_path).with_context(|_| { + format!( + "Cannot open metadata file for gazetteer parser at path: {:?}", + metadata_path + ) + })?; + let metadata: GazetteerParserMetadata = serde_json::from_reader(metadata_file) + .with_context(|_| "Cannot deserialize gazetteer parser metadata")?; + let entity_parsers = metadata + .parsers_metadata + .into_iter() + .map(|entity_parser_metadata| { + let parser = EntityParser::from_folder( + path.as_ref().join(&entity_parser_metadata.entity_parser), + ) + .with_context(|_| { + format!( + "Cannot create entity parser from path: {}", + entity_parser_metadata.entity_parser + ) + })?; + Ok(GazetteerEntityParser { + entity_identifier: T::try_from_identifier( + entity_parser_metadata.entity_identifier, + )?, + parser, + }) + }) + .collect::>()?; + Ok(Self { entity_parsers }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::test_path; + use gazetteer_entity_parser::EntityValue; + use gazetteer_entity_parser::ParserBuilder; + use snips_nlu_ontology::{ + BuiltinEntityKind, BuiltinGazetteerEntityKind, SlotValue, StringValue, + }; + use tempfile::tempdir; + + fn get_test_custom_gazetteer_parser() -> GazetteerParser { + let artist_entity_parser_builder = get_test_music_artist_parser_builder(); + let track_entity_parser_builder = get_test_music_track_parser_builder(); + let gazetteer_parser_builder = GazetteerParserBuilder { + entity_parsers: vec![ + GazetteerEntityParserBuilder { + entity_identifier: "music_artist".to_string(), + entity_parser: artist_entity_parser_builder, + }, + GazetteerEntityParserBuilder { + entity_identifier: "music_track".to_string(), + entity_parser: track_entity_parser_builder, + }, + ], + }; + gazetteer_parser_builder.build().unwrap() + } + + fn get_test_builtin_gazetteer_parser() -> GazetteerParser { + let artist_entity_parser_builder = get_test_music_artist_parser_builder(); + let track_entity_parser_builder = get_test_music_track_parser_builder(); + let gazetteer_parser_builder = GazetteerParserBuilder { + entity_parsers: vec![ + GazetteerEntityParserBuilder { + entity_identifier: "snips/musicArtist".to_string(), + entity_parser: artist_entity_parser_builder, + }, + GazetteerEntityParserBuilder { + entity_identifier: "snips/musicTrack".to_string(), + entity_parser: track_entity_parser_builder, + }, + ], + }; + gazetteer_parser_builder.build().unwrap() + } + + fn get_test_music_track_parser_builder() -> ParserBuilder { + let track_entity_parser_builder = EntityParserBuilder::default() + .minimum_tokens_ratio(0.7) + .add_value(EntityValue { + raw_value: "harder better faster stronger".to_string(), + resolved_value: "Harder Better Faster Stronger".to_string(), + }) + .add_value(EntityValue { + raw_value: "what s my age again".to_string(), + resolved_value: "What's my age again".to_string(), + }); + track_entity_parser_builder + } + + fn get_test_music_artist_parser_builder() -> ParserBuilder { + EntityParserBuilder::default() + .minimum_tokens_ratio(0.6) + .add_value(EntityValue { + raw_value: "the rolling stones".to_string(), + resolved_value: "The Rolling Stones".to_string(), + }) + .add_value(EntityValue { + raw_value: "blink one eight two".to_string(), + resolved_value: "Blink 182".to_string(), + }) + } + + #[test] + fn test_should_parse_above_threshold() { + // Given + let gazetteer_parser = get_test_custom_gazetteer_parser(); + + // When + let input = "I want to listen to the track harder better faster please"; + let result = gazetteer_parser.extract_entities(input, None); + + // Then + let expected_match = GazetteerEntityMatch { + value: "harder better faster".to_string(), + resolved_value: "Harder Better Faster Stronger".to_string(), + range: 30..50, + entity_identifier: "music_track".to_string(), + }; + assert_eq!(Some(vec![expected_match]), result.ok()); + } + + #[test] + fn test_should_not_parse_below_threshold() { + // Given + let gazetteer_parser = get_test_custom_gazetteer_parser(); + + // When + let input = "I want to listen to the track harder better please"; + let result = gazetteer_parser.extract_entities(input, None); + + // Then + assert_eq!(Some(vec![]), result.ok()); + } + + #[test] + fn test_should_parse_using_scope() { + // Given + let gazetteer_parser = get_test_custom_gazetteer_parser(); + + // When + let input = "I want to listen to what s my age again by blink one eight two"; + let artist_scope: &[String] = &["music_artist".to_string()]; + let result_artist = gazetteer_parser.extract_entities(input, Some(artist_scope)); + let track_scope: &[String] = &["music_track".to_string()]; + let result_track = gazetteer_parser.extract_entities(input, Some(track_scope)); + + // Then + let expected_artist_match = GazetteerEntityMatch { + value: "blink one eight two".to_string(), + resolved_value: "Blink 182".to_string(), + range: 43..62, + entity_identifier: "music_artist".to_string(), + }; + + let expected_track_match = GazetteerEntityMatch { + value: "what s my age again".to_string(), + resolved_value: "What's my age again".to_string(), + range: 20..39, + entity_identifier: "music_track".to_string(), + }; + assert_eq!(Some(vec![expected_artist_match]), result_artist.ok()); + assert_eq!(Some(vec![expected_track_match]), result_track.ok()); + } + + #[test] + fn test_should_parse_with_builtin_entities() { + // Given + let builtin_gazetteer_parser = get_test_builtin_gazetteer_parser(); + + // When + let input = "I want to listen to the track harder better faster please"; + let result = builtin_gazetteer_parser.extract_builtin_entities(input, None); + + // Then + let expected_match = BuiltinEntity { + value: "harder better faster".to_string(), + entity: SlotValue::MusicTrack(StringValue { + value: "Harder Better Faster Stronger".to_string(), + }), + range: 30..50, + entity_kind: BuiltinEntityKind::MusicTrack, + }; + assert_eq!(Some(vec![expected_match]), result.ok()); + } + + #[test] + fn test_should_persist_custom_gazetteer_parser() { + // Given + let gazetteer_parser = get_test_custom_gazetteer_parser(); + let temp_dir = tempdir().unwrap(); + let parser_dir = temp_dir.path().join("custom_gazetteer_parser"); + + // When + gazetteer_parser.persist(&parser_dir).unwrap(); + let loaded_gazetteer_parser = GazetteerParser::from_path(&parser_dir).unwrap(); + + // Then + assert_eq!(gazetteer_parser, loaded_gazetteer_parser); + } + + #[test] + fn test_should_load_custom_gazetteer_parser_from_path() { + // Given + let path = test_path().join("custom_gazetteer_parser"); + + // When + let parser = GazetteerParser::from_path(path); + + // Then + let expected_parser = get_test_custom_gazetteer_parser(); + assert_eq!(Some(expected_parser), parser.ok()); + } + + #[test] + fn test_should_persist_builtin_gazetteer_parser() { + // Given + let gazetteer_parser = get_test_builtin_gazetteer_parser(); + let temp_dir = tempdir().unwrap(); + let parser_dir = temp_dir.path().join("builtin_gazetteer_parser"); + + // When + gazetteer_parser.persist(&parser_dir).unwrap(); + let loaded_gazetteer_parser = GazetteerParser::from_path(&parser_dir).unwrap(); + + // Then + assert_eq!(gazetteer_parser, loaded_gazetteer_parser); + } + + #[test] + fn test_should_load_builtin_gazetteer_parser_from_path() { + // Given + let path = test_path().join("builtin_gazetteer_parser"); + + // When + let parser = GazetteerParser::from_path(path); + + // Then + let expected_parser = get_test_builtin_gazetteer_parser(); + assert_eq!(Some(expected_parser), parser.ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c5fa1f5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,13 @@ +pub mod errors; + +mod builtin_entity_parser; +mod conversion; +mod gazetteer_parser; +#[cfg(test)] +mod test_utils; +mod utils; + +pub use builtin_entity_parser::*; +pub use conversion::*; +pub use gazetteer_parser::*; +pub use snips_nlu_ontology::*; diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..31790e6 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,5 @@ +use std::path::{Path, PathBuf}; + +pub fn test_path() -> PathBuf { + Path::new("data").join("tests") +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..700ffdd --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,42 @@ +use std::collections::{HashMap, HashSet}; +use std::iter::FromIterator; +use std::ops::Range; + +use lazy_static::lazy_static; +use regex::Regex; +use snips_nlu_ontology::Language; + +lazy_static! { + pub static ref NON_SPACE_REGEX: Regex = Regex::new(r"[^\s]+").unwrap(); +} + +lazy_static! { + pub static ref NON_SPACE_SEPARATED_LANGUAGES: HashSet = + [Language::JA].into_iter().cloned().collect(); +} + +pub fn get_ranges_mapping(tokens_ranges: &Vec>) -> HashMap { + /* Given tokens ranges returns a mapping of byte index to a token index + The byte indexes corresponds to indexes of the end of tokens in string given by joining all + the tokens. The tokens index gives the index of the tokens preceding the byte index. + + For instance, if range_mapping[65] -> 5, this means that the token of index 6 starts at the + 65th byte in the joined string + */ + let ranges_mapping = HashMap::::from_iter(tokens_ranges.iter().enumerate().fold( + vec![], + |mut acc: Vec<(usize, usize)>, (token_index, ref original_range)| { + let previous_end = if token_index == 0 { + 0 as usize + } else { + acc[acc.len() - 1].0 + }; + acc.push(( + previous_end + original_range.end - original_range.start, + token_index, + )); + acc + }, + )); + ranges_mapping +} diff --git a/update_ontology_version.sh b/update_ontology_version.sh new file mode 100755 index 0000000..28fa793 --- /dev/null +++ b/update_ontology_version.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +NEW_VERSION=${1?"usage $0 "} + +echo "Updating ontology version to ${NEW_VERSION}" +find . -name "Cargo.toml" -exec perl -p -i -e "s/snipsco\/snips-nlu-ontology\".*\$/snipsco\/snips-nlu-ontology\", tag = \"$NEW_VERSION\" }/g" {} \; diff --git a/update_version.sh b/update_version.sh new file mode 100755 index 0000000..20d3800 --- /dev/null +++ b/update_version.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +NEW_VERSION=${1?"usage $0 "} + +echo "Updating versions to version ${NEW_VERSION}" +find . -name "Cargo.toml" -exec perl -p -i -e "s/^version = \".*\"$/version = \"$NEW_VERSION\"/g" {} \; + + +if [[ "${NEW_VERSION}" == "${NEW_VERSION/-SNAPSHOT/}" ]] +then + perl -p -i -e "s/snips-nlu-parsers\", tag = \".*\"/snips-nlu-parsers\", tag = \"$NEW_VERSION\"/g" \ + python/ffi/Cargo.toml + perl -p -i -e "s/snips-nlu-parsers\", branch = \".*\"/snips-nlu-parsers\", tag = \"$NEW_VERSION\"/g" \ + python/ffi/Cargo.toml +else + perl -p -i -e "s/snips-nlu-parsers\", branch = \".*\"/snips-nlu-parsers\", branch = \"develop\"/g" \ + python/ffi/Cargo.toml + perl -p -i -e "s/snips-nlu-parsers\", tag = \".*\"/snips-nlu-parsers\", branch = \"develop\"/g" \ + python/ffi/Cargo.toml +fi + +echo "$NEW_VERSION" > python/snips_nlu_parsers/__version__