Skip to content

Training API

This page documents the model training components of AniSearch Model.

Overview

The training package provides functionality for fine-tuning cross-encoder models on anime and manga data. It includes:

  • Base trainer class with common functionality
  • Specialized trainers for anime and manga
  • Dataset handling utilities
  • Training utilities

Data Processing Flow

The following diagram illustrates how data flows through the training process:

flowchart LR
    A[(Raw Dataset)] --> B[Load Dataset]
    B --> C[Filter Light Novels]
    C --> D[Clean Data]
    D --> E[Generate Training Pairs]
    E --> F[Create Query Variations]
    F --> G[Create Positive/Negative Examples]
    G --> H[Prepare Model Input]
    H --> I[Fine-tune Model]

    subgraph Manga Specific
        C
    end

    subgraph Both Trainers
        D
        E
        F
        G
        H
        I
    end

    style A fill:#e3f2fd,stroke:#1976d2
    style I fill:#e8f5e9,stroke:#4caf50
Press "Alt" / "Option" to enable Pan & Zoom

Similarity Score Calculation

When generating synthetic training data, the system calculates similarity scores between entries based on their metadata:

flowchart TD
    A[Start] --> B[Parse Genre Lists]
    B --> C[Parse Theme Lists]
    C --> D{Both Lists Empty?}
    D -->|Yes| E[Return 0.0]
    D -->|No| F[Calculate Jaccard Similarity]
    F --> G[Weight Themes Higher]
    G --> H[Cap Score at 0.8]
    H --> I[Return Final Score]

    style A fill:#e3f2fd,stroke:#1976d2
    style E fill:#ffebee,stroke:#f44336
    style I fill:#e8f5e9,stroke:#4caf50
Press "Alt" / "Option" to enable Pan & Zoom

This process ensures that synthetic training pairs reflect meaningful relationships between entries based on their genres and themes.

Base Trainer

The foundation class with core training functionality:

src.training.base_trainer.BaseModelTrainer

BaseModelTrainer(dataset_type: str = 'anime', model_name: str = MODEL_NAME, epochs: int = DEFAULT_EPOCHS, batch_size: int = DEFAULT_BATCH_SIZE, eval_steps: int = DEFAULT_EVAL_STEPS, warmup_steps: int = DEFAULT_WARMUP_STEPS, max_samples: int = DEFAULT_MAX_SAMPLES, learning_rate: float = DEFAULT_LEARNING_RATE, eval_split: float = 0.1, seed: int = 42, device: Optional[str] = None, dataset_path: Optional[str] = None)

Base trainer class for fine-tuning cross-encoder models on anime/manga datasets.

This class provides the core functionality for training cross-encoder models on anime and manga datasets. It handles dataset preparation, synthetic training data generation, model configuration, and training execution. The trainer supports various training parameters and loss functions, allowing for flexible model tuning.

The trainer creates training examples by pairing titles (queries) with synopses (documents), generating both positive pairs (matching title-synopsis) and negative pairs (title with unrelated synopsis). It can also generate variations of queries to improve model robustness.

ATTRIBUTE DESCRIPTION
dataset_type

Type of dataset ('anime' or 'manga')

TYPE: str

model_name

Name or path of the base model to fine-tune

TYPE: str

epochs

Number of training epochs

TYPE: int

batch_size

Training batch size

TYPE: int

eval_steps

Number of steps between evaluations

TYPE: int

warmup_steps

Number of warmup steps for learning rate scheduler

TYPE: int

max_samples

Maximum number of training samples to use

TYPE: int

learning_rate

Learning rate for the optimizer

TYPE: float

eval_split

Fraction of data to use for evaluation

TYPE: float

seed

Random seed for reproducibility

TYPE: int

device

Device to use for training ('cpu', 'cuda', etc.)

TYPE: str

dataset_path

Path to the dataset file

TYPE: str

df

The loaded dataset

TYPE: DataFrame

output_path

Path where the fine-tuned model will be saved

TYPE: str

synopsis_cols

Columns containing synopsis information

TYPE: List[str]

id_col

Column containing the ID for anime/manga entries

TYPE: str

Example
# Initialize a trainer for anime dataset
trainer = BaseModelTrainer(
    dataset_type="anime",
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    epochs=3,
    batch_size=16
)

# Train the model
output_path = trainer.train(loss_type="mse")

# Create labeled data for inspection
trainer.create_and_save_labeled_data(
    output_file="labeled_anime_data.csv",
    n_samples=5000
)
Notes
  • The trainer requires merged datasets to be available. If not found, it will suggest running the merge_datasets.py script first.
  • For best results, ensure the dataset contains adequate synopsis information and relevant metadata like genres and themes.
  • The trainer automatically handles text truncation to fit within model token limits, prioritizing the query (title) over the document (synopsis).

This constructor sets up the training environment, loads the appropriate dataset, and prepares internal state for the training process. It validates inputs, sets up random seeds for reproducibility, configures the device, and establishes the model output path.

PARAMETER DESCRIPTION
dataset_type

The type of dataset to use for training. Must be either 'anime' or 'manga'. This determines which dataset is loaded and how certain processing steps are performed. Default is 'anime'.

TYPE: str DEFAULT: 'anime'

model_name

The name or path of the base cross-encoder model to fine-tune. Can be a HuggingFace model identifier or a local path. Default is the value from MODEL_NAME constant.

TYPE: str DEFAULT: MODEL_NAME

epochs

Number of complete passes through the training dataset. Higher values may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

TYPE: int DEFAULT: DEFAULT_EPOCHS

batch_size

Number of examples processed in each training step. Larger batches provide more stable gradients but require more memory. Default is DEFAULT_BATCH_SIZE (16).

TYPE: int DEFAULT: DEFAULT_BATCH_SIZE

eval_steps

Number of training steps between model evaluations. If not specified, a reasonable value will be calculated based on dataset size. Default is DEFAULT_EVAL_STEPS (500).

TYPE: int DEFAULT: DEFAULT_EVAL_STEPS

warmup_steps

Number of steps for learning rate warm-up. During warm-up, the learning rate gradually increases from 0 to the specified rate. Default is DEFAULT_WARMUP_STEPS (500).

TYPE: int DEFAULT: DEFAULT_WARMUP_STEPS

max_samples

Maximum number of training samples to use from the dataset. Useful for limiting training time or for testing. Set to None to use all available data. Default is DEFAULT_MAX_SAMPLES (10000).

TYPE: int DEFAULT: DEFAULT_MAX_SAMPLES

learning_rate

Learning rate for the optimizer. Controls how quickly model weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

TYPE: float DEFAULT: DEFAULT_LEARNING_RATE

eval_split

Fraction of data to use for evaluation instead of training. Must be between 0 and 1. Default is 0.1 (10% for evaluation).

TYPE: float DEFAULT: 0.1

seed

Random seed for reproducibility. Ensures the same training/evaluation split and data sampling across runs. Default is 42.

TYPE: int DEFAULT: 42

device

Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None, automatically selects GPU if available, otherwise CPU. Default is None.

TYPE: Optional[str] DEFAULT: None

dataset_path

Path to the dataset file. If None, uses the default path based on dataset_type. Default is None.

TYPE: Optional[str] DEFAULT: None

RAISES DESCRIPTION
ValueError

If dataset_type is not 'anime' or 'manga'

FileNotFoundError

If the dataset file doesn't exist

Notes
  • The method automatically creates the output directory if it doesn't exist
  • The output path is constructed from the model name and dataset type
  • After initialization, the dataset is prepared by calling _prepare_dataset()
Source code in src/training/base_trainer.py
def __init__(  # pylint: disable=too-many-arguments, too-many-positional-arguments
    self,
    dataset_type: str = "anime",
    model_name: str = MODEL_NAME,
    epochs: int = DEFAULT_EPOCHS,
    batch_size: int = DEFAULT_BATCH_SIZE,
    eval_steps: int = DEFAULT_EVAL_STEPS,
    warmup_steps: int = DEFAULT_WARMUP_STEPS,
    max_samples: int = DEFAULT_MAX_SAMPLES,
    learning_rate: float = DEFAULT_LEARNING_RATE,
    eval_split: float = 0.1,
    seed: int = 42,
    device: Optional[str] = None,
    dataset_path: Optional[str] = None,
):
    """
    Initialize the trainer with configuration parameters for model fine-tuning.

    This constructor sets up the training environment, loads the appropriate dataset,
    and prepares internal state for the training process. It validates inputs,
    sets up random seeds for reproducibility, configures the device, and establishes
    the model output path.

    Args:
        dataset_type: The type of dataset to use for training. Must be either 'anime'
            or 'manga'. This determines which dataset is loaded and how certain
            processing steps are performed. Default is 'anime'.

        model_name: The name or path of the base cross-encoder model to fine-tune.
            Can be a HuggingFace model identifier or a local path. Default is the
            value from MODEL_NAME constant.

        epochs: Number of complete passes through the training dataset. Higher values
            may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

        batch_size: Number of examples processed in each training step. Larger batches
            provide more stable gradients but require more memory. Default is
            DEFAULT_BATCH_SIZE (16).

        eval_steps: Number of training steps between model evaluations. If not specified,
            a reasonable value will be calculated based on dataset size. Default is
            DEFAULT_EVAL_STEPS (500).

        warmup_steps: Number of steps for learning rate warm-up. During warm-up, the
            learning rate gradually increases from 0 to the specified rate. Default
            is DEFAULT_WARMUP_STEPS (500).

        max_samples: Maximum number of training samples to use from the dataset. Useful
            for limiting training time or for testing. Set to None to use all available
            data. Default is DEFAULT_MAX_SAMPLES (10000).

        learning_rate: Learning rate for the optimizer. Controls how quickly model
            weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

        eval_split: Fraction of data to use for evaluation instead of training. Must
            be between 0 and 1. Default is 0.1 (10% for evaluation).

        seed: Random seed for reproducibility. Ensures the same training/evaluation
            split and data sampling across runs. Default is 42.

        device: Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None,
            automatically selects GPU if available, otherwise CPU. Default is None.

        dataset_path: Path to the dataset file. If None, uses the default path based
            on dataset_type. Default is None.

    Raises:
        ValueError: If dataset_type is not 'anime' or 'manga'
        FileNotFoundError: If the dataset file doesn't exist

    Notes:
        - The method automatically creates the output directory if it doesn't exist
        - The output path is constructed from the model name and dataset type
        - After initialization, the dataset is prepared by calling _prepare_dataset()
    """
    self.dataset_type = dataset_type.lower()
    if self.dataset_type not in ["anime", "manga"]:
        raise ValueError("Dataset type must be either 'anime' or 'manga'")

    self.model_name = model_name
    self.epochs = epochs
    self.batch_size = batch_size
    self.eval_steps = eval_steps
    self.warmup_steps = warmup_steps
    self.max_samples = max_samples
    self.learning_rate = learning_rate
    self.eval_split = eval_split
    self.seed = seed

    # Track whether eval_steps was explicitly set
    self.eval_steps_specified = eval_steps != DEFAULT_EVAL_STEPS

    # Track whether warmup_steps was explicitly set
    self.warmup_steps_specified = warmup_steps != DEFAULT_WARMUP_STEPS

    # Fix random seeds for reproducibility
    setup_random_seeds(seed)

    # Set device
    self.device = get_device(device)
    logger.info("Using device: %s", self.device)

    # Set model save path
    if not os.path.exists(MODEL_SAVE_PATH):
        os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

    model_basename = os.path.basename(model_name.replace("/", "-"))
    self.output_path = os.path.join(
        MODEL_SAVE_PATH, f"{model_basename}-{dataset_type}-finetuned"
    )

    # Load dataset
    if dataset_path is None:
        # Use default dataset path if none is provided

        self.dataset_path = (
            ANIME_DATASET_PATH if dataset_type == "anime" else MANGA_DATASET_PATH
        )
    else:
        self.dataset_path = dataset_path

    if not os.path.exists(self.dataset_path):
        raise FileNotFoundError(
            f"Dataset not found: {self.dataset_path}. "
            f"Run 'python src/merge_datasets.py --type {dataset_type}' first."
        )

    logger.info("Loading dataset from: %s", self.dataset_path)
    self.df = pd.read_csv(self.dataset_path)
    logger.info("Loaded %d entries", len(self.df))

    # Prepare dataset for training
    self._prepare_dataset()

batch_size instance-attribute

batch_size = batch_size

dataset_path instance-attribute

dataset_path = ANIME_DATASET_PATH if dataset_type == 'anime' else MANGA_DATASET_PATH

dataset_type instance-attribute

dataset_type = lower()

device instance-attribute

device = get_device(device)

df instance-attribute

df = read_csv(dataset_path)

epochs instance-attribute

epochs = epochs

eval_split instance-attribute

eval_split = eval_split

eval_steps instance-attribute

eval_steps = eval_steps

eval_steps_specified instance-attribute

eval_steps_specified = eval_steps != DEFAULT_EVAL_STEPS

learning_rate instance-attribute

learning_rate = learning_rate

max_samples instance-attribute

max_samples = max_samples

model_name instance-attribute

model_name = model_name

output_path instance-attribute

output_path = join(MODEL_SAVE_PATH, f'{model_basename}-{dataset_type}-finetuned')

seed instance-attribute

seed = seed

warmup_steps instance-attribute

warmup_steps = warmup_steps

warmup_steps_specified instance-attribute

warmup_steps_specified = warmup_steps != DEFAULT_WARMUP_STEPS

create_and_save_labeled_data

create_and_save_labeled_data(output_file: str, n_samples: int = 10000, include_partial_matches: bool = True) -> None

Create and save synthetic labeled data to a CSV file for inspection or custom training.

This method generates a rich dataset of labeled examples with various levels of relevance between queries and documents. Unlike the synthetic training data used directly for training (which uses binary labels), this method creates examples with graded relevance scores between 0.0 and 1.0, capturing partial matches based on content similarity.

The generated CSV file includes:

  • Perfect matches: Title paired with its own synopsis (score 1.0)
  • Partial matches: Title paired with synopses of similar content based on genres and themes (scores 0.1-0.8)
  • Query variations: Conversational variations of titles (e.g., "Looking for X") paired with matching synopses (score 1.0)
PARAMETER DESCRIPTION
output_file

Path to save the labeled data CSV file. If the directory doesn't exist, it will be created. If writing fails due to permissions, the file will be saved to the current directory.

TYPE: str

n_samples

Number of base entries to sample from the dataset for creating labeled examples. The actual number of examples in the output will be larger due to variations and partial matches. Default is 10000.

TYPE: int DEFAULT: 10000

include_partial_matches

Whether to include examples with partial relevance based on genre/theme similarity. When True, the dataset will include examples with scores between 0.1 and 0.8. When False, only perfect matches (1.0) and variations will be included. Default is True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
None

The method saves the labeled data to a file but doesn't return a value.

TYPE: None

Example
# Create labeled data with default settings
trainer = BaseModelTrainer(dataset_type="anime")
trainer.create_and_save_labeled_data("data/labeled_anime.csv")

# Create a smaller dataset without partial matches
trainer.create_and_save_labeled_data(
    "data/simple_labeled_anime.csv",
    n_samples=5000,
    include_partial_matches=False
)
Notes
  • The output CSV includes an 'example_type' column indicating the type of each example (positive, variation_positive, or similarity-based)
  • Similarity-based scores are rounded to the nearest 0.1 for cleaner values
  • Query variations are added to approximately 50% of the titles
  • The method handles permission errors by falling back to the current directory
  • The method logs a distribution of scores in the final dataset
Source code in src/training/base_trainer.py
def create_and_save_labeled_data(
    self,
    output_file: str,
    n_samples: int = 10000,
    include_partial_matches: bool = True,
) -> None:
    """
    Create and save synthetic labeled data to a CSV file for inspection or custom training.

    This method generates a rich dataset of labeled examples with various levels of
    relevance between queries and documents. Unlike the synthetic training data used
    directly for training (which uses binary labels), this method creates examples with
    graded relevance scores between 0.0 and 1.0, capturing partial matches based on
    content similarity.

    The generated CSV file includes:

    - **Perfect matches**: Title paired with its own synopsis (score 1.0)
    - **Partial matches**: Title paired with synopses of similar content based on
      genres and themes (scores 0.1-0.8)
    - **Query variations**: Conversational variations of titles (e.g., "Looking for X")
      paired with matching synopses (score 1.0)

    Args:
        output_file: Path to save the labeled data CSV file. If the directory doesn't
            exist, it will be created. If writing fails due to permissions, the file
            will be saved to the current directory.

        n_samples: Number of base entries to sample from the dataset for creating
            labeled examples. The actual number of examples in the output will be
            larger due to variations and partial matches. Default is 10000.

        include_partial_matches: Whether to include examples with partial relevance
            based on genre/theme similarity. When True, the dataset will include
            examples with scores between 0.1 and 0.8. When False, only perfect
            matches (1.0) and variations will be included. Default is True.

    Returns:
        None: The method saves the labeled data to a file but doesn't return a value.

    Example:
        ```python
        # Create labeled data with default settings
        trainer = BaseModelTrainer(dataset_type="anime")
        trainer.create_and_save_labeled_data("data/labeled_anime.csv")

        # Create a smaller dataset without partial matches
        trainer.create_and_save_labeled_data(
            "data/simple_labeled_anime.csv",
            n_samples=5000,
            include_partial_matches=False
        )
        ```

    Notes:
        - The output CSV includes an 'example_type' column indicating the type of each
          example (positive, variation_positive, or similarity-based)
        - Similarity-based scores are rounded to the nearest 0.1 for cleaner values
        - Query variations are added to approximately 50% of the titles
        - The method handles permission errors by falling back to the current directory
        - The method logs a distribution of scores in the final dataset
    """
    logger.info("Creating %d labeled examples for inspection", n_samples)

    # Ensure output_file has a proper path
    if output_file.startswith("/"):
        output_file = output_file.lstrip("/")

    if not output_file.endswith(".csv"):
        output_file = f"{output_file}.csv"

    output_dir = os.path.dirname(output_file)
    if output_dir:
        try:
            os.makedirs(output_dir, exist_ok=True)
            logger.info("Created output directory: %s", output_dir)
        except PermissionError:
            logger.error(
                "Permission denied when creating directory: %s", output_dir
            )
            logger.info("Trying to save to current directory instead")
            output_file = os.path.basename(output_file)

    # Limit dataset size
    df_sample = self.df.sample(min(len(self.df), n_samples), random_state=self.seed)

    # Create data
    data = []
    for idx, row in tqdm(
        df_sample.iterrows(), total=len(df_sample), desc="Creating labeled data"
    ):
        title = str(row["title"]) if not pd.isna(row["title"]) else ""
        synopsis = (
            str(row["combined_synopsis"])
            if not pd.isna(row["combined_synopsis"])
            else ""
        )

        # Skip entries with empty titles or synopses
        if not title or not synopsis:
            continue

        # 1. Add positive example (perfect match)
        data.append(
            {
                "query": title,
                "text": synopsis,
                "score": 1.0,
                "example_type": "positive",
            }
        )

        # 2. Add similarity-based matches with varying scores
        if include_partial_matches and (
            "genres" in self.df.columns or "themes" in self.df.columns
        ):
            # Sample a larger pool of entries to evaluate for similarity
            sample_indices = random.sample(
                list(set(df_sample.index) - {idx}),
                min(
                    50, len(df_sample) - 1
                ),  # Increased from 20 to 50 for better coverage
            )

            # Calculate similarity for all sampled entries
            similarity_scores = []
            for sample_idx in sample_indices:
                sample_row = df_sample.loc[sample_idx]
                score = self._calculate_similarity_score(row, sample_row)
                similarity_scores.append((sample_idx, score))

            # Sort by similarity score
            similarity_scores.sort(key=lambda x: x[1])

            # Select entries across the similarity spectrum
            # We'll pick examples from different similarity bands to ensure good coverage
            selected_scores = []

            # Very low similarity (0.0-0.2) - replaces completely random negatives
            very_low = [s for s in similarity_scores if s[1] < 0.2]
            if very_low:
                selected_scores.extend(
                    random.sample(very_low, min(2, len(very_low)))
                )

            # Low similarity (0.2-0.4)
            low = [s for s in similarity_scores if 0.2 <= s[1] < 0.4]
            if low:
                selected_scores.extend(random.sample(low, min(2, len(low))))

            # Medium similarity (0.4-0.6)
            medium = [s for s in similarity_scores if 0.4 <= s[1] < 0.6]
            if medium:
                selected_scores.extend(random.sample(medium, min(2, len(medium))))

            # High similarity (0.6-0.8)
            high = [s for s in similarity_scores if 0.6 <= s[1] <= 0.8]
            if high:
                selected_scores.extend(random.sample(high, min(2, len(high))))

            # Add all selected examples
            for sample_idx, score in selected_scores:
                sample_row = df_sample.loc[sample_idx]
                sample_synopsis = (
                    str(sample_row["combined_synopsis"])
                    if not pd.isna(sample_row["combined_synopsis"])
                    else ""
                )

                if not sample_synopsis:
                    continue

                # Round score to nearest 0.1 for cleaner values
                rounded_score = round(score * 10) / 10

                # Set example type based on score range
                if rounded_score < 0.2:
                    example_type = "very_low_similarity"
                elif rounded_score < 0.4:
                    example_type = "low_similarity"
                elif rounded_score < 0.6:
                    example_type = "medium_similarity"
                else:
                    example_type = "high_similarity"

                data.append(
                    {
                        "query": title,
                        "text": sample_synopsis,
                        "score": rounded_score,
                        "example_type": example_type,
                    }
                )

        # 3. Add query variations for some entries
        if random.random() < 0.5:  # 50% chance
            query_variations = self.create_query_variations([title])
            for variation in query_variations[1:]:  # Skip the original
                data.append(
                    {
                        "query": variation,
                        "text": synopsis,
                        "score": 1.0,
                        "example_type": "variation_positive",
                    }
                )

    # Create DataFrame
    labeled_df = pd.DataFrame(data)

    # Print distribution of scores
    score_counts = labeled_df["score"].value_counts().sort_index()
    logger.info("Score distribution:")
    for score, count in score_counts.items():
        logger.info(
            "  Score %.1f: %d examples (%.1f%%)",
            score,
            count,
            100 * count / len(labeled_df),
        )

    # Try to save the file
    try:
        labeled_df.to_csv(output_file, index=False)
        logger.info("Saved %d labeled examples to %s", len(labeled_df), output_file)
    except PermissionError:
        fallback_file = f"labeled_data_{self.dataset_type}.csv"
        logger.error("Permission denied when writing to %s", output_file)
        logger.info("Saving to %s instead", fallback_file)
        try:
            labeled_df.to_csv(fallback_file, index=False)
            logger.info(
                "Saved %d labeled examples to %s", len(labeled_df), fallback_file
            )
        except Exception as e:
            logger.error("Failed to save labeled data: %s", str(e))
            raise

create_query_variations

create_query_variations(base_queries: List[str], n_variations: int = 7) -> List[str]

Create natural language variations of base queries to enhance training robustness.

This method generates conversational and alternative phrasings of the base queries to help the model recognize the same intent expressed in different ways. For each base query, it creates variations using templates like "I'm looking for {query}" or "Find me {query}".

Query variations are important for training more robust models that can handle real-world search inputs, which often contain conversational phrases and different formulations of the same information need.

PARAMETER DESCRIPTION
base_queries

List of original query strings (typically anime/manga titles) that will be used as the basis for generating variations.

TYPE: List[str]

n_variations

Number of variations to create for each base query. The actual number may be less if there aren't enough templates. Default is 7.

TYPE: int DEFAULT: 7

RETURNS DESCRIPTION
List[str]

List[str]: A combined list containing both the original queries and their variations. The length will be approximately len(base_queries) * (1 + n_variations), but may be less if n_variations exceeds the number of available templates.

Example
# Create variations of anime titles
titles = ["Naruto", "One Piece", "Attack on Titan"]
trainer = BaseModelTrainer(dataset_type="anime")
variations = trainer.create_query_variations(titles, n_variations=3)

# Print all variations
for var in variations:
    print(var)
# Example output:
# Naruto
# I'm looking for Naruto
# Can you recommend Naruto?
# Find me Naruto
# One Piece
# ...etc.
Notes
  • The method always includes the original queries in the returned list
  • Templates are selected randomly for each query
  • The method is designed for English language variations
  • The method is decorated with handle_exceptions for error handling
Source code in src/training/base_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def create_query_variations(
    self, base_queries: List[str], n_variations: int = 7
) -> List[str]:
    """
    Create natural language variations of base queries to enhance training robustness.

    This method generates conversational and alternative phrasings of the base queries
    to help the model recognize the same intent expressed in different ways. For each
    base query, it creates variations using templates like "I'm looking for {query}"
    or "Find me {query}".

    Query variations are important for training more robust models that can handle
    real-world search inputs, which often contain conversational phrases and different
    formulations of the same information need.

    Args:
        base_queries: List of original query strings (typically anime/manga titles)
            that will be used as the basis for generating variations.

        n_variations: Number of variations to create for each base query. The actual
            number may be less if there aren't enough templates. Default is 7.

    Returns:
        List[str]: A combined list containing both the original queries and their
            variations. The length will be approximately len(base_queries) * (1 + n_variations),
            but may be less if n_variations exceeds the number of available templates.

    Example:
        ```python
        # Create variations of anime titles
        titles = ["Naruto", "One Piece", "Attack on Titan"]
        trainer = BaseModelTrainer(dataset_type="anime")
        variations = trainer.create_query_variations(titles, n_variations=3)

        # Print all variations
        for var in variations:
            print(var)
        # Example output:
        # Naruto
        # I'm looking for Naruto
        # Can you recommend Naruto?
        # Find me Naruto
        # One Piece
        # ...etc.
        ```

    Notes:
        - The method always includes the original queries in the returned list
        - Templates are selected randomly for each query
        - The method is designed for English language variations
        - The method is decorated with handle_exceptions for error handling
    """
    templates = [
        "I'm looking for {query}",
        "Can you recommend {query}?",
        "Find me {query}",
        "I want to watch {query}",
        "Suggest {query}",
        "I need {query}",
        "Something like {query}",
        "Similar to {query}",
        "{query} or similar",
        "Looking for {query}",
        "Need recommendations for {query}",
        "What's similar to {query}?",
        "I enjoyed {query}, what else?",
    ]

    variations = []
    for query in base_queries:
        # Add the original query
        variations.append(query)

        # Add variations based on templates
        n_to_use = min(n_variations, len(templates))
        selected_templates = random.sample(templates, n_to_use)
        for template in selected_templates:
            variations.append(template.format(query=query))

    return variations

create_synthetic_training_data

create_synthetic_training_data() -> List[InputExample]

Create synthetic training data pairs for cross-encoder model fine-tuning.

This method generates a balanced dataset of positive and negative examples:

  • Positive examples: Pairs of titles with their matching synopses (label 1.0)
  • Negative examples: Pairs of titles with randomly selected unrelated synopses (label 0.0)

For each positive example, the method creates 3 negative examples, resulting in a 1:3 ratio of positive to negative examples. This ratio helps the model learn to distinguish relevant from irrelevant content.

The method samples up to max_samples entries from the dataset and applies randomization with the configured seed for reproducibility. Examples with empty titles or synopses are skipped.

RETURNS DESCRIPTION
List[InputExample]

List[InputExample]: A list of InputExample objects ready for training, where each example contains: - texts[0]: A title (query) - texts[1]: A synopsis (document) - label: 1.0 for positive pairs, 0.0 for negative pairs

Example
# Create synthetic training data
trainer = BaseModelTrainer(dataset_type="anime")
examples = trainer.create_synthetic_training_data()

# Examine the first few examples
for i, example in enumerate(examples[:5]):
    print(f"Example {i}:")
    print(f"  Query: {example.texts[0][:50]}...")
    print(f"  Document: {example.texts[1][:50]}...")
    print(f"  Label: {example.label}")
Notes
  • The method is decorated with handle_exceptions for error handling
  • Results are shuffled before returning to randomize the training order
  • If max_samples is smaller than the dataset size, a random subset is used
  • The 1:3 positive-to-negative ratio is a common practice in information retrieval tasks to handle the natural imbalance of relevant vs. irrelevant documents
Source code in src/training/base_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def create_synthetic_training_data(self) -> List[InputExample]:
    """
    Create synthetic training data pairs for cross-encoder model fine-tuning.

    This method generates a balanced dataset of positive and negative examples:

    - **Positive examples**: Pairs of titles with their matching synopses (label 1.0)
    - **Negative examples**: Pairs of titles with randomly selected unrelated
      synopses (label 0.0)

    For each positive example, the method creates 3 negative examples, resulting
    in a 1:3 ratio of positive to negative examples. This ratio helps the model
    learn to distinguish relevant from irrelevant content.

    The method samples up to max_samples entries from the dataset and applies
    randomization with the configured seed for reproducibility. Examples with
    empty titles or synopses are skipped.

    Returns:
        List[InputExample]: A list of InputExample objects ready for training,
            where each example contains:
            - texts[0]: A title (query)
            - texts[1]: A synopsis (document)
            - label: 1.0 for positive pairs, 0.0 for negative pairs

    Example:
        ```python
        # Create synthetic training data
        trainer = BaseModelTrainer(dataset_type="anime")
        examples = trainer.create_synthetic_training_data()

        # Examine the first few examples
        for i, example in enumerate(examples[:5]):
            print(f"Example {i}:")
            print(f"  Query: {example.texts[0][:50]}...")
            print(f"  Document: {example.texts[1][:50]}...")
            print(f"  Label: {example.label}")
        ```

    Notes:
        - The method is decorated with handle_exceptions for error handling
        - Results are shuffled before returning to randomize the training order
        - If max_samples is smaller than the dataset size, a random subset is used
        - The 1:3 positive-to-negative ratio is a common practice in information
          retrieval tasks to handle the natural imbalance of relevant vs. irrelevant
          documents
    """
    logger.info("Creating synthetic training data")

    # Limit dataset size if needed
    df_sample = self.df.sample(
        min(len(self.df), self.max_samples), random_state=self.seed
    )
    examples = []

    for idx, row in tqdm(
        df_sample.iterrows(), total=len(df_sample), desc="Creating training pairs"
    ):
        title = str(row["title"]) if not pd.isna(row["title"]) else ""
        synopsis = (
            str(row["combined_synopsis"])
            if not pd.isna(row["combined_synopsis"])
            else ""
        )

        # Skip entries with empty titles or synopses
        if not title or not synopsis:
            continue

        # Create positive pair (score 1.0)
        examples.append(InputExample(texts=[title, synopsis], label=1.0))

        # For each positive pair, create 3 negative pairs (score 0.0)
        # with random synopses from other entries
        for _ in range(3):
            negative_idx = random.choice(df_sample.index)
            while negative_idx == idx:  # Ensure different entry
                negative_idx = random.choice(df_sample.index)

            negative_synopsis = (
                str(df_sample.loc[negative_idx, "combined_synopsis"])
                if not pd.isna(df_sample.loc[negative_idx, "combined_synopsis"])
                else ""
            )
            if not negative_synopsis:
                continue

            examples.append(
                InputExample(texts=[title, negative_synopsis], label=0.0)
            )

    logger.info(
        "Created %d training examples: %d positive, %d negative",
        len(examples),
        len(examples) // 4,
        3 * len(examples) // 4,
    )

    # Shuffle examples
    random.shuffle(examples)
    return examples

create_training_data_from_labeled_file

create_training_data_from_labeled_file(labeled_file: str) -> List[InputExample]

Create training data from a pre-labeled CSV file instead of synthetic generation.

This method allows for using custom or human-labeled data for training. The labeled file should be a CSV containing at least three columns:

  • query: The search query or title
  • text: The document or synopsis text
  • score: A numerical score/label (typically 0-1) indicating relevance

Using labeled data gives more control over the training examples and can incorporate domain expertise about what constitutes good matches. It's especially useful for fine-grained relevance levels beyond just binary classification.

PARAMETER DESCRIPTION
labeled_file

Path to the CSV file containing labeled examples. The file must include 'query', 'text', and 'score' columns.

TYPE: str

RETURNS DESCRIPTION
List[InputExample]

List[InputExample]: A list of InputExample objects created from the labeled file, where each example contains: - texts[0]: The query from the 'query' column - texts[1]: The document from the 'text' column - label: The float score from the 'score' column

RAISES DESCRIPTION
FileNotFoundError

If the labeled_file doesn't exist

ValueError

If the required columns are missing from the file

Example
# Create training data from labeled file
trainer = BaseModelTrainer(dataset_type="anime")
examples = trainer.create_training_data_from_labeled_file(
    "path/to/labeled_data.csv"
)

# Print distribution of scores
score_counts = {}
for example in examples:
    score = example.label
    score_counts[score] = score_counts.get(score, 0) + 1

for score, count in sorted(score_counts.items()):
    print(f"Score {score}: {count} examples")
Notes
  • The method is decorated with handle_exceptions for error handling
  • No shuffling is performed as the labeled file may already have a specific order
  • Empty values in the CSV are converted to empty strings
  • The scores are converted to float values
Source code in src/training/base_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def create_training_data_from_labeled_file(
    self, labeled_file: str
) -> List[InputExample]:
    """
    Create training data from a pre-labeled CSV file instead of synthetic generation.

    This method allows for using custom or human-labeled data for training. The
    labeled file should be a CSV containing at least three columns:

    - **query**: The search query or title
    - **text**: The document or synopsis text
    - **score**: A numerical score/label (typically 0-1) indicating relevance

    Using labeled data gives more control over the training examples and can
    incorporate domain expertise about what constitutes good matches. It's especially
    useful for fine-grained relevance levels beyond just binary classification.

    Args:
        labeled_file: Path to the CSV file containing labeled examples. The file
            must include 'query', 'text', and 'score' columns.

    Returns:
        List[InputExample]: A list of InputExample objects created from the labeled
            file, where each example contains:
            - texts[0]: The query from the 'query' column
            - texts[1]: The document from the 'text' column
            - label: The float score from the 'score' column

    Raises:
        FileNotFoundError: If the labeled_file doesn't exist
        ValueError: If the required columns are missing from the file

    Example:
        ```python
        # Create training data from labeled file
        trainer = BaseModelTrainer(dataset_type="anime")
        examples = trainer.create_training_data_from_labeled_file(
            "path/to/labeled_data.csv"
        )

        # Print distribution of scores
        score_counts = {}
        for example in examples:
            score = example.label
            score_counts[score] = score_counts.get(score, 0) + 1

        for score, count in sorted(score_counts.items()):
            print(f"Score {score}: {count} examples")
        ```

    Notes:
        - The method is decorated with handle_exceptions for error handling
        - No shuffling is performed as the labeled file may already have a
          specific order
        - Empty values in the CSV are converted to empty strings
        - The scores are converted to float values
    """
    logger.info("Loading labeled data from %s", labeled_file)
    if not os.path.exists(labeled_file):
        raise FileNotFoundError(f"Labeled file not found: {labeled_file}")

    df = pd.read_csv(labeled_file)
    logger.info("Loaded %d labeled examples", len(df))

    # Ensure required columns exist
    required_cols = ["query", "text", "score"]
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(
            f"Missing required columns in labeled data: {missing_cols}"
        )

    # Convert to InputExample format
    examples = []
    for _, row in df.iterrows():
        examples.append(
            InputExample(
                texts=[
                    str(row["query"]) if not pd.isna(row["query"]) else "",
                    str(row["text"]) if not pd.isna(row["text"]) else "",
                ],
                label=float(row["score"]),
            )
        )

    logger.info("Created %d training examples from labeled data", len(examples))
    return examples

train

train(labeled_file: Optional[str] = None, loss_type: str = 'mse', scheduler: str = 'linear') -> str

Train the cross-encoder model with the prepared dataset.

This method executes the full training pipeline:

  1. Prepares training data (synthetic or from labeled file)
  2. Splits data into training and evaluation sets
  3. Truncates text pairs to fit model token limits
  4. Configures the model, loss function, and training arguments
  5. Executes the training process
  6. Saves the fine-tuned model

The method supports various loss functions and learning rate schedulers to optimize different aspects of model performance. It automatically handles device placement, batching, and evaluation during training.

PARAMETER DESCRIPTION
labeled_file

Optional path to a pre-labeled CSV file containing training examples. If provided, uses this file instead of generating synthetic data. Default is None (generate synthetic data).

TYPE: Optional[str] DEFAULT: None

loss_type

Type of loss function to use for training. Supported options: - 'mse' (default): Mean Squared Error loss - 'binary_cross_entropy': Binary Cross Entropy loss - 'cross_entropy': Cross Entropy loss - 'lambda': LambdaLoss for LambdaRank-style learning to rank - 'list_mle', 'p_list_mle': ListMLE/PListMLE losses for listwise learning - 'list_net': ListNet loss for listwise learning - 'multiple_negatives_ranking': Multiple Negatives Ranking loss - 'cached_multiple_negatives_ranking': Cached version of MNR loss - 'margin_mse': Margin MSE loss - 'rank_net': RankNet loss for pairwise learning

TYPE: str DEFAULT: 'mse'

scheduler

Learning rate scheduler type. Options include: - 'linear' (default): Linear decay from initial value to 0 - 'cosine': Cosine decay schedule - 'cosine_with_restarts': Cosine decay with periodic restarts - 'polynomial': Polynomial decay - 'constant': Constant learning rate - 'constant_with_warmup': Constant learning rate after warmup

TYPE: str DEFAULT: 'linear'

RETURNS DESCRIPTION
str

Path to the saved fine-tuned model, which can be loaded later for inference or additional training.

TYPE: str

Example
# Train a model with default settings
trainer = BaseModelTrainer(
    dataset_type="anime",
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    epochs=3
)
model_path = trainer.train()
print(f"Model saved to: {model_path}")

# Train with custom loss and scheduler
trainer2 = BaseModelTrainer(dataset_type="manga")
model_path = trainer2.train(
    loss_type="binary_cross_entropy",
    scheduler="cosine"
)
Notes
  • The method automatically calculates reasonable evaluation and warmup steps if they weren't explicitly specified during initialization
  • Training progress is logged using tqdm progress bars and the logger
  • The model with the best evaluation performance is automatically saved
  • The method is decorated with handle_exceptions for error handling
Source code in src/training/base_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def train(
    self,
    labeled_file: Optional[str] = None,
    loss_type: str = "mse",
    scheduler: str = "linear",
) -> str:
    """
    Train the cross-encoder model with the prepared dataset.

    This method executes the full training pipeline:

    1. Prepares training data (synthetic or from labeled file)
    2. Splits data into training and evaluation sets
    3. Truncates text pairs to fit model token limits
    4. Configures the model, loss function, and training arguments
    5. Executes the training process
    6. Saves the fine-tuned model

    The method supports various loss functions and learning rate schedulers to
    optimize different aspects of model performance. It automatically handles
    device placement, batching, and evaluation during training.

    Args:
        labeled_file: Optional path to a pre-labeled CSV file containing training
            examples. If provided, uses this file instead of generating synthetic
            data. Default is None (generate synthetic data).

        loss_type: Type of loss function to use for training. Supported options:
            - 'mse' (default): Mean Squared Error loss
            - 'binary_cross_entropy': Binary Cross Entropy loss
            - 'cross_entropy': Cross Entropy loss
            - 'lambda': LambdaLoss for LambdaRank-style learning to rank
            - 'list_mle', 'p_list_mle': ListMLE/PListMLE losses for listwise learning
            - 'list_net': ListNet loss for listwise learning
            - 'multiple_negatives_ranking': Multiple Negatives Ranking loss
            - 'cached_multiple_negatives_ranking': Cached version of MNR loss
            - 'margin_mse': Margin MSE loss
            - 'rank_net': RankNet loss for pairwise learning

        scheduler: Learning rate scheduler type. Options include:
            - 'linear' (default): Linear decay from initial value to 0
            - 'cosine': Cosine decay schedule
            - 'cosine_with_restarts': Cosine decay with periodic restarts
            - 'polynomial': Polynomial decay
            - 'constant': Constant learning rate
            - 'constant_with_warmup': Constant learning rate after warmup

    Returns:
        str: Path to the saved fine-tuned model, which can be loaded later for
            inference or additional training.

    Example:
        ```python
        # Train a model with default settings
        trainer = BaseModelTrainer(
            dataset_type="anime",
            model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
            epochs=3
        )
        model_path = trainer.train()
        print(f"Model saved to: {model_path}")

        # Train with custom loss and scheduler
        trainer2 = BaseModelTrainer(dataset_type="manga")
        model_path = trainer2.train(
            loss_type="binary_cross_entropy",
            scheduler="cosine"
        )
        ```

    Notes:
        - The method automatically calculates reasonable evaluation and warmup steps
          if they weren't explicitly specified during initialization
        - Training progress is logged using tqdm progress bars and the logger
        - The model with the best evaluation performance is automatically saved
        - The method is decorated with handle_exceptions for error handling
    """
    logger.info("Starting training with %s", self.model_name)

    # Prepare training data
    if labeled_file is not None:
        train_examples = self.create_training_data_from_labeled_file(labeled_file)
    else:
        train_examples = self.create_synthetic_training_data()

    # Split into train and evaluation sets
    train_size = int(len(train_examples) * (1 - self.eval_split))
    train_data = train_examples[:train_size]
    eval_data = train_examples[train_size:]

    logger.info(
        "Training on %d examples, evaluating on %d examples",
        len(train_data),
        len(eval_data),
    )

    # Try to disable tokenizer warnings
    transformers.logging.set_verbosity_error()
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    # Initialize the model with env vars to encourage fast tokenizer
    os.environ["USE_FAST_TOKENIZER"] = "true"

    # Initialize the model
    logger.info("Initializing model: %s", self.model_name)
    model = CrossEncoder(
        self.model_name,
        num_labels=1,
        device=self.device,
        max_length=512,
    )

    # Get the tokenizer
    tokenizer = model.tokenizer
    max_length = 512

    # Manually truncate text pairs to avoid tokenizer warnings
    logger.info("Truncating training examples to fit max_length")
    # Extract text pairs from examples
    train_pairs = [(example.texts[0], example.texts[1]) for example in train_data]
    train_labels = [example.label for example in train_data]

    # Process training data in batches
    truncated_train_pairs = batch_truncate_text_pairs(
        train_pairs, tokenizer, max_length=max_length
    )

    # Create truncated examples
    truncated_train_data = []
    for i, (text_a, text_b) in enumerate(truncated_train_pairs):
        truncated_train_data.append(
            InputExample(texts=[text_a, text_b], label=train_labels[i])
        )

    # Process evaluation data in batches
    eval_pairs = [(example.texts[0], example.texts[1]) for example in eval_data]
    eval_labels = [example.label for example in eval_data]

    truncated_eval_pairs = batch_truncate_text_pairs(
        eval_pairs, tokenizer, max_length=max_length
    )

    # Create truncated examples
    truncated_eval_data = []
    for i, (text_a, text_b) in enumerate(truncated_eval_pairs):
        truncated_eval_data.append(
            InputExample(texts=[text_a, text_b], label=eval_labels[i])
        )

    # Prepare datasets for the CrossEncoderTrainer
    train_texts1 = [example.texts[0] for example in truncated_train_data]
    train_texts2 = [example.texts[1] for example in truncated_train_data]
    train_labels = [example.label for example in truncated_train_data]

    eval_texts1 = [example.texts[0] for example in truncated_eval_data]
    eval_texts2 = [example.texts[1] for example in truncated_eval_data]
    eval_labels = [example.label for example in truncated_eval_data]

    # Create HuggingFace datasets
    train_hf_dataset = HFDataset.from_dict(
        {
            "sentence_A": train_texts1,
            "sentence_B": train_texts2,
            "labels": train_labels,
        }
    )

    eval_hf_dataset = HFDataset.from_dict(
        {
            "sentence_A": eval_texts1,
            "sentence_B": eval_texts2,
            "labels": eval_labels,
        }
    )

    # Set warm-up steps based on epochs and dataset size
    if not self.warmup_steps_specified:
        # Calculate steps per epoch (approx)
        steps_per_epoch = max(1, len(truncated_train_data) // self.batch_size)
        # Use 10% of total steps but ensure at least 100 steps
        total_steps = steps_per_epoch * self.epochs
        self.warmup_steps = max(100, int(total_steps * 0.1))
        logger.info(
            "Using %d warm-up steps (approx. %d%% of total steps)",
            self.warmup_steps,
            int(100 * self.warmup_steps / total_steps),
        )

    # Set evaluation steps if not specified
    if not self.eval_steps_specified:
        # Calculate reasonable evaluation frequency - evaluate ~5 times per epoch
        steps_per_epoch = max(1, len(truncated_train_data) // self.batch_size)
        self.eval_steps = max(100, steps_per_epoch // 5)
        logger.info(
            "Using %d evaluation steps (approx. 5 times per epoch)", self.eval_steps
        )

    # Set up loss function
    if loss_type == "binary_cross_entropy":
        loss: Any = BinaryCrossEntropyLoss(model)
    elif loss_type == "cross_entropy":
        loss = CrossEntropyLoss(model)
    elif loss_type == "lambda":
        loss = LambdaLoss(model)
    elif loss_type == "list_mle":
        loss = ListMLELoss(model)
    elif loss_type == "p_list_mle":
        loss = PListMLELoss(model)
    elif loss_type == "list_net":
        loss = ListNetLoss(model)
    elif loss_type == "multiple_negatives_ranking":
        loss = MultipleNegativesRankingLoss(model)
    elif loss_type == "cached_multiple_negatives_ranking":
        loss = CachedMultipleNegativesRankingLoss(model)
    elif loss_type == "mse":
        loss = MSELoss(model)
    elif loss_type == "margin_mse":
        loss = MarginMSELoss(model)
    elif loss_type == "rank_net":
        loss = RankNetLoss(model)
    else:
        logger.info("Unknown loss type '%s', falling back to MSE loss", loss_type)
        loss = MSELoss(model)

    # Create training arguments
    training_args = CrossEncoderTrainingArguments(
        output_dir=self.output_path,
        num_train_epochs=self.epochs,
        per_device_train_batch_size=self.batch_size,
        per_device_eval_batch_size=self.batch_size,
        eval_strategy="steps",
        eval_steps=self.eval_steps,
        warmup_steps=self.warmup_steps,
        learning_rate=self.learning_rate,
        weight_decay=0.05,  # L2 regularization
        lr_scheduler_type=scheduler,
        save_strategy="steps",
        save_steps=self.eval_steps,
        logging_steps=100,
        load_best_model_at_end=True,
        auto_find_batch_size=True,
        disable_tqdm=False,
    )

    # Initialize trainer
    trainer = CrossEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_hf_dataset,
        eval_dataset=eval_hf_dataset,
        loss=loss,
    )

    # Train the model
    logger.info(
        "Training with: epochs=%d, batch_size=%d, warmup_steps=%d, eval_steps=%d",
        self.epochs,
        self.batch_size,
        self.warmup_steps,
        self.eval_steps,
    )
    trainer.train()

    # Save the model
    logger.info("Saving fine-tuned model to %s", self.output_path)
    model.save(self.output_path)
    logger.info("Training completed successfully!")

    return self.output_path

Anime Trainer

Specialized trainer for anime models:

src.training.anime_trainer.AnimeModelTrainer

AnimeModelTrainer(model_name: str = MODEL_NAME, epochs: int = DEFAULT_EPOCHS, batch_size: int = DEFAULT_BATCH_SIZE, eval_steps: int = DEFAULT_EVAL_STEPS, warmup_steps: int = DEFAULT_WARMUP_STEPS, max_samples: int = DEFAULT_MAX_SAMPLES, learning_rate: float = DEFAULT_LEARNING_RATE, eval_split: float = 0.1, seed: int = 42, device: Optional[str] = None, dataset_path: Optional[str] = None)

Bases: BaseModelTrainer

Specialized trainer for fine-tuning cross-encoder models on anime datasets.

This class extends the BaseModelTrainer with anime-specific functionality, simplifying the creation of search models optimized for anime content. It automatically configures the training process for anime datasets and provides anime-specific query generation for more robust training.

The trainer creates relevant training examples using anime titles and synopses, and generates query variations that reflect how users typically search for anime content (e.g., "Looking for anime about...", "Anime similar to...").

ATTRIBUTE DESCRIPTION
dataset_type

Fixed to "anime" to specify this trainer works with anime datasets.

TYPE: str

model_name

Name of the base cross-encoder model used for fine-tuning.

TYPE: str

epochs

Number of training epochs.

TYPE: int

batch_size

Number of examples processed in each training step.

TYPE: int

eval_steps

Number of steps between model evaluations.

TYPE: int

warmup_steps

Number of warmup steps for the learning rate scheduler.

TYPE: int

max_samples

Maximum number of training samples to use.

TYPE: int

learning_rate

Learning rate for the optimizer.

TYPE: float

eval_split

Fraction of data used for evaluation.

TYPE: float

seed

Random seed for reproducibility.

TYPE: int

device

Device used for training (cpu or cuda).

TYPE: str

df

The loaded anime dataset after preparation.

TYPE: DataFrame

Example
# Initialize a trainer for anime model
trainer = AnimeModelTrainer(
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    epochs=5,
    batch_size=16
)

# Train the model with MSE loss and linear scheduler
model_path = trainer.train(loss_type="mse", scheduler="linear")
print(f"Anime search model saved to: {model_path}")

# Create labeled data for inspection
trainer.create_and_save_labeled_data(
    output_file="anime_labeled_data.csv",
    n_samples=5000
)
Notes
  • The trainer automatically uses the default anime dataset path unless specified
  • For best results, ensure your anime dataset contains adequate synopsis information and metadata like genres and themes
  • This class sets dataset_type="anime" in the parent class, focusing all operations on anime data

This constructor sets up the training environment specifically for anime data, passing "anime" as the dataset_type to the parent class. It configures all training parameters and loads the appropriate anime dataset.

PARAMETER DESCRIPTION
model_name

The name or path of the base cross-encoder model to fine-tune. Can be a HuggingFace model identifier or a local path. Default is the value from MODEL_NAME constant.

TYPE: str DEFAULT: MODEL_NAME

epochs

Number of complete passes through the training dataset. Higher values may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

TYPE: int DEFAULT: DEFAULT_EPOCHS

batch_size

Number of examples processed in each training step. Larger batches provide more stable gradients but require more memory. Default is DEFAULT_BATCH_SIZE (16).

TYPE: int DEFAULT: DEFAULT_BATCH_SIZE

eval_steps

Number of training steps between model evaluations. If not specified, a reasonable value will be calculated based on dataset size. Default is DEFAULT_EVAL_STEPS (500).

TYPE: int DEFAULT: DEFAULT_EVAL_STEPS

warmup_steps

Number of steps for learning rate warm-up. During warm-up, the learning rate gradually increases from 0 to the specified rate. Default is DEFAULT_WARMUP_STEPS (500).

TYPE: int DEFAULT: DEFAULT_WARMUP_STEPS

max_samples

Maximum number of training samples to use from the anime dataset. Useful for limiting training time or for testing. Set to None to use all available data. Default is DEFAULT_MAX_SAMPLES (10000).

TYPE: int DEFAULT: DEFAULT_MAX_SAMPLES

learning_rate

Learning rate for the optimizer. Controls how quickly model weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

TYPE: float DEFAULT: DEFAULT_LEARNING_RATE

eval_split

Fraction of data to use for evaluation instead of training. Must be between 0 and 1. Default is 0.1 (10% for evaluation).

TYPE: float DEFAULT: 0.1

seed

Random seed for reproducibility. Ensures the same training/evaluation split and data sampling across runs. Default is 42.

TYPE: int DEFAULT: 42

device

Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None, automatically selects GPU if available, otherwise CPU. Default is None.

TYPE: Optional[str] DEFAULT: None

dataset_path

Path to the anime dataset file. If None, uses the default anime dataset path. Default is None.

TYPE: Optional[str] DEFAULT: None

Notes
  • This constructor passes "anime" as the dataset_type to the parent class
  • The method automatically creates the output directory if it doesn't exist
  • The output path is constructed from the model name and "anime"
  • After initialization, the anime dataset is prepared for training
Source code in src/training/anime_trainer.py
def __init__(  # pylint: disable=too-many-arguments, too-many-positional-arguments
    self,
    model_name: str = MODEL_NAME,
    epochs: int = DEFAULT_EPOCHS,
    batch_size: int = DEFAULT_BATCH_SIZE,
    eval_steps: int = DEFAULT_EVAL_STEPS,
    warmup_steps: int = DEFAULT_WARMUP_STEPS,
    max_samples: int = DEFAULT_MAX_SAMPLES,
    learning_rate: float = DEFAULT_LEARNING_RATE,
    eval_split: float = 0.1,
    seed: int = 42,
    device: Optional[str] = None,
    dataset_path: Optional[str] = None,
):
    """
    Initialize the anime-specific trainer with configuration parameters.

    This constructor sets up the training environment specifically for anime data,
    passing "anime" as the dataset_type to the parent class. It configures all
    training parameters and loads the appropriate anime dataset.

    Args:
        model_name: The name or path of the base cross-encoder model to fine-tune.
            Can be a HuggingFace model identifier or a local path. Default is the
            value from MODEL_NAME constant.

        epochs: Number of complete passes through the training dataset. Higher values
            may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

        batch_size: Number of examples processed in each training step. Larger batches
            provide more stable gradients but require more memory. Default is
            DEFAULT_BATCH_SIZE (16).

        eval_steps: Number of training steps between model evaluations. If not specified,
            a reasonable value will be calculated based on dataset size. Default is
            DEFAULT_EVAL_STEPS (500).

        warmup_steps: Number of steps for learning rate warm-up. During warm-up, the
            learning rate gradually increases from 0 to the specified rate. Default
            is DEFAULT_WARMUP_STEPS (500).

        max_samples: Maximum number of training samples to use from the anime dataset.
            Useful for limiting training time or for testing. Set to None to use all
            available data. Default is DEFAULT_MAX_SAMPLES (10000).

        learning_rate: Learning rate for the optimizer. Controls how quickly model
            weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

        eval_split: Fraction of data to use for evaluation instead of training.
            Must be between 0 and 1. Default is 0.1 (10% for evaluation).

        seed: Random seed for reproducibility. Ensures the same training/evaluation
            split and data sampling across runs. Default is 42.

        device: Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None,
            automatically selects GPU if available, otherwise CPU. Default is None.

        dataset_path: Path to the anime dataset file. If None, uses the default
            anime dataset path. Default is None.

    Notes:
        - This constructor passes "anime" as the dataset_type to the parent class
        - The method automatically creates the output directory if it doesn't exist
        - The output path is constructed from the model name and "anime"
        - After initialization, the anime dataset is prepared for training
    """
    super().__init__(
        dataset_type="anime",
        model_name=model_name,
        epochs=epochs,
        batch_size=batch_size,
        eval_steps=eval_steps,
        warmup_steps=warmup_steps,
        max_samples=max_samples,
        learning_rate=learning_rate,
        eval_split=eval_split,
        seed=seed,
        device=device,
        dataset_path=dataset_path,
    )
    logger.info("Initialized AnimeModelTrainer")

create_query_variations

create_query_variations(base_queries: List[str], n_variations: int = 7) -> List[str]

Create anime-specific variations of base queries to improve training robustness.

This method overrides the parent class implementation to generate query variations specifically tailored for anime search, using templates that reflect how users typically search for anime content (e.g., "Looking for anime about...", "Anime similar to...").

The variations help the model learn to recognize the same anime-related intent expressed in different ways, making it more robust to real-world search queries.

PARAMETER DESCRIPTION
base_queries

List of original query strings (typically anime titles or descriptions) that will be used as the basis for generating variations.

TYPE: List[str]

n_variations

Number of anime-specific variations to create for each base query. If this exceeds the number of available templates, all templates will be used. Default is 7.

TYPE: int DEFAULT: 7

RETURNS DESCRIPTION
List[str]

List[str]: A combined list containing both the original queries and their anime-specific variations. The length will be approximately len(base_queries) * (1 + n_variations), but may be less if n_variations exceeds the number of available templates.

Example
# Create variations of anime titles
titles = ["Naruto", "One Piece", "Attack on Titan"]
trainer = AnimeModelTrainer()
variations = trainer.create_query_variations(titles, n_variations=3)

# Print all variations
for var in variations:
    print(var)
# Example output:
# Naruto
# Looking for anime about Naruto
# I want to watch anime with Naruto
# Find me anime where Naruto
# One Piece
# ...etc.
Notes
  • The method always includes the original queries in the returned list
  • Templates are selected randomly for each query
  • All templates include the word "anime" to help the model recognize anime-specific search patterns
  • This anime-specific implementation provides better training examples than the generic implementation in the parent class
Source code in src/training/anime_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def create_query_variations(
    self, base_queries: List[str], n_variations: int = 7
) -> List[str]:
    """
    Create anime-specific variations of base queries to improve training robustness.

    This method overrides the parent class implementation to generate query variations
    specifically tailored for anime search, using templates that reflect how users
    typically search for anime content (e.g., "Looking for anime about...",
    "Anime similar to...").

    The variations help the model learn to recognize the same anime-related intent
    expressed in different ways, making it more robust to real-world search queries.

    Args:
        base_queries: List of original query strings (typically anime titles or
            descriptions) that will be used as the basis for generating variations.

        n_variations: Number of anime-specific variations to create for each base
            query. If this exceeds the number of available templates, all templates
            will be used. Default is 7.

    Returns:
        List[str]: A combined list containing both the original queries and their
            anime-specific variations. The length will be approximately
            len(base_queries) * (1 + n_variations), but may be less if n_variations
            exceeds the number of available templates.

    Example:
        ```python
        # Create variations of anime titles
        titles = ["Naruto", "One Piece", "Attack on Titan"]
        trainer = AnimeModelTrainer()
        variations = trainer.create_query_variations(titles, n_variations=3)

        # Print all variations
        for var in variations:
            print(var)
        # Example output:
        # Naruto
        # Looking for anime about Naruto
        # I want to watch anime with Naruto
        # Find me anime where Naruto
        # One Piece
        # ...etc.
        ```

    Notes:
        - The method always includes the original queries in the returned list
        - Templates are selected randomly for each query
        - All templates include the word "anime" to help the model recognize
          anime-specific search patterns
        - This anime-specific implementation provides better training examples
          than the generic implementation in the parent class
    """
    # Add anime-specific templates
    anime_templates = [
        "Looking for anime about {query}",
        "I want to watch anime with {query}",
        "Find me anime where {query}",
        "Can you recommend anime that has {query}",
        "What anime is about {query}",
        "Anime similar to {query}",
        "{query} anime recommendation",
        "I'm looking for anime with {query}",
        "I'm searching for anime with {query}",
        "I'm trying to find anime with {query}",
    ]

    variations = []
    for query in base_queries:
        # Add the original query
        variations.append(query)

        # Select n_variations randomly from anime-specific templates
        n_to_use = min(n_variations, len(anime_templates))
        selected_templates = random.sample(anime_templates, n_to_use)
        for template in selected_templates:
            variations.append(template.format(query=query))

    return variations

Manga Trainer

Specialized trainer for manga models:

src.training.manga_trainer.MangaModelTrainer

MangaModelTrainer(model_name: str = MODEL_NAME, epochs: int = DEFAULT_EPOCHS, batch_size: int = DEFAULT_BATCH_SIZE, eval_steps: int = DEFAULT_EVAL_STEPS, warmup_steps: int = DEFAULT_WARMUP_STEPS, max_samples: int = DEFAULT_MAX_SAMPLES, learning_rate: float = DEFAULT_LEARNING_RATE, eval_split: float = 0.1, seed: int = 42, device: Optional[str] = None, dataset_path: Optional[str] = None, include_light_novels: bool = False)

Bases: BaseModelTrainer

Specialized trainer for fine-tuning cross-encoder models on manga datasets.

This class extends the BaseModelTrainer with manga-specific functionality, simplifying the creation of search models optimized for manga content. It automatically configures the training process for manga datasets and provides manga-specific query generation for more robust training.

The trainer creates relevant training examples using manga titles and synopses, and generates query variations that reflect how users typically search for manga content (e.g., "Looking for manga about...", "Manga similar to...").

ATTRIBUTE DESCRIPTION
dataset_type

Fixed to "manga" to specify this trainer works with manga datasets.

TYPE: str

include_light_novels

Flag indicating whether light novels should be included in the training dataset. When False, light novels are filtered out.

TYPE: bool

model_name

Name of the base cross-encoder model used for fine-tuning.

TYPE: str

epochs

Number of training epochs.

TYPE: int

batch_size

Number of examples processed in each training step.

TYPE: int

eval_steps

Number of steps between model evaluations.

TYPE: int

warmup_steps

Number of warmup steps for the learning rate scheduler.

TYPE: int

max_samples

Maximum number of training samples to use.

TYPE: int

learning_rate

Learning rate for the optimizer.

TYPE: float

eval_split

Fraction of data used for evaluation.

TYPE: float

seed

Random seed for reproducibility.

TYPE: int

device

Device used for training (cpu or cuda).

TYPE: str

df

The loaded manga dataset after preparation.

TYPE: DataFrame

Example
# Initialize a trainer for manga model, excluding light novels
trainer = MangaModelTrainer(
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    epochs=5,
    batch_size=16,
    include_light_novels=False
)

# Train the model with MSE loss and linear scheduler
model_path = trainer.train(loss_type="mse", scheduler="linear")
print(f"Manga search model saved to: {model_path}")

# Create labeled data for inspection
trainer.create_and_save_labeled_data(
    output_file="manga_labeled_data.csv",
    n_samples=5000
)
Notes
  • The trainer automatically uses the default manga dataset path unless specified
  • For best results, ensure your manga dataset contains adequate synopsis information and metadata like genres and themes
  • This class sets dataset_type="manga" in the parent class, focusing all operations on manga data
  • Light novels can be excluded from training to create more manga-specific models

This constructor sets up the training environment specifically for manga data, passing "manga" as the dataset_type to the parent class. It configures all training parameters, loads the appropriate manga dataset, and optionally filters out light novels from the dataset.

PARAMETER DESCRIPTION
model_name

The name or path of the base cross-encoder model to fine-tune. Can be a HuggingFace model identifier or a local path. Default is the value from MODEL_NAME constant.

TYPE: str DEFAULT: MODEL_NAME

epochs

Number of complete passes through the training dataset. Higher values may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

TYPE: int DEFAULT: DEFAULT_EPOCHS

batch_size

Number of examples processed in each training step. Larger batches provide more stable gradients but require more memory. Default is DEFAULT_BATCH_SIZE (16).

TYPE: int DEFAULT: DEFAULT_BATCH_SIZE

eval_steps

Number of training steps between model evaluations. If not specified, a reasonable value will be calculated based on dataset size. Default is DEFAULT_EVAL_STEPS (500).

TYPE: int DEFAULT: DEFAULT_EVAL_STEPS

warmup_steps

Number of steps for learning rate warm-up. During warm-up, the learning rate gradually increases from 0 to the specified rate. Default is DEFAULT_WARMUP_STEPS (500).

TYPE: int DEFAULT: DEFAULT_WARMUP_STEPS

max_samples

Maximum number of training samples to use from the manga dataset. Useful for limiting training time or for testing. Set to None to use all available data. Default is DEFAULT_MAX_SAMPLES (10000).

TYPE: int DEFAULT: DEFAULT_MAX_SAMPLES

learning_rate

Learning rate for the optimizer. Controls how quickly model weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

TYPE: float DEFAULT: DEFAULT_LEARNING_RATE

eval_split

Fraction of data to use for evaluation instead of training. Must be between 0 and 1. Default is 0.1 (10% for evaluation).

TYPE: float DEFAULT: 0.1

seed

Random seed for reproducibility. Ensures the same training/evaluation split and data sampling across runs. Default is 42.

TYPE: int DEFAULT: 42

device

Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None, automatically selects GPU if available, otherwise CPU. Default is None.

TYPE: Optional[str] DEFAULT: None

dataset_path

Path to the manga dataset file. If None, uses the default manga dataset path. Default is None.

TYPE: Optional[str] DEFAULT: None

include_light_novels

Whether to include light novels in the manga dataset. When False, entries identified as light novels based on their genres will be filtered out. Default is False.

TYPE: bool DEFAULT: False

Notes
  • This constructor passes "manga" as the dataset_type to the parent class
  • The method automatically creates the output directory if it doesn't exist
  • The output path is constructed from the model name and "manga"
  • After initialization, the manga dataset is prepared for training
  • If include_light_novels is False, light novels will be filtered from the dataset
Source code in src/training/manga_trainer.py
def __init__(  # pylint: disable=too-many-arguments, too-many-positional-arguments
    self,
    model_name: str = MODEL_NAME,
    epochs: int = DEFAULT_EPOCHS,
    batch_size: int = DEFAULT_BATCH_SIZE,
    eval_steps: int = DEFAULT_EVAL_STEPS,
    warmup_steps: int = DEFAULT_WARMUP_STEPS,
    max_samples: int = DEFAULT_MAX_SAMPLES,
    learning_rate: float = DEFAULT_LEARNING_RATE,
    eval_split: float = 0.1,
    seed: int = 42,
    device: Optional[str] = None,
    dataset_path: Optional[str] = None,
    include_light_novels: bool = False,
):
    """
    Initialize the manga-specific trainer with configuration parameters.

    This constructor sets up the training environment specifically for manga data,
    passing "manga" as the dataset_type to the parent class. It configures all
    training parameters, loads the appropriate manga dataset, and optionally filters
    out light novels from the dataset.

    Args:
        model_name: The name or path of the base cross-encoder model to fine-tune.
            Can be a HuggingFace model identifier or a local path. Default is the
            value from MODEL_NAME constant.

        epochs: Number of complete passes through the training dataset. Higher values
            may improve performance but risk overfitting. Default is DEFAULT_EPOCHS (3).

        batch_size: Number of examples processed in each training step. Larger batches
            provide more stable gradients but require more memory. Default is
            DEFAULT_BATCH_SIZE (16).

        eval_steps: Number of training steps between model evaluations. If not specified,
            a reasonable value will be calculated based on dataset size. Default is
            DEFAULT_EVAL_STEPS (500).

        warmup_steps: Number of steps for learning rate warm-up. During warm-up, the
            learning rate gradually increases from 0 to the specified rate. Default
            is DEFAULT_WARMUP_STEPS (500).

        max_samples: Maximum number of training samples to use from the manga dataset.
            Useful for limiting training time or for testing. Set to None to use all
            available data. Default is DEFAULT_MAX_SAMPLES (10000).

        learning_rate: Learning rate for the optimizer. Controls how quickly model
            weights are updated during training. Default is DEFAULT_LEARNING_RATE (2e-6).

        eval_split: Fraction of data to use for evaluation instead of training.
            Must be between 0 and 1. Default is 0.1 (10% for evaluation).

        seed: Random seed for reproducibility. Ensures the same training/evaluation
            split and data sampling across runs. Default is 42.

        device: Device to use for training ('cpu', 'cuda', 'cuda:0', etc.). If None,
            automatically selects GPU if available, otherwise CPU. Default is None.

        dataset_path: Path to the manga dataset file. If None, uses the default
            manga dataset path. Default is None.

        include_light_novels: Whether to include light novels in the manga dataset.
            When False, entries identified as light novels based on their genres will
            be filtered out. Default is False.

    Notes:
        - This constructor passes "manga" as the dataset_type to the parent class
        - The method automatically creates the output directory if it doesn't exist
        - The output path is constructed from the model name and "manga"
        - After initialization, the manga dataset is prepared for training
        - If include_light_novels is False, light novels will be filtered from the dataset
    """
    super().__init__(
        dataset_type="manga",
        model_name=model_name,
        epochs=epochs,
        batch_size=batch_size,
        eval_steps=eval_steps,
        warmup_steps=warmup_steps,
        max_samples=max_samples,
        learning_rate=learning_rate,
        eval_split=eval_split,
        seed=seed,
        device=device,
        dataset_path=dataset_path,
    )
    self.include_light_novels = include_light_novels
    logger.info(
        "Initialized MangaModelTrainer with include_light_novels=%s",
        include_light_novels,
    )

    # Filter light novels if necessary
    if not self.include_light_novels:
        self._filter_light_novels()

include_light_novels instance-attribute

include_light_novels = include_light_novels

create_query_variations

create_query_variations(base_queries: List[str], n_variations: int = 7) -> List[str]

Create manga-specific variations of base queries to improve training robustness.

This method overrides the parent class implementation to generate query variations specifically tailored for manga search, using templates that reflect how users typically search for manga content (e.g., "Looking for manga about...", "Manga similar to...").

The variations help the model learn to recognize the same manga-related intent expressed in different ways, making it more robust to real-world search queries.

PARAMETER DESCRIPTION
base_queries

List of original query strings (typically manga titles or descriptions) that will be used as the basis for generating variations.

TYPE: List[str]

n_variations

Number of manga-specific variations to create for each base query. If this exceeds the number of available templates, all templates will be used. Default is 7.

TYPE: int DEFAULT: 7

RETURNS DESCRIPTION
List[str]

List[str]: A combined list containing both the original queries and their manga-specific variations. The length will be approximately len(base_queries) * (1 + n_variations), but may be less if n_variations exceeds the number of available templates.

Example
# Create variations of manga titles
titles = ["One Piece", "Berserk", "Chainsaw Man"]
trainer = MangaModelTrainer()
variations = trainer.create_query_variations(titles, n_variations=3)

# Print all variations
for var in variations:
    print(var)
# Example output:
# One Piece
# Looking for manga about One Piece
# I want to read manga with One Piece
# Find me manga where One Piece
# Berserk
# ...etc.
Notes
  • The method always includes the original queries in the returned list
  • Templates are selected randomly for each query
  • All templates include the word "manga" to help the model recognize manga-specific search patterns
  • This manga-specific implementation provides better training examples than the generic implementation in the parent class
  • The manga templates focus on reading rather than watching (compared to anime)
Source code in src/training/manga_trainer.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def create_query_variations(
    self, base_queries: List[str], n_variations: int = 7
) -> List[str]:
    """
    Create manga-specific variations of base queries to improve training robustness.

    This method overrides the parent class implementation to generate query variations
    specifically tailored for manga search, using templates that reflect how users
    typically search for manga content (e.g., "Looking for manga about...",
    "Manga similar to...").

    The variations help the model learn to recognize the same manga-related intent
    expressed in different ways, making it more robust to real-world search queries.

    Args:
        base_queries: List of original query strings (typically manga titles or
            descriptions) that will be used as the basis for generating variations.

        n_variations: Number of manga-specific variations to create for each base
            query. If this exceeds the number of available templates, all templates
            will be used. Default is 7.

    Returns:
        List[str]: A combined list containing both the original queries and their
            manga-specific variations. The length will be approximately
            len(base_queries) * (1 + n_variations), but may be less if n_variations
            exceeds the number of available templates.

    Example:
        ```python
        # Create variations of manga titles
        titles = ["One Piece", "Berserk", "Chainsaw Man"]
        trainer = MangaModelTrainer()
        variations = trainer.create_query_variations(titles, n_variations=3)

        # Print all variations
        for var in variations:
            print(var)
        # Example output:
        # One Piece
        # Looking for manga about One Piece
        # I want to read manga with One Piece
        # Find me manga where One Piece
        # Berserk
        # ...etc.
        ```

    Notes:
        - The method always includes the original queries in the returned list
        - Templates are selected randomly for each query
        - All templates include the word "manga" to help the model recognize
          manga-specific search patterns
        - This manga-specific implementation provides better training examples
          than the generic implementation in the parent class
        - The manga templates focus on reading rather than watching (compared to anime)
    """
    # Add manga-specific templates
    manga_templates = [
        "Looking for manga about {query}",
        "I want to read manga with {query}",
        "Find me manga where {query}",
        "Can you recommend manga that has {query}",
        "What manga is about {query}",
        "Manga similar to {query}",
        "{query} manga recommendation",
        "I'm looking for manga with {query}",
        "I'm searching for manga with {query}",
        "I'm trying to find manga with {query}",
        "What manga should I read if I like {query}",
        "Good manga about {query}",
    ]

    variations = []
    for query in base_queries:
        # Add the original query
        variations.append(query)

        # Select n_variations randomly from manga-specific templates
        n_to_use = min(n_variations, len(manga_templates))
        selected_templates = random.sample(manga_templates, n_to_use)
        for template in selected_templates:
            variations.append(template.format(query=query))

    return variations

Dataset

Dataset implementation for cross-encoder training:

src.training.dataset.InputExampleDataset

InputExampleDataset(examples: List[InputExample])

Bases: Dataset

PyTorch Dataset wrapper for a collection of SentenceTransformers InputExamples.

This dataset class adapts a list of InputExample objects (from SentenceTransformers) to be compatible with PyTorch's data loading utilities. It enables efficient batch processing during training and evaluation of cross-encoder models.

InputExamples typically contain: - A pair of texts (query and document) - A label indicating relevance or similarity - An optional identifier

This dataset implementation allows seamless integration with PyTorch's DataLoader for efficient batching, shuffling, and parallel data loading during training.

ATTRIBUTE DESCRIPTION
examples

A list of InputExample objects containing the text pairs and labels for training or evaluation.

Example
from sentence_transformers import InputExample
from torch.utils.data import DataLoader

# Create example data
examples = [
    InputExample(texts=['anime query', 'anime description'], label=1.0),
    InputExample(texts=['unrelated query', 'anime description'], label=0.0),
    # ... more examples
]

# Create dataset
dataset = InputExampleDataset(examples)

# Create DataLoader for training
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Use in training loop
for batch in train_dataloader:
    # Process batch...
    pass
PARAMETER DESCRIPTION
examples

List of InputExample objects from SentenceTransformers. Each example should contain a pair of texts and a label. For cross-encoder training, each example typically contains: - texts[0]: The query text - texts[1]: The document text - label: A float value indicating relevance (typically 0 to 1)

TYPE: List[InputExample]

Source code in src/training/dataset.py
def __init__(self, examples: List[InputExample]):
    """
    Initialize the dataset with a list of InputExample objects.

    Args:
        examples: List of InputExample objects from SentenceTransformers.
            Each example should contain a pair of texts and a label.
            For cross-encoder training, each example typically contains:
            - texts[0]: The query text
            - texts[1]: The document text
            - label: A float value indicating relevance (typically 0 to 1)
    """
    self.examples = examples

examples instance-attribute

examples = examples

__getitem__

__getitem__(idx: int) -> InputExample

Retrieve an example by its index.

This method is required by PyTorch's Dataset interface and is called by DataLoader during batch generation. It retrieves a single example by its index in the examples list.

PARAMETER DESCRIPTION
idx

Integer index of the example to retrieve, must be in range 0 <= idx < len(self).

TYPE: int

RETURNS DESCRIPTION
InputExample

The example at the specified index, containing text pairs and a label.

TYPE: InputExample

RAISES DESCRIPTION
IndexError

If idx is out of bounds for the examples list.

Source code in src/training/dataset.py
def __getitem__(self, idx: int) -> InputExample:
    """
    Retrieve an example by its index.

    This method is required by PyTorch's Dataset interface and is called by
    DataLoader during batch generation. It retrieves a single example by its
    index in the examples list.

    Args:
        idx: Integer index of the example to retrieve, must be in range
            0 <= idx < len(self).

    Returns:
        InputExample: The example at the specified index, containing text pairs
            and a label.

    Raises:
        IndexError: If idx is out of bounds for the examples list.
    """
    return self.examples[idx]

__len__

__len__() -> int

Return the number of examples in the dataset.

This method is required by PyTorch's Dataset interface and is called by DataLoader to determine the size of the dataset and the number of batches.

RETURNS DESCRIPTION
int

The total number of examples in the dataset.

TYPE: int

Source code in src/training/dataset.py
def __len__(self) -> int:
    """
    Return the number of examples in the dataset.

    This method is required by PyTorch's Dataset interface and is called by
    DataLoader to determine the size of the dataset and the number of batches.

    Returns:
        int: The total number of examples in the dataset.
    """
    return len(self.examples)

Training Utilities

Helper functions for model training:

Training Utilities

Utility functions for training and fine-tuning cross-encoder models.

This module provides specialized utility functions for training, evaluating, and optimizing models in the anime/manga search application. It includes functionality for text preprocessing, device management, and optimization settings that are specifically tailored for cross-encoder model training.

Features

  • Random seed initialization for reproducible experiments
  • Device detection and configuration for CPU/GPU training
  • Efficient batch text truncation for handling large text pairs
  • Data parsing utilities for handling list data from datasets
  • Default training parameters for common scenarios

Usage Context

These utilities are primarily used in:

  1. Model fine-tuning workflows
  2. Training script configuration
  3. Dataset preprocessing for training

The functions work together to provide a consistent environment for model training and help manage the complexities of preparing text data for transformer models.

DEFAULT_BATCH_SIZE module-attribute

DEFAULT_BATCH_SIZE: int = 16

Default training batch size.

This batch size works well on most consumer GPUs with 8GB+ VRAM. Adjust based on available memory - larger batches generally provide more stable training but require more memory.

DEFAULT_EPOCHS module-attribute

DEFAULT_EPOCHS: int = 3

Default number of training epochs.

The model will iterate over the training data this many times. For cross-encoder models, 3 epochs is often sufficient to get good performance while avoiding overfitting.

DEFAULT_EVAL_STEPS module-attribute

DEFAULT_EVAL_STEPS: int = 500

Default number of steps between model evaluations during training.

Controls how frequently the model is evaluated on the validation set.

DEFAULT_LEARNING_RATE module-attribute

DEFAULT_LEARNING_RATE: float = 2e-06

Default learning rate for fine-tuning.

A conservative learning rate that works well for most cross-encoder fine-tuning. Smaller than typical learning rates for training from scratch to avoid disrupting pre-trained weights.

DEFAULT_MAX_SAMPLES module-attribute

DEFAULT_MAX_SAMPLES: int = 10000

Default maximum number of training samples to use.

Limits the training dataset size to avoid excessive training times for large datasets. Set to None to use the entire dataset.

DEFAULT_WARMUP_STEPS module-attribute

DEFAULT_WARMUP_STEPS: int = 500

Default number of learning rate warmup steps.

Learning rate starts at a low value and gradually increases to the full learning rate over this many steps, which helps with training stability.

MODEL_SAVE_PATH module-attribute

MODEL_SAVE_PATH: str = 'model/fine-tuned/'

Path where fine-tuned models are saved.

logger module-attribute

logger = getLogger(__name__)

batch_truncate_text_pairs

batch_truncate_text_pairs(text_pairs: List[Tuple[str, str]], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length: int = 512, batch_size: int = 128) -> List[Tuple[str, str]]

Efficiently truncate multiple text pairs to fit within a specified token length.

This function processes a large list of text pairs (query, text) and truncates them to fit within the model's maximum sequence length. It uses a batch processing approach for efficiency and preserves as much of the query (text_a) as possible, truncating the document (text_b) to fit the remaining space.

The truncation process: 1. Tokenizes all text_a entries to calculate their token lengths 2. For each pair, reserves tokens for text_a plus special tokens 3. Allocates remaining tokens for text_b and truncates as needed 4. Performs validation checks on a sample of results to ensure compliance

PARAMETER DESCRIPTION
text_pairs

List of tuples, each containing two strings (text_a, text_b). Typically, text_a is a query and text_b is a document or longer text.

TYPE: List[Tuple[str, str]]

tokenizer

The tokenizer that will be used with the model. Must be a transformers PreTrainedTokenizer or PreTrainedTokenizerFast instance compatible with the target model.

TYPE: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

max_length

Maximum allowed sequence length in tokens (including all special tokens). Default is 512, which is common for many transformer models.

TYPE: int DEFAULT: 512

batch_size

Number of text pairs to process in each batch. Higher values increase memory usage but improve processing speed. Default is 128.

TYPE: int DEFAULT: 128

RETURNS DESCRIPTION
List[Tuple[str, str]]

List[Tuple[str, str]]: A list of truncated text pairs, where each text_b has been truncated as needed to fit within the max_length constraint when combined with its text_a.

Example
from transformers import AutoTokenizer

# Load a tokenizer
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-6-v2")

# Sample text pairs (query, document)
text_pairs = [
    ("short query", "very long document text that exceeds the limit..."),
    ("another query", "another document that's also quite long...")
]

# Truncate to fit model's constraints
truncated_pairs = batch_truncate_text_pairs(
    text_pairs=text_pairs,
    tokenizer=tokenizer,
    max_length=128,  # Short for example purposes
    batch_size=32
)
Notes
  • The function prioritizes preserving text_a (usually the query) completely
  • Only text_b is truncated unless absolutely necessary
  • The function includes a double-check mechanism that samples some pairs to verify they actually fit within max_length
  • Very long text_a entries might result in empty text_b if there's no space left
  • The function uses the @handle_exceptions decorator for error handling
Source code in src/training/utils.py
@handle_exceptions(log_exceptions=True, include_exc_info=True)
def batch_truncate_text_pairs(
    text_pairs: List[Tuple[str, str]],
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    max_length: int = 512,
    batch_size: int = 128,
) -> List[Tuple[str, str]]:
    """
    Efficiently truncate multiple text pairs to fit within a specified token length.

    This function processes a large list of text pairs (query, text) and truncates
    them to fit within the model's maximum sequence length. It uses a batch processing
    approach for efficiency and preserves as much of the query (text_a) as possible,
    truncating the document (text_b) to fit the remaining space.

    The truncation process:
    1. Tokenizes all text_a entries to calculate their token lengths
    2. For each pair, reserves tokens for text_a plus special tokens
    3. Allocates remaining tokens for text_b and truncates as needed
    4. Performs validation checks on a sample of results to ensure compliance

    Args:
        text_pairs: List of tuples, each containing two strings (text_a, text_b).
            Typically, text_a is a query and text_b is a document or longer text.

        tokenizer: The tokenizer that will be used with the model. Must be a
            transformers PreTrainedTokenizer or PreTrainedTokenizerFast instance
            compatible with the target model.

        max_length: Maximum allowed sequence length in tokens (including all
            special tokens). Default is 512, which is common for many transformer
            models.

        batch_size: Number of text pairs to process in each batch. Higher values
            increase memory usage but improve processing speed. Default is 128.

    Returns:
        List[Tuple[str, str]]: A list of truncated text pairs, where each text_b
            has been truncated as needed to fit within the max_length constraint
            when combined with its text_a.

    Example:
        ```python
        from transformers import AutoTokenizer

        # Load a tokenizer
        tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-6-v2")

        # Sample text pairs (query, document)
        text_pairs = [
            ("short query", "very long document text that exceeds the limit..."),
            ("another query", "another document that's also quite long...")
        ]

        # Truncate to fit model's constraints
        truncated_pairs = batch_truncate_text_pairs(
            text_pairs=text_pairs,
            tokenizer=tokenizer,
            max_length=128,  # Short for example purposes
            batch_size=32
        )
        ```

    Notes:
        - The function prioritizes preserving text_a (usually the query) completely
        - Only text_b is truncated unless absolutely necessary
        - The function includes a double-check mechanism that samples some pairs
          to verify they actually fit within max_length
        - Very long text_a entries might result in empty text_b if there's no space left
        - The function uses the @handle_exceptions decorator for error handling
    """
    results = []
    total_batches = math.ceil(len(text_pairs) / batch_size)

    # First, tokenize all text_a entries to get their lengths
    logger.info("Pre-computing text_a token lengths")
    text_a_list = [pair[0] for pair in text_pairs]

    # Process queries in smaller batches to avoid memory issues
    text_a_lengths = []
    sub_batch_size = 1000  # Smaller batch size for tokenization

    for i in range(0, len(text_a_list), sub_batch_size):
        sub_batch = text_a_list[i : i + sub_batch_size]
        # Get token counts using direct encoding
        sub_lengths = [
            len(tokenizer.encode(text, add_special_tokens=True)) for text in sub_batch
        ]
        text_a_lengths.extend(sub_lengths)

    # Process in batches
    logger.info("Truncating text pairs in batches")
    for batch_idx in tqdm(range(total_batches), desc="Truncating pairs"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(text_pairs))
        batch_pairs = text_pairs[start_idx:end_idx]
        batch_a_lengths = text_a_lengths[start_idx:end_idx]

        # Calculate available tokens for each text_b in the batch
        special_tokens_count = 3  # Account for special tokens
        available_tokens_for_b = [
            max(0, max_length - length - special_tokens_count)
            for length in batch_a_lengths
        ]

        # Extract text_b entries for the batch
        batch_text_b = [pair[1] for pair in batch_pairs]

        # Process text_b truncation sequentially
        batch_truncated_b = []
        for i, (text_b, available_tokens) in enumerate(
            zip(batch_text_b, available_tokens_for_b)
        ):
            if available_tokens <= 0:
                batch_truncated_b.append("")
                continue

            # Tokenize and truncate
            truncated_b = tokenizer.encode(
                text_b,
                add_special_tokens=False,
                max_length=available_tokens,
                truncation=True,
            )

            # Decode back to text
            batch_truncated_b.append(
                tokenizer.decode(truncated_b, skip_special_tokens=True)
            )

        # Create truncated pairs for this batch
        batch_results = [
            (batch_pairs[i][0], batch_truncated_b[i]) for i in range(len(batch_pairs))
        ]

        # Double-check only a small sample (every 20th) to save time
        for i in range(0, len(batch_results), 20):
            if i >= len(batch_results):
                break

            text_a, text_b = batch_results[i]
            if not text_b:  # Skip empty text_b
                continue

            # Check the final length
            final_tokens = tokenizer.encode(
                text_a, text_b, add_special_tokens=True, truncation=False
            )
            final_length = len(final_tokens)

            # Emergency truncation if needed
            if final_length > max_length:
                available = max(
                    0,
                    max_length - batch_a_lengths[i] - special_tokens_count - 5,
                )
                if available <= 0:
                    batch_results[i] = (text_a, "")
                else:
                    truncated_b = tokenizer.encode(
                        batch_text_b[i],
                        add_special_tokens=False,
                        max_length=available,
                        truncation=True,
                    )
                    batch_results[i] = (
                        text_a,
                        tokenizer.decode(truncated_b, skip_special_tokens=True),
                    )

        results.extend(batch_results)

    return results

get_device

get_device(device: Optional[str] = None) -> str

Determine the appropriate computing device for model training and inference.

This function selects the best available device for running models, defaulting to CUDA (GPU) if available, and falling back to CPU if not. It also allows for explicitly specifying a device if needed.

PARAMETER DESCRIPTION
device

Optional string specifying the device to use. If provided, this overrides the automatic detection. Valid values include 'cpu', 'cuda', 'cuda:0', etc. Default is None, which triggers automatic detection.

TYPE: Optional[str] DEFAULT: None

RETURNS DESCRIPTION
str

A string identifier for the device to use, compatible with PyTorch's device specification format (e.g., 'cuda', 'cpu', 'cuda:1').

TYPE: str

Example
# Get the best available device
device = get_device()

# Move a model to the appropriate device
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
model.model = model.model.to(device)

# Or explicitly specify a device
device = get_device('cuda:1')  # Use second GPU if available
Notes
  • CUDA device is only returned if PyTorch can access CUDA
  • The function doesn't check for specific CUDA device availability beyond what torch.cuda.is_available() provides
  • For multi-GPU setups, you may want to explicitly specify a device or implement more sophisticated device selection logic
Source code in src/training/utils.py
def get_device(device: Optional[str] = None) -> str:
    """
    Determine the appropriate computing device for model training and inference.

    This function selects the best available device for running models, defaulting
    to CUDA (GPU) if available, and falling back to CPU if not. It also allows for
    explicitly specifying a device if needed.

    Args:
        device: Optional string specifying the device to use. If provided, this
            overrides the automatic detection. Valid values include 'cpu', 'cuda',
            'cuda:0', etc. Default is None, which triggers automatic detection.

    Returns:
        str: A string identifier for the device to use, compatible with PyTorch's
            device specification format (e.g., 'cuda', 'cpu', 'cuda:1').

    Example:
        ```python
        # Get the best available device
        device = get_device()

        # Move a model to the appropriate device
        model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        model.model = model.model.to(device)

        # Or explicitly specify a device
        device = get_device('cuda:1')  # Use second GPU if available
        ```

    Notes:
        - CUDA device is only returned if PyTorch can access CUDA
        - The function doesn't check for specific CUDA device availability beyond
          what torch.cuda.is_available() provides
        - For multi-GPU setups, you may want to explicitly specify a device
          or implement more sophisticated device selection logic
    """
    if device is None:
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device

parse_list_column

parse_list_column(column_value: Any) -> List[str]

Parse a list column from a dataset that may be stored as a string representation.

This function handles various formats of list data that may come from CSV or DataFrame columns, converting them to Python lists. It handles:

  • String representations of lists like "[item1, item2, item3]"
  • Already parsed list objects
  • Single string values (converted to a single-item list)
  • None or NaN values (converted to empty list)
PARAMETER DESCRIPTION
column_value

The value to parse, which could be a string representation of a list, an actual list object, a single string, or a missing value (None/NaN).

TYPE: Any

RETURNS DESCRIPTION
List[str]

List[str]: A list of strings parsed from the input. Returns an empty list for None or NaN values.

Example
# Parse string representation of a list
genres = parse_list_column("['Action', 'Comedy', 'Drama']")
# Result: ['Action', 'Comedy', 'Drama']

# Parse a single string
tags = parse_list_column("Shounen")
# Result: ['Shounen']

# Handle NaN values
empty = parse_list_column(float('nan'))
# Result: []
Notes
  • Uses ast.literal_eval to safely parse string representations of lists
  • Falls back to comma splitting if literal_eval fails
  • Handles missing values (None, NaN) by returning an empty list
  • Non-string, non-list inputs that aren't NaN will result in an empty list
Source code in src/training/utils.py
def parse_list_column(column_value: Any) -> List[str]:
    """
    Parse a list column from a dataset that may be stored as a string representation.

    This function handles various formats of list data that may come from CSV or
    DataFrame columns, converting them to Python lists. It handles:

    - String representations of lists like "[item1, item2, item3]"
    - Already parsed list objects
    - Single string values (converted to a single-item list)
    - None or NaN values (converted to empty list)

    Args:
        column_value: The value to parse, which could be a string representation of
            a list, an actual list object, a single string, or a missing value (None/NaN).

    Returns:
        List[str]: A list of strings parsed from the input. Returns an empty list
            for None or NaN values.

    Example:
        ```python
        # Parse string representation of a list
        genres = parse_list_column("['Action', 'Comedy', 'Drama']")
        # Result: ['Action', 'Comedy', 'Drama']

        # Parse a single string
        tags = parse_list_column("Shounen")
        # Result: ['Shounen']

        # Handle NaN values
        empty = parse_list_column(float('nan'))
        # Result: []
        ```

    Notes:
        - Uses ast.literal_eval to safely parse string representations of lists
        - Falls back to comma splitting if literal_eval fails
        - Handles missing values (None, NaN) by returning an empty list
        - Non-string, non-list inputs that aren't NaN will result in an empty list
    """
    if pd.isna(column_value):
        return []

    if isinstance(column_value, str):
        # Try to parse as literal if it looks like a list
        if column_value.startswith("[") and column_value.endswith("]"):
            try:
                return ast.literal_eval(column_value)
            except (ValueError, SyntaxError):
                # If parsing fails, split by comma
                return [item.strip() for item in column_value.strip("[]").split(",")]
        else:
            # If it's just a single string value
            return [column_value]
    elif isinstance(column_value, list):
        return column_value
    else:
        return []

setup_random_seeds

setup_random_seeds(seed: int = 42) -> None

Set random seeds for reproducibility across Python, NumPy, and PyTorch.

This function sets consistent random seeds for all random number generators used in the training process, ensuring that experiments can be reproduced with the same randomization patterns. It sets seeds for:

  • Python's random module
  • NumPy's random number generator
  • PyTorch's CPU random number generator
  • PyTorch's GPU random number generators (if available)
PARAMETER DESCRIPTION
seed

Integer value to use as the random seed. Default is 42, which is a common choice for reproducible machine learning experiments.

TYPE: int DEFAULT: 42

RETURNS DESCRIPTION
None

This function doesn't return a value but sets global random states.

TYPE: None

Example
# Initialize random seeds before training
setup_random_seeds(42)

# Now all random operations will be reproducible
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])
Notes
  • Using the same seed guarantees the same random sequence across runs
  • Different hardware or PyTorch versions might still produce variations
  • For full reproducibility, also set deterministic algorithms in PyTorch configurations and control the environment more strictly
Source code in src/training/utils.py
def setup_random_seeds(seed: int = 42) -> None:
    """
    Set random seeds for reproducibility across Python, NumPy, and PyTorch.

    This function sets consistent random seeds for all random number generators used
    in the training process, ensuring that experiments can be reproduced with the
    same randomization patterns. It sets seeds for:

    - Python's random module
    - NumPy's random number generator
    - PyTorch's CPU random number generator
    - PyTorch's GPU random number generators (if available)

    Args:
        seed: Integer value to use as the random seed. Default is 42, which is a
            common choice for reproducible machine learning experiments.

    Returns:
        None: This function doesn't return a value but sets global random states.

    Example:
        ```python
        # Initialize random seeds before training
        setup_random_seeds(42)

        # Now all random operations will be reproducible
        train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])
        ```

    Notes:
        - Using the same seed guarantees the same random sequence across runs
        - Different hardware or PyTorch versions might still produce variations
        - For full reproducibility, also set deterministic algorithms in PyTorch
          configurations and control the environment more strictly
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)