Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformers style scripts #100

Open
qgallouedec opened this issue Apr 25, 2024 · 0 comments
Open

transformers style scripts #100

qgallouedec opened this issue Apr 25, 2024 · 0 comments
Labels
✨ Enhancement New feature or request 🔄 Refactor Refactoring

Comments

@qgallouedec
Copy link
Member

We should try to converge to something that roughly looks like this for our scripts:

#!/usr/bin/env python3
"""Train my model on a dataset."""

from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import AutoConfig, HfArgumentParser, Trainer, TrainingArguments
from my_model import MyModel


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config we are going to train from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it "
                "will execute code present on the Hub on your local machine."
            )
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """


def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Load the model and the processor
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
    )
    model = MyModel(config)

    # Load the dataset
    dataset = load_dataset("my_dataset")

    # Train
    trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
    trainer.train()


if __name__ == "__main__":
    main()
@aliberts aliberts added ✨ Enhancement New feature or request 🔄 Refactor Refactoring labels Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ Enhancement New feature or request 🔄 Refactor Refactoring
Projects
None yet
Development

No branches or pull requests

2 participants