From d6c0130477abd51b6fb021d0a57e4d6327a5fbed Mon Sep 17 00:00:00 2001 From: Harrison Date: Mon, 29 May 2023 15:45:37 -0500 Subject: [PATCH] Initial commit Moved training code from 'avarias' --- .gitignore | 230 ++++++++++++++++++++++++++++++++++++++++++++++++ environment.yml | 8 ++ tox.ini | 3 + trainers.py | 121 +++++++++++++++++++++++++ 4 files changed, 362 insertions(+) create mode 100644 .gitignore create mode 100644 environment.yml create mode 100644 tox.ini create mode 100644 trainers.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..785917d --- /dev/null +++ b/.gitignore @@ -0,0 +1,230 @@ +# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python,venv,virtualenv +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python,venv,virtualenv + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### venv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VirtualEnv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python,venv,virtualenv + +# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) + diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..52f9baa --- /dev/null +++ b/environment.yml @@ -0,0 +1,8 @@ +name: huggingface_transformers_utils +channels: + - conda-forge + - huggingface +dependencies: + - transformers=4.28 + - pip: + - evaluate==0.4.0 diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..e0ea542 --- /dev/null +++ b/tox.ini @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203 \ No newline at end of file diff --git a/trainers.py b/trainers.py new file mode 100644 index 0000000..06e47ba --- /dev/null +++ b/trainers.py @@ -0,0 +1,121 @@ +import numpy as np +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorForTokenClassification, + Trainer, + TrainingArguments, + EvalPrediction, +) +from typing import Union +from datasets import load_dataset +from datasets.dataset_dict import ( + DatasetDict, + Dataset, + IterableDatasetDict, +) +from datasets.iterable_dataset import IterableDataset +import evaluate + + +class TokenClassificationTrainer: + def __init__( + self, + model: str, + dataset: str, + labels_name: str = "labels", + evaluator: str = "seqeval", + ) -> None: + self._dataset: Union[ + DatasetDict, Dataset, IterableDatasetDict, IterableDataset + ] = load_dataset(dataset) + self._labels: list[str] = ( + self._dataset["train"].features[labels_name].feature.names + ) # type: ignore + self._id_to_label: dict[int, str] = {} + self._label_to_id: dict[str, int] = {} + for id, label in enumerate(self._labels): + self._id_to_label[id] = label + self._label_to_id[label] = id + self._model = AutoModelForSequenceClassification.from_pretrained( + model, + num_labels=len(self._labels), + id2label=self._id_to_label, + label2id=self._label_to_id, + ) + self._tokenizer = AutoTokenizer.from_pretrained(model) + self._data_collator = DataCollatorForTokenClassification( + tokenizer=self._tokenizer + ) + self._evaluator = evaluate.load(evaluator) + + def tokenize_and_align_labels(self, examples): + # Straight from + # https://huggingface.co/docs/transformers/tasks/token_classification + tokenized_inputs = self._tokenizer( + examples["tokens"], truncation=True, is_split_into_words=True + ) + + labels = [] + for i, label in enumerate(examples[f"ner_tags"]): + word_ids = tokenized_inputs.word_ids( + batch_index=i + ) # Map tokens to their respective word. + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: # Set the special tokens to -100. + if word_idx is None: + label_ids.append(-100) + elif ( + word_idx != previous_word_idx + ): # Only label the first token of a given word. + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels.append(label_ids) + + tokenized_inputs["labels"] = labels + return tokenized_inputs + + def tokenize_and_align_labels_over_dataset(self): + return self._dataset.map(self.tokenize_and_align_labels, batched=True) + + def compute_metrics( + self, evaluation_prediction: EvalPrediction + ) -> dict[str, float]: + predictions, expectations = evaluation_prediction + predictions = np.argmax(predictions, axis=2) + + true_predictions = [ + [self._labels[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, expectations) + ] + + true_labels = [ + [self._labels[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, expectations) + ] + + results: dict[str, float] = self._evaluator.compute( + predictions=true_predictions, references=true_labels + ) # type: ignore + + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + def train(self, output_dir: str, **arguments): + trainer = Trainer( + args=TrainingArguments(output_dir=output_dir, **arguments), + train_dataset=self._dataset["train"], # type: ignore + eval_dataset=self._dataset["test"], # type: ignore + tokenizer=self._tokenizer, + data_collator=self._data_collator, + compute_metrics=self.compute_metrics, + ) + + trainer.train()