Initial commit
Moved training code from 'avarias'
This commit is contained in:
		
							
								
								
									
										230
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,230 @@ | |||||||
|  | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig | ||||||
|  | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python,venv,virtualenv | ||||||
|  | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python,venv,virtualenv | ||||||
|  |  | ||||||
|  | ### Linux ### | ||||||
|  | *~ | ||||||
|  |  | ||||||
|  | # temporary files which can be created if a process still has a handle open of a deleted file | ||||||
|  | .fuse_hidden* | ||||||
|  |  | ||||||
|  | # KDE directory preferences | ||||||
|  | .directory | ||||||
|  |  | ||||||
|  | # Linux trash folder which might appear on any partition or disk | ||||||
|  | .Trash-* | ||||||
|  |  | ||||||
|  | # .nfs files are created when an open file is removed but is still being accessed | ||||||
|  | .nfs* | ||||||
|  |  | ||||||
|  | ### Python ### | ||||||
|  | # Byte-compiled / optimized / DLL files | ||||||
|  | __pycache__/ | ||||||
|  | *.py[cod] | ||||||
|  | *$py.class | ||||||
|  |  | ||||||
|  | # C extensions | ||||||
|  | *.so | ||||||
|  |  | ||||||
|  | # Distribution / packaging | ||||||
|  | .Python | ||||||
|  | build/ | ||||||
|  | develop-eggs/ | ||||||
|  | dist/ | ||||||
|  | downloads/ | ||||||
|  | eggs/ | ||||||
|  | .eggs/ | ||||||
|  | lib/ | ||||||
|  | lib64/ | ||||||
|  | parts/ | ||||||
|  | sdist/ | ||||||
|  | var/ | ||||||
|  | wheels/ | ||||||
|  | share/python-wheels/ | ||||||
|  | *.egg-info/ | ||||||
|  | .installed.cfg | ||||||
|  | *.egg | ||||||
|  | MANIFEST | ||||||
|  |  | ||||||
|  | # PyInstaller | ||||||
|  | #  Usually these files are written by a python script from a template | ||||||
|  | #  before PyInstaller builds the exe, so as to inject date/other infos into it. | ||||||
|  | *.manifest | ||||||
|  | *.spec | ||||||
|  |  | ||||||
|  | # Installer logs | ||||||
|  | pip-log.txt | ||||||
|  | pip-delete-this-directory.txt | ||||||
|  |  | ||||||
|  | # Unit test / coverage reports | ||||||
|  | htmlcov/ | ||||||
|  | .tox/ | ||||||
|  | .nox/ | ||||||
|  | .coverage | ||||||
|  | .coverage.* | ||||||
|  | .cache | ||||||
|  | nosetests.xml | ||||||
|  | coverage.xml | ||||||
|  | *.cover | ||||||
|  | *.py,cover | ||||||
|  | .hypothesis/ | ||||||
|  | .pytest_cache/ | ||||||
|  | cover/ | ||||||
|  |  | ||||||
|  | # Translations | ||||||
|  | *.mo | ||||||
|  | *.pot | ||||||
|  |  | ||||||
|  | # Django stuff: | ||||||
|  | *.log | ||||||
|  | local_settings.py | ||||||
|  | db.sqlite3 | ||||||
|  | db.sqlite3-journal | ||||||
|  |  | ||||||
|  | # Flask stuff: | ||||||
|  | instance/ | ||||||
|  | .webassets-cache | ||||||
|  |  | ||||||
|  | # Scrapy stuff: | ||||||
|  | .scrapy | ||||||
|  |  | ||||||
|  | # Sphinx documentation | ||||||
|  | docs/_build/ | ||||||
|  |  | ||||||
|  | # PyBuilder | ||||||
|  | .pybuilder/ | ||||||
|  | target/ | ||||||
|  |  | ||||||
|  | # Jupyter Notebook | ||||||
|  | .ipynb_checkpoints | ||||||
|  |  | ||||||
|  | # IPython | ||||||
|  | profile_default/ | ||||||
|  | ipython_config.py | ||||||
|  |  | ||||||
|  | # pyenv | ||||||
|  | #   For a library or package, you might want to ignore these files since the code is | ||||||
|  | #   intended to run in multiple environments; otherwise, check them in: | ||||||
|  | # .python-version | ||||||
|  |  | ||||||
|  | # pipenv | ||||||
|  | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||||||
|  | #   However, in case of collaboration, if having platform-specific dependencies or dependencies | ||||||
|  | #   having no cross-platform support, pipenv may install dependencies that don't work, or not | ||||||
|  | #   install all needed dependencies. | ||||||
|  | #Pipfile.lock | ||||||
|  |  | ||||||
|  | # poetry | ||||||
|  | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||||||
|  | #   This is especially recommended for binary packages to ensure reproducibility, and is more | ||||||
|  | #   commonly ignored for libraries. | ||||||
|  | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||||||
|  | #poetry.lock | ||||||
|  |  | ||||||
|  | # pdm | ||||||
|  | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||||||
|  | #pdm.lock | ||||||
|  | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||||||
|  | #   in version control. | ||||||
|  | #   https://pdm.fming.dev/#use-with-ide | ||||||
|  | .pdm.toml | ||||||
|  |  | ||||||
|  | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||||||
|  | __pypackages__/ | ||||||
|  |  | ||||||
|  | # Celery stuff | ||||||
|  | celerybeat-schedule | ||||||
|  | celerybeat.pid | ||||||
|  |  | ||||||
|  | # SageMath parsed files | ||||||
|  | *.sage.py | ||||||
|  |  | ||||||
|  | # Environments | ||||||
|  | .env | ||||||
|  | .venv | ||||||
|  | env/ | ||||||
|  | venv/ | ||||||
|  | ENV/ | ||||||
|  | env.bak/ | ||||||
|  | venv.bak/ | ||||||
|  |  | ||||||
|  | # Spyder project settings | ||||||
|  | .spyderproject | ||||||
|  | .spyproject | ||||||
|  |  | ||||||
|  | # Rope project settings | ||||||
|  | .ropeproject | ||||||
|  |  | ||||||
|  | # mkdocs documentation | ||||||
|  | /site | ||||||
|  |  | ||||||
|  | # mypy | ||||||
|  | .mypy_cache/ | ||||||
|  | .dmypy.json | ||||||
|  | dmypy.json | ||||||
|  |  | ||||||
|  | # Pyre type checker | ||||||
|  | .pyre/ | ||||||
|  |  | ||||||
|  | # pytype static type analyzer | ||||||
|  | .pytype/ | ||||||
|  |  | ||||||
|  | # Cython debug symbols | ||||||
|  | cython_debug/ | ||||||
|  |  | ||||||
|  | # PyCharm | ||||||
|  | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||||||
|  | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||||||
|  | #  and can be added to the global gitignore or merged into this file.  For a more nuclear | ||||||
|  | #  option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||||||
|  | #.idea/ | ||||||
|  |  | ||||||
|  | ### Python Patch ### | ||||||
|  | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration | ||||||
|  | poetry.toml | ||||||
|  |  | ||||||
|  | # ruff | ||||||
|  | .ruff_cache/ | ||||||
|  |  | ||||||
|  | # LSP config files | ||||||
|  | pyrightconfig.json | ||||||
|  |  | ||||||
|  | ### venv ### | ||||||
|  | # Virtualenv | ||||||
|  | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ | ||||||
|  | [Bb]in | ||||||
|  | [Ii]nclude | ||||||
|  | [Ll]ib | ||||||
|  | [Ll]ib64 | ||||||
|  | [Ll]ocal | ||||||
|  | [Ss]cripts | ||||||
|  | pyvenv.cfg | ||||||
|  | pip-selfcheck.json | ||||||
|  |  | ||||||
|  | ### VirtualEnv ### | ||||||
|  | # Virtualenv | ||||||
|  | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ | ||||||
|  |  | ||||||
|  | ### VisualStudioCode ### | ||||||
|  | .vscode/* | ||||||
|  | !.vscode/settings.json | ||||||
|  | !.vscode/tasks.json | ||||||
|  | !.vscode/launch.json | ||||||
|  | !.vscode/extensions.json | ||||||
|  | !.vscode/*.code-snippets | ||||||
|  |  | ||||||
|  | # Local History for Visual Studio Code | ||||||
|  | .history/ | ||||||
|  |  | ||||||
|  | # Built Visual Studio Code Extensions | ||||||
|  | *.vsix | ||||||
|  |  | ||||||
|  | ### VisualStudioCode Patch ### | ||||||
|  | # Ignore all local history of files | ||||||
|  | .history | ||||||
|  | .ionide | ||||||
|  |  | ||||||
|  | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python,venv,virtualenv | ||||||
|  |  | ||||||
|  | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) | ||||||
|  |  | ||||||
							
								
								
									
										8
									
								
								environment.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								environment.yml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | |||||||
|  | name: huggingface_transformers_utils | ||||||
|  | channels: | ||||||
|  |   - conda-forge | ||||||
|  |   - huggingface | ||||||
|  | dependencies: | ||||||
|  |   - transformers=4.28 | ||||||
|  |   - pip: | ||||||
|  |       - evaluate==0.4.0 | ||||||
							
								
								
									
										121
									
								
								trainers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								trainers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | |||||||
|  | import numpy as np | ||||||
|  | from transformers import ( | ||||||
|  |     AutoModelForSequenceClassification, | ||||||
|  |     AutoTokenizer, | ||||||
|  |     DataCollatorForTokenClassification, | ||||||
|  |     Trainer, | ||||||
|  |     TrainingArguments, | ||||||
|  |     EvalPrediction, | ||||||
|  | ) | ||||||
|  | from typing import Union | ||||||
|  | from datasets import load_dataset | ||||||
|  | from datasets.dataset_dict import ( | ||||||
|  |     DatasetDict, | ||||||
|  |     Dataset, | ||||||
|  |     IterableDatasetDict, | ||||||
|  | ) | ||||||
|  | from datasets.iterable_dataset import IterableDataset | ||||||
|  | import evaluate | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TokenClassificationTrainer: | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model: str, | ||||||
|  |         dataset: str, | ||||||
|  |         labels_name: str = "labels", | ||||||
|  |         evaluator: str = "seqeval", | ||||||
|  |     ) -> None: | ||||||
|  |         self._dataset: Union[ | ||||||
|  |             DatasetDict, Dataset, IterableDatasetDict, IterableDataset | ||||||
|  |         ] = load_dataset(dataset) | ||||||
|  |         self._labels: list[str] = ( | ||||||
|  |             self._dataset["train"].features[labels_name].feature.names | ||||||
|  |         )  # type: ignore | ||||||
|  |         self._id_to_label: dict[int, str] = {} | ||||||
|  |         self._label_to_id: dict[str, int] = {} | ||||||
|  |         for id, label in enumerate(self._labels): | ||||||
|  |             self._id_to_label[id] = label | ||||||
|  |             self._label_to_id[label] = id | ||||||
|  |         self._model = AutoModelForSequenceClassification.from_pretrained( | ||||||
|  |             model, | ||||||
|  |             num_labels=len(self._labels), | ||||||
|  |             id2label=self._id_to_label, | ||||||
|  |             label2id=self._label_to_id, | ||||||
|  |         ) | ||||||
|  |         self._tokenizer = AutoTokenizer.from_pretrained(model) | ||||||
|  |         self._data_collator = DataCollatorForTokenClassification( | ||||||
|  |             tokenizer=self._tokenizer | ||||||
|  |         ) | ||||||
|  |         self._evaluator = evaluate.load(evaluator) | ||||||
|  |  | ||||||
|  |     def tokenize_and_align_labels(self, examples): | ||||||
|  |         # Straight from | ||||||
|  |         # https://huggingface.co/docs/transformers/tasks/token_classification | ||||||
|  |         tokenized_inputs = self._tokenizer( | ||||||
|  |             examples["tokens"], truncation=True, is_split_into_words=True | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         labels = [] | ||||||
|  |         for i, label in enumerate(examples[f"ner_tags"]): | ||||||
|  |             word_ids = tokenized_inputs.word_ids( | ||||||
|  |                 batch_index=i | ||||||
|  |             )  # Map tokens to their respective word. | ||||||
|  |             previous_word_idx = None | ||||||
|  |             label_ids = [] | ||||||
|  |             for word_idx in word_ids:  # Set the special tokens to -100. | ||||||
|  |                 if word_idx is None: | ||||||
|  |                     label_ids.append(-100) | ||||||
|  |                 elif ( | ||||||
|  |                     word_idx != previous_word_idx | ||||||
|  |                 ):  # Only label the first token of a given word. | ||||||
|  |                     label_ids.append(label[word_idx]) | ||||||
|  |                 else: | ||||||
|  |                     label_ids.append(-100) | ||||||
|  |                 previous_word_idx = word_idx | ||||||
|  |             labels.append(label_ids) | ||||||
|  |  | ||||||
|  |         tokenized_inputs["labels"] = labels | ||||||
|  |         return tokenized_inputs | ||||||
|  |  | ||||||
|  |     def tokenize_and_align_labels_over_dataset(self): | ||||||
|  |         return self._dataset.map(self.tokenize_and_align_labels, batched=True) | ||||||
|  |  | ||||||
|  |     def compute_metrics( | ||||||
|  |         self, evaluation_prediction: EvalPrediction | ||||||
|  |     ) -> dict[str, float]: | ||||||
|  |         predictions, expectations = evaluation_prediction | ||||||
|  |         predictions = np.argmax(predictions, axis=2) | ||||||
|  |  | ||||||
|  |         true_predictions = [ | ||||||
|  |             [self._labels[p] for (p, l) in zip(prediction, label) if l != -100] | ||||||
|  |             for prediction, label in zip(predictions, expectations) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         true_labels = [ | ||||||
|  |             [self._labels[l] for (p, l) in zip(prediction, label) if l != -100] | ||||||
|  |             for prediction, label in zip(predictions, expectations) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         results: dict[str, float] = self._evaluator.compute( | ||||||
|  |             predictions=true_predictions, references=true_labels | ||||||
|  |         )  # type: ignore | ||||||
|  |          | ||||||
|  |         return { | ||||||
|  |             "precision": results["overall_precision"], | ||||||
|  |             "recall": results["overall_recall"], | ||||||
|  |             "f1": results["overall_f1"], | ||||||
|  |             "accuracy": results["overall_accuracy"], | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     def train(self, output_dir: str, **arguments): | ||||||
|  |         trainer = Trainer( | ||||||
|  |             args=TrainingArguments(output_dir=output_dir, **arguments), | ||||||
|  |             train_dataset=self._dataset["train"],  # type: ignore | ||||||
|  |             eval_dataset=self._dataset["test"],  # type: ignore | ||||||
|  |             tokenizer=self._tokenizer, | ||||||
|  |             data_collator=self._data_collator, | ||||||
|  |             compute_metrics=self.compute_metrics, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         trainer.train() | ||||||
		Reference in New Issue
	
	Block a user