From 7fe009f1f12f69f5e1ac5344bfa3422b0b36e9b5 Mon Sep 17 00:00:00 2001 From: h4njlg Date: Thu, 15 Jan 2026 07:04:12 +0000 Subject: [PATCH] docs: Add MLflow integration tutorial and example Add comprehensive documentation and working example for MLflow integration: - Tutorial covering experiment tracking, pyfunc model creation, and deployment - Example scripts with training data and pyfunc wrapper implementation - Updated README with PyPI installation instructions --- README.md | 20 + docs/source/tutorials/index.md | 21 + docs/source/tutorials/mlflow_integration.md | 645 ++++++++++++++++++++ examples/mlflow_logging/README.md | 194 ++++++ examples/mlflow_logging/__init__.py | 49 ++ examples/mlflow_logging/data/test.csv | 7 + examples/mlflow_logging/data/train.csv | 31 + examples/mlflow_logging/data/val.csv | 7 + examples/mlflow_logging/pyfunc_wrapper.py | 406 ++++++++++++ examples/mlflow_logging/run_example.py | 519 ++++++++++++++++ 10 files changed, 1899 insertions(+) create mode 100644 docs/source/tutorials/mlflow_integration.md create mode 100644 examples/mlflow_logging/README.md create mode 100644 examples/mlflow_logging/__init__.py create mode 100644 examples/mlflow_logging/data/test.csv create mode 100644 examples/mlflow_logging/data/train.csv create mode 100644 examples/mlflow_logging/data/val.csv create mode 100644 examples/mlflow_logging/pyfunc_wrapper.py create mode 100644 examples/mlflow_logging/run_example.py diff --git a/README.md b/README.md index ad7ec83..f57e648 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,26 @@ A unified, extensible framework for text classification with categorical variabl ## 📦 Installation +### From Pypi (recommended) + +```bash +# With uv + +uv add torchTextClassifiers + +# For using huggingface tokenizers +uv add torchTextClassifiers --extra huggingface + +# With pip + +pip install torchTextclassifiers + +# For using huggingface tokenizers +pip install torchTextclassifiers[huggingface] +``` + + +### From Source ```bash # Clone the repository git clone https://github.com/InseeFrLab/torchTextClassifiers.git diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index 845c221..4977a86 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -10,6 +10,7 @@ multiclass_classification mixed_features explainability multilabel_classification +mlflow_integration ``` ## Overview @@ -116,6 +117,22 @@ Assign multiple labels to each text sample for complex classification scenarios. **Difficulty:** Advanced | **Time:** 30 minutes ::: +:::{grid-item-card} {fas}`chart-line` MLflow Integration +:link: mlflow_integration +:link-type: doc + +Track experiments and deploy models with MLflow for production-ready ML pipelines. + +**What you'll learn:** +- Log training metrics per epoch +- Create portable pyfunc models +- Export models without package dependency +- Flexible inference input formats +- Use MLflow UI for visualization + +**Difficulty:** Advanced | **Time:** 25 minutes +::: + :::: ## Learning Path @@ -129,6 +146,7 @@ graph LR C --> D[Mixed Features] C --> F[Multilabel Classification] D --> E[Explainability] + D --> G[MLflow Integration] F --> E style A fill:#e3f2fd @@ -137,6 +155,7 @@ graph LR style D fill:#64b5f6 style E fill:#1976d2 style F fill:#42a5f5 + style G fill:#4caf50 ``` 1. **Start with**: {doc}`../getting_started/quickstart` - Get familiar with the basics @@ -144,6 +163,7 @@ graph LR 3. **Next**: {doc}`multiclass_classification` - Handle multiple classes 4. **Branch out**: {doc}`mixed_features` for categorical features OR {doc}`multilabel_classification` for multiple labels 5. **Master**: {doc}`explainability` - Understand your model's predictions +6. **Deploy**: {doc}`mlflow_integration` - Track experiments and deploy to production ## Tutorial Format @@ -238,6 +258,7 @@ All tutorials are based on runnable examples in the repository: - [examples/using_additional_features.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/using_additional_features.py) - [examples/advanced_training.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/advanced_training.py) - [examples/simple_explainability_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/simple_explainability_example.py) +- [examples/mlflow_logging_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/mlflow_logging_example.py) ### Jupyter Notebooks diff --git a/docs/source/tutorials/mlflow_integration.md b/docs/source/tutorials/mlflow_integration.md new file mode 100644 index 0000000..0e8c2ef --- /dev/null +++ b/docs/source/tutorials/mlflow_integration.md @@ -0,0 +1,645 @@ +# MLflow Integration + +Learn how to track experiments, log training metrics, and deploy models with MLflow. + +## Learning Objectives + +By the end of this tutorial, you will be able to: + +- Log training metrics per epoch using PyTorch Lightning's MLFlowLogger +- Create a pyfunc wrapper for flexible model deployment +- Export models with full Captum explainability support +- Handle flexible input formats in production environments +- Use inference parameters (`top_k`, `explain`) for advanced predictions +- Use the MLflow UI to visualize training progress + +## Prerequisites + +Before starting this tutorial, you should: + +- Complete the {doc}`mixed_features` tutorial +- Have MLflow installed (`pip install mlflow` or `uv add mlflow`) +- Have Captum installed for explainability (`pip install captum`) +- Understand PyTorch Lightning basics +- Be familiar with model deployment concepts + +## Overview + +MLflow is an open-source platform for managing the machine learning lifecycle. In this tutorial, we'll integrate torchTextClassifiers with MLflow to: + +1. **Track experiments**: Log parameters, metrics, and artifacts for reproducibility +2. **Monitor training**: Visualize loss and accuracy curves in real-time +3. **Deploy models**: Create portable models with explainability support + +### Why Full PyTorch Models? + +Our approach saves the full PyTorch model (not TorchScript) to enable: + +- **Captum explainability**: Token-level attributions via Integrated Gradients +- **Full model access**: Access to embedding layers for gradient computation +- **Flexible deployment**: Works with any deployment environment + +The trade-off is that `torchTextClassifiers` must be installed in the inference environment. The pip requirements are automatically included in the logged model. + +## Running the Example + +A complete working example is provided in the `examples/mlflow_logging/` directory: + +```bash +# From the repository root +uv run examples/mlflow_logging/run_example.py + +# Or with explicit dependencies +uv run --extra huggingface --with mlflow --with captum \ + examples/mlflow_logging/run_example.py +``` + +### Example Structure + +``` +examples/mlflow_logging/ +├── __init__.py # Package exports +├── pyfunc_wrapper.py # TextClassifierWrapper class +├── run_example.py # Main training script +├── README.md # Detailed documentation +└── data/ + ├── train.csv # Training data (30 samples) + ├── val.csv # Validation data (6 samples) + └── test.csv # Test data (6 samples) +``` + +## Complete Code + +Here's the core workflow. See `examples/mlflow_logging/run_example.py` for the full implementation. + +```python +import json +import os +import tempfile + +import mlflow +import mlflow.pyfunc +import numpy as np +import pandas as pd +import torch +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, ParamSchema, ParamSpec +from pytorch_lightning.loggers import MLFlowLogger + +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.tokenizers import WordPieceTokenizer + +# Import the pyfunc wrapper +from examples.mlflow_logging import TextClassifierWrapper + + +def main(): + # Step 1: Load data from CSV files + data_dir = Path("examples/mlflow_logging/data") + train_df = pd.read_csv(data_dir / "train.csv") + X_train = np.array([[row["text"], row["category"]] for _, row in train_df.iterrows()], dtype=object) + y_train = train_df["label"].values + + # Step 2: Train tokenizer + tokenizer = WordPieceTokenizer(vocab_size=1000, output_dim=64) + tokenizer.train(X_train[:, 0].tolist()) + + # Step 3: Configure model + model_config = ModelConfig( + embedding_dim=32, + num_classes=2, + categorical_vocabulary_sizes=[3], + categorical_embedding_dims=8, + ) + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) + + # Step 4: Create MLFlowLogger + mlflow_logger = MLFlowLogger( + experiment_name="text-classification", + log_model=False, + ) + + # Step 5: Train with metric logging + training_config = TrainingConfig( + num_epochs=15, batch_size=8, lr=1e-3, + trainer_params={"logger": mlflow_logger}, + ) + classifier.train(X_train, y_train, training_config=training_config, X_val=X_val, y_val=y_val) + + # Step 6: Log artifacts to MLflow + run_id = mlflow_logger.run_id + with mlflow.start_run(run_id=run_id): + mlflow.log_params({"embedding_dim": 32, "num_classes": 2, "vocab_size": 1000}) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save tokenizer (HuggingFace format) + tokenizer_path = os.path.join(tmpdir, "tokenizer") + classifier.tokenizer.tokenizer.save_pretrained(tokenizer_path) + + # Save PyTorch model (full model for Captum support) + model_path = os.path.join(tmpdir, "model.pt") + pytorch_model = classifier.pytorch_model + pytorch_model.eval() + torch.save(pytorch_model, model_path) + + # Save configs + with open(os.path.join(tmpdir, "label_mapping.json"), "w") as f: + json.dump({"0": "negative", "1": "positive"}, f) + with open(os.path.join(tmpdir, "model_config.json"), "w") as f: + json.dump({"output_dim": 64, "num_classes": 2, "categorical_columns": ["category"]}, f) + + # Define params schema for inference parameters + params_schema = ParamSchema([ + ParamSpec("top_k", DataType.long, default=1), + ParamSpec("explain", DataType.boolean, default=False), + ]) + signature = ModelSignature(inputs=None, outputs=None, params=params_schema) + + # Log pyfunc model + mlflow.pyfunc.log_model( + artifact_path="model", + python_model=TextClassifierWrapper(), + artifacts={ + "model": model_path, + "tokenizer": tokenizer_path, + "label_mapping": os.path.join(tmpdir, "label_mapping.json"), + "model_config": os.path.join(tmpdir, "model_config.json"), + }, + pip_requirements=["torch>=2.0", "transformers>=4.30", "pandas", "numpy", "captum", "torchTextClassifiers"], + signature=signature, + ) + + # Step 7: Test inference + model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") + print(model.predict([["Great product!", 0]])) + +if __name__ == "__main__": + main() +``` + +## Step-by-Step Walkthrough + +### Step 1: Understanding the PyFunc Wrapper + +The `TextClassifierWrapper` class is the heart of our deployment strategy. It inherits from `mlflow.pyfunc.PythonModel` and implements two required methods: + +```python +class TextClassifierWrapper(mlflow.pyfunc.PythonModel): + def load_context(self, context): + """Called once when the model is loaded.""" + import torch + from transformers import AutoTokenizer + + # Load full PyTorch model (enables Captum explainability) + self.model = torch.load(context.artifacts["model"], weights_only=False) + self.model.eval() + + # Load HuggingFace tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(context.artifacts["tokenizer"]) + + # Load configurations + with open(context.artifacts["label_mapping"]) as f: + self.label_mapping = json.load(f) + + def predict(self, context, model_input, params=None): + """Called for each prediction request.""" + # Process input, run inference, return results + ... +``` + +**Key Points:** +- `load_context()` is called **once** when the model is loaded +- `predict()` is called for **each inference request** +- The `context.artifacts` dictionary maps artifact names to file paths +- The `params` argument receives inference parameters like `top_k` and `explain` + +:::{note} +The wrapper imports `torch` and `transformers` inside the methods. This ensures the dependencies are only loaded when the model is used, not when it's defined. +::: + +### Step 2: Flexible Input Handling + +The `_parse_input()` method accepts multiple input formats for production flexibility: + +```python +def _parse_input(self, model_input): + # Single string + if isinstance(model_input, str): + return [model_input], [[0] * num_cat_features] + + # DataFrame with named columns + if isinstance(model_input, pd.DataFrame): + if "text" in model_input.columns: + texts = model_input["text"].tolist() + # ... extract categories + + # List of lists: [["text1", cat1], ["text2", cat2]] + if isinstance(model_input, list): + if isinstance(model_input[0], list): + texts = [row[0] for row in model_input] + categories = [row[1:] for row in model_input] +``` + +**Supported Formats:** + +| Format | Example | Use Case | +|--------|---------|----------| +| Single string | `"Great product!"` | Quick single prediction | +| List of strings | `["Text 1", "Text 2"]` | Batch without categories | +| List of lists | `[["Text", 0], ["Text", 1]]` | Batch with categories | +| DataFrame | `pd.DataFrame({"text": [...], "category": [...]})` | Production pipelines | + +### Step 3: Setting Up MLFlowLogger + +PyTorch Lightning's `MLFlowLogger` automatically logs training metrics: + +```python +from pytorch_lightning.loggers import MLFlowLogger + +mlflow_logger = MLFlowLogger( + experiment_name="text-classification", # Groups related runs + log_model=False, # We'll log manually with pyfunc +) + +training_config = TrainingConfig( + num_epochs=15, + batch_size=8, + lr=1e-3, + trainer_params={"logger": mlflow_logger}, # Pass to Lightning trainer +) +``` + +**Metrics Logged Automatically:** +- `train_loss_step` - Loss at each training step +- `train_loss_epoch` - Average loss per epoch +- `train_accuracy` - Training accuracy per epoch +- `val_loss` - Validation loss per epoch +- `val_accuracy` - Validation accuracy per epoch + +:::{tip} +The `experiment_name` parameter groups related runs together. Use descriptive names like `"sentiment-analysis-v2"` or `"product-reviews"`. +::: + +### Step 4: Saving the Model for Captum Support + +We save the full PyTorch model (not TorchScript) to enable Captum explainability: + +```python +# Get the PyTorch model +pytorch_model = classifier.pytorch_model +pytorch_model.eval() + +# Save the full model (enables Captum attribution methods) +torch.save(pytorch_model, "model.pt") +``` + +**Why Full PyTorch Instead of TorchScript?** +- TorchScript models don't propagate gradients properly +- Captum's `LayerIntegratedGradients` requires gradient access +- The full model preserves access to embedding layers + +:::{note} +The trade-off is that `torchTextClassifiers` must be installed at inference time, since `torch.save()` pickles the model class which includes references to the original module. +::: + +### Step 5: Defining the Params Schema + +To enable inference parameters (`top_k`, `explain`), define a params schema: + +```python +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, ParamSchema, ParamSpec + +params_schema = ParamSchema([ + ParamSpec("top_k", DataType.long, default=1), + ParamSpec("explain", DataType.boolean, default=False), +]) + +signature = ModelSignature( + inputs=None, # Flexible input formats + outputs=None, # Output varies based on params + params=params_schema, +) +``` + +:::{important} +Without a params schema, MLflow ignores the `params` argument in `predict()`. The schema explicitly declares which parameters are accepted. +::: + +### Step 6: Logging the PyFunc Model + +Finally, we log everything to MLflow: + +```python +with mlflow.start_run(run_id=mlflow_logger.run_id): + # Log hyperparameters + mlflow.log_params({ + "embedding_dim": embedding_dim, + "num_classes": num_classes, + "vocab_size": vocab_size, + }) + + # Log the pyfunc model with all artifacts + mlflow.pyfunc.log_model( + artifact_path="model", + python_model=TextClassifierWrapper(), + artifacts={ + "model": model_path, + "tokenizer": tokenizer_path, + "label_mapping": label_mapping_path, + "model_config": model_config_path, + }, + pip_requirements=[ + "torch>=2.0", + "transformers>=4.30", + "pandas", + "numpy", + "captum", + "torchTextClassifiers", + ], + signature=signature, + ) +``` + +**Artifacts Saved:** + +| Artifact | Format | Purpose | +|----------|--------|---------| +| `model.pt` | PyTorch | The full PyTorch model (for Captum) | +| `tokenizer/` | HuggingFace | Vocabulary and tokenizer config | +| `label_mapping.json` | JSON | `{"0": "negative", "1": "positive"}` | +| `model_config.json` | JSON | Model configuration | + +## Using the MLflow UI + +After training, launch the MLflow UI to visualize your experiments: + +```bash +mlflow ui +``` + +Then open http://localhost:5000 in your browser. + +### Viewing Training Curves + +1. Select your experiment ("text-classification") +2. Click on a run +3. Go to the "Metrics" tab +4. View `train_loss`, `val_loss`, `train_accuracy`, `val_accuracy` over epochs + +### Comparing Runs + +1. Select multiple runs using checkboxes +2. Click "Compare" +3. View side-by-side metrics and parameters +4. Identify the best-performing configuration + +### Accessing Artifacts + +1. Click on a run +2. Go to the "Artifacts" tab +3. Browse the saved files: + - `model/` - Contains the pyfunc model + - `model/artifacts/` - PyTorch model, tokenizer, configs + +## Loading and Using the Model + +Once logged, load the model anywhere MLflow is installed: + +```python +import mlflow + +# Load the model +model = mlflow.pyfunc.load_model("runs://model") + +# Make predictions with flexible input formats +model.predict("Great product!") +model.predict(["Text 1", "Text 2"]) +model.predict([["Text with category", 0]]) +model.predict(pd.DataFrame({"text": ["Hello"], "category": [1]})) +``` + +:::{note} +Replace `` with the actual run ID from your training session. You can find it in the MLflow UI or in the training output. +::: + +## Inference Parameters + +The pyfunc wrapper supports additional inference parameters via the `params` argument: + +### top_k: Multiple Predictions + +Return the top k predictions instead of just the best one: + +```python +# Get top 3 predictions per sample +result = model.predict( + [["This product might be good or bad.", 0]], + params={"top_k": 3} +) +print(result) +# prediction_1 confidence_1 prediction_2 confidence_2 +# 0 positive 0.52 negative 0.48 +``` + +**Output Columns with top_k:** +- `prediction_1`, `confidence_1` - Best prediction +- `prediction_2`, `confidence_2` - Second best +- ... up to `prediction_k`, `confidence_k` + +:::{tip} +If `top_k` exceeds the number of classes, it's automatically limited to the number of classes. +::: + +### explain: Token Attributions + +Get token-level attributions to understand which parts of the input influenced the prediction: + +```python +# Get token attributions +result = model.predict( + [["Amazing quality product!"]], + params={"explain": True} +) +print(result) +# prediction confidence tokens attributions +# 0 positive 0.92 [amazing, quality, prod...] [0.35, 0.28, 0.15, ...] + +# Analyze which tokens were most important +tokens = result.iloc[0]["tokens"] +attributions = result.iloc[0]["attributions"] +for tok, attr in sorted(zip(tokens, attributions), key=lambda x: x[1], reverse=True)[:5]: + if tok != "[PAD]": + print(f" {tok}: {attr:.4f}") +# Output: +# amazing: 0.3500 +# quality: 0.2800 +# product: 0.1500 +# !: 0.0200 +``` + +**Output Columns with explain:** +- `prediction`, `confidence` - The prediction result +- `tokens` - List of tokenized words +- `attributions` - Normalized importance scores (sum to 1.0) + +:::{note} +The attributions use Captum's `LayerIntegratedGradients`, which computes gradient-based feature importance through the embedding layer. +::: + +### Combined Parameters + +Use multiple parameters together: + +```python +# Get top 2 predictions with explanations +result = model.predict( + [["Great product but shipping was slow."]], + params={"top_k": 2, "explain": True} +) +print(result.columns.tolist()) +# ['prediction_1', 'confidence_1', 'prediction_2', 'confidence_2', 'tokens', 'attributions'] +``` + +### Summary of Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `top_k` | int | 1 | Number of top predictions to return | +| `explain` | bool | False | Include token attributions | + +**Example:** +```python +# Default behavior (top_k=1, explain=False) +model.predict(data) + +# Get top 3 predictions +model.predict(data, params={"top_k": 3}) + +# Get attributions +model.predict(data, params={"explain": True}) + +# Get both +model.predict(data, params={"top_k": 2, "explain": True}) +``` + +## Common Issues and Solutions + +### Issue: "Run not found" Error + +**Symptoms:** `MlflowException: Run 'xxx' not found` + +**Solutions:** +1. Ensure you're using the same tracking URI: + ```python + mlflow.set_tracking_uri("file:./mlruns") # Local + mlflow.set_tracking_uri("http://localhost:5000") # Remote server + ``` +2. Check the `mlruns/` directory exists in your working directory +3. Verify the run ID is correct + +### Issue: Uniform Attributions + +**Symptoms:** All attribution values are the same (e.g., 0.0156) + +**Solutions:** +1. Ensure you're using the full PyTorch model, not TorchScript +2. The model must be loaded with `torch.load()`, not `torch.jit.load()` +3. Verify Captum has access to the embedding layer + +### Issue: Missing Dependencies at Inference + +**Symptoms:** `ModuleNotFoundError` when loading the model + +**Solutions:** +1. Install the required packages: + ```bash + pip install torch transformers pandas numpy captum torchTextClassifiers + ``` +2. Or use the generated `requirements.txt`: + ```bash + pip install -r mlruns///artifacts/model/requirements.txt + ``` + +### Issue: Input Format Errors + +**Symptoms:** `ValueError: Unsupported input format` + +**Solutions:** +1. Check your input is one of the supported formats +2. For DataFrames, ensure the "text" column exists +3. For lists of lists, ensure each inner list has `[text, category]` format + +### Issue: Params Ignored + +**Symptoms:** `params={"top_k": 3}` has no effect + +**Solutions:** +1. Ensure the model was logged with a `ModelSignature` that includes a `ParamSchema` +2. The params schema must be defined before logging: + ```python + params_schema = ParamSchema([ + ParamSpec("top_k", DataType.long, default=1), + ParamSpec("explain", DataType.boolean, default=False), + ]) + signature = ModelSignature(inputs=None, outputs=None, params=params_schema) + ``` + +## Production Deployment + +### Serving with MLflow + +Start a REST API server: + +```bash +mlflow models serve -m runs://model -p 5001 +``` + +Make predictions via HTTP: + +```bash +curl -X POST http://localhost:5001/invocations \ + -H "Content-Type: application/json" \ + -d '{"inputs": [["Great product!", 0]]}' +``` + +### Loading in a Fresh Environment + +```python +# In a new environment with the required dependencies +import mlflow + +# Load from MLflow tracking server +mlflow.set_tracking_uri("http://your-mlflow-server:5000") +model = mlflow.pyfunc.load_model("models:/text-classifier/Production") + +# Or load from a local path +model = mlflow.pyfunc.load_model("/path/to/model") + +# Make predictions +result = model.predict(["This is great!"]) +print(result) +# prediction confidence +# 0 positive 0.89 +``` + +## Next Steps + +- **MLflow Model Registry**: Learn to version and stage models with [MLflow Model Registry](https://mlflow.org/docs/latest/model-registry.html) +- **Remote Tracking**: Set up a [remote MLflow tracking server](https://mlflow.org/docs/latest/tracking.html#tracking-server) +- **Custom Metrics**: Add custom metrics by modifying the Lightning module +- **Batch Inference**: Process large datasets efficiently with batch predictions + +## Summary + +In this tutorial, you learned: + +- How to integrate PyTorch Lightning with MLflow for automatic metric logging +- How to create a pyfunc wrapper for flexible model deployment +- How to save models with full Captum explainability support +- How to handle flexible input formats in production +- How to use inference parameters (`top_k` for multiple predictions, `explain` for attributions) +- How to use the MLflow UI to visualize training progress +- How to serve models via REST API + +Your models are now ready for production deployment with full experiment tracking, explainability, and reproducibility! diff --git a/examples/mlflow_logging/README.md b/examples/mlflow_logging/README.md new file mode 100644 index 0000000..3c60c6d --- /dev/null +++ b/examples/mlflow_logging/README.md @@ -0,0 +1,194 @@ +# MLflow Logging Example + +This example demonstrates how to train a text classifier with torchTextClassifiers +and log it to MLflow with a pyfunc wrapper for flexible deployment. + +## Features + +- **Training with MLflow Logging**: Uses PyTorch Lightning's `MLFlowLogger` to log + training metrics (loss, accuracy) per epoch automatically +- **Flexible Input Formats**: The pyfunc wrapper accepts strings, lists, DataFrames +- **Top-K Predictions**: Return multiple predictions per sample with confidence scores +- **Explainability**: Token-level attributions via Captum's Integrated Gradients + +## Prerequisites + +Install the required dependencies: + +```bash +# Using uv (recommended) +uv add torchTextClassifiers mlflow captum + +# Or using pip +pip install torchTextClassifiers mlflow captum +``` + +## Quick Start + +Run the example script: + +```bash +# From the repository root +uv run examples/mlflow_logging/run_example.py + +# Or with explicit dependencies +uv run --extra huggingface --with mlflow --with captum \ + examples/mlflow_logging/run_example.py +``` + +## Package Structure + +``` +mlflow_logging/ +├── __init__.py # Package exports (TextClassifierWrapper) +├── pyfunc_wrapper.py # MLflow pyfunc wrapper class +├── run_example.py # Main example script +├── README.md # This file +└── data/ + ├── train.csv # Training data (30 samples) + ├── val.csv # Validation data (6 samples) + └── test.csv # Test data (6 samples) +``` + +## Data Format + +The CSV files contain product reviews with sentiment labels: + +| Column | Description | +|----------|------------------------------------------------| +| text | Product review text | +| category | Product category (0=electronics, 1=clothing, 2=books) | +| label | Sentiment label (0=negative, 1=positive) | + +## Using the Logged Model + +After running the example, load and use the model: + +```python +import mlflow + +# Load the model (replace with actual run ID) +model = mlflow.pyfunc.load_model("runs://model") + +# Basic prediction - single string +model.predict("Great product!") + +# Multiple samples as list of strings +model.predict(["Love it!", "Terrible quality."]) + +# With categorical features as list of lists +model.predict([["Amazing electronics!", 0], ["Bad fit.", 1]]) + +# DataFrame input +import pandas as pd +df = pd.DataFrame({"text": ["Nice book!"], "category": [2]}) +model.predict(df) +``` + +## Inference Parameters + +The pyfunc wrapper supports inference-time parameters via the `params` argument: + +### `top_k` - Multiple Predictions + +Return the top-k predictions with confidence scores: + +```python +# Get top 3 predictions per sample +result = model.predict(data, params={"top_k": 3}) + +# Result columns: prediction_1, confidence_1, prediction_2, confidence_2, ... +print(result) +# prediction_1 confidence_1 prediction_2 confidence_2 ... +# 0 positive 0.85 negative 0.15 ... +``` + +### `explain` - Token Attributions + +Get token-level importance scores using Captum: + +```python +# Get explanations +result = model.predict(data, params={"explain": True}) + +# Result includes tokens and attributions columns +print(result["tokens"][0]) # ['[CLS]', 'amazing', 'product', '!', '[SEP]', ...] +print(result["attributions"][0]) # [0.05, 0.45, 0.30, 0.10, 0.05, ...] +``` + +### Combined Parameters + +```python +# Top-2 predictions with explanations +result = model.predict(data, params={"top_k": 2, "explain": True}) +``` + +## How It Works + +### Training Flow + +1. **Data Loading**: Reads CSV files into numpy arrays +2. **Tokenizer Training**: Trains a WordPiece tokenizer on training texts +3. **Model Configuration**: Sets up embedding dimensions, categorical features +4. **MLflow Logger**: Creates `MLFlowLogger` for automatic metric logging +5. **Training**: Trains the classifier with validation monitoring +6. **Artifact Export**: Saves model, tokenizer, and config files +7. **Model Logging**: Logs pyfunc model with all artifacts to MLflow + +### Pyfunc Wrapper + +The `TextClassifierWrapper` class: + +- **`load_context()`**: Loads PyTorch model, HuggingFace tokenizer, and configs +- **`_parse_input()`**: Converts various input formats to (texts, categories) +- **`predict()`**: Runs inference with optional top_k and explain parameters +- **`_predict_with_explain()`**: Uses Captum for token attributions + +## Metrics Logged + +The example logs these metrics to MLflow: + +| Metric | Description | +|--------|-------------| +| `train_loss_step` | Training loss per batch | +| `train_loss_epoch` | Training loss per epoch | +| `train_accuracy` | Training accuracy per epoch | +| `val_loss` | Validation loss per epoch | +| `val_accuracy` | Validation accuracy per epoch | +| `final_train_accuracy` | Final training accuracy | +| `final_val_accuracy` | Final validation accuracy | +| `test_accuracy` | Test set accuracy | + +## Troubleshooting + +### FutureWarning about filesystem tracking + +``` +FutureWarning: Relying on the default value of `tracking_uri` is deprecated +``` + +This is a non-breaking warning from MLflow. To suppress it, set the tracking URI: + +```python +mlflow.set_tracking_uri("sqlite:///mlruns.db") # Or your preferred backend +``` + +### Model requires torchTextClassifiers + +The model uses `torch.save()` which pickles the model class. This requires +`torchTextClassifiers` to be installed when loading. The pip requirements +are included in the logged model metadata. + +### Attributions are uniform + +Ensure you're using the full PyTorch model (not TorchScript). TorchScript models +don't propagate gradients properly for Captum attribution methods. + +## Customization + +To adapt this example for your own data: + +1. Replace the CSV files in `data/` with your own data +2. Update `category_mapping` in `run_example.py` for your categories +3. Update `label_mapping` for your class labels +4. Adjust model hyperparameters as needed diff --git a/examples/mlflow_logging/__init__.py b/examples/mlflow_logging/__init__.py new file mode 100644 index 0000000..3c47258 --- /dev/null +++ b/examples/mlflow_logging/__init__.py @@ -0,0 +1,49 @@ +""" +MLflow Integration Example +========================== + +This package demonstrates how to train a text classifier with torchTextClassifiers +and log it to MLflow with a pyfunc wrapper for flexible deployment. + +Features: + - Train a text classifier with text + categorical features + - Log training metrics per epoch via PyTorch Lightning's MLFlowLogger + - Create a pyfunc wrapper for inference with multiple input formats + - Use Captum for model explainability (token attributions) + +Package Structure: + - pyfunc_wrapper.py: TextClassifierWrapper class for MLflow pyfunc + - run_example.py: Main script demonstrating the complete workflow + - data/: CSV files with sample training, validation, and test data + +Usage: + Run the example script: + + .. code-block:: bash + + uv run examples/mlflow_logging/run_example.py + + Or import the wrapper for custom use: + + .. code-block:: python + + from examples.mlflow_logging import TextClassifierWrapper + +Example: + After running the example, load the model: + + .. code-block:: python + + import mlflow + model = mlflow.pyfunc.load_model("runs://model") + + # Basic prediction + model.predict(["Great product!"]) + + # With parameters + model.predict(data, params={"top_k": 3, "explain": True}) +""" + +from .pyfunc_wrapper import TextClassifierWrapper + +__all__ = ["TextClassifierWrapper"] diff --git a/examples/mlflow_logging/data/test.csv b/examples/mlflow_logging/data/test.csv new file mode 100644 index 0000000..bbc3224 --- /dev/null +++ b/examples/mlflow_logging/data/test.csv @@ -0,0 +1,7 @@ +text,category,label +"This camera takes stunning photos!",0,1 +"Product stopped working after a week.",0,0 +"Most comfortable sweater I've ever worn.",1,1 +"Size runs way too small.",1,0 +"An inspiring and beautiful story.",2,1 +"Confusing plot and bad writing.",2,0 diff --git a/examples/mlflow_logging/data/train.csv b/examples/mlflow_logging/data/train.csv new file mode 100644 index 0000000..d991d43 --- /dev/null +++ b/examples/mlflow_logging/data/train.csv @@ -0,0 +1,31 @@ +text,category,label +"This phone has amazing battery life and great camera!",0,1 +"Best laptop I've ever owned, super fast and reliable.",0,1 +"The sound quality of these headphones is incredible!",0,1 +"Love this tablet, perfect for reading and gaming.",0,1 +"Great smart watch with accurate fitness tracking.",0,1 +"Phone screen cracked after one week, terrible quality.",0,0 +"Laptop overheats constantly and crashes often.",0,0 +"Headphones broke after a month, waste of money.",0,0 +"Tablet is slow and battery drains too fast.",0,0 +"Watch stopped working after firmware update.",0,0 +"Perfect fit and comfortable material, love it!",1,1 +"Beautiful dress, exactly as pictured, great quality.",1,1 +"These jeans are so comfortable and durable.",1,1 +"Best jacket I own, keeps me warm in winter.",1,1 +"Shoes are stylish and very comfortable for walking.",1,1 +"Shirt shrunk after first wash, poor quality fabric.",1,0 +"Color faded quickly, not worth the price.",1,0 +"Doesn't fit as described, returning immediately.",1,0 +"Material feels cheap and uncomfortable.",1,0 +"Stitching came undone after wearing twice.",1,0 +"Couldn't put this book down, absolutely captivating!",2,1 +"Well written and thought-provoking, highly recommend.",2,1 +"Amazing story with unforgettable characters.",2,1 +"This book changed my perspective on life.",2,1 +"Beautifully written, a masterpiece of literature.",2,1 +"Boring plot, couldn't finish reading it.",2,0 +"Poorly written with many grammatical errors.",2,0 +"Disappointing ending ruined the whole book.",2,0 +"Characters are flat and uninteresting.",2,0 +"Overrated, don't understand the hype.",2,0 diff --git a/examples/mlflow_logging/data/val.csv b/examples/mlflow_logging/data/val.csv new file mode 100644 index 0000000..a9f9553 --- /dev/null +++ b/examples/mlflow_logging/data/val.csv @@ -0,0 +1,7 @@ +text,category,label +"Excellent product quality and fast shipping!",0,1 +"Broken on arrival, very disappointed.",0,0 +"Fits perfectly and looks great!",1,1 +"Terrible quality, falling apart already.",1,0 +"A wonderful read from start to finish.",2,1 +"Waste of time, very boring.",2,0 diff --git a/examples/mlflow_logging/pyfunc_wrapper.py b/examples/mlflow_logging/pyfunc_wrapper.py new file mode 100644 index 0000000..45650f1 --- /dev/null +++ b/examples/mlflow_logging/pyfunc_wrapper.py @@ -0,0 +1,406 @@ +""" +PyFunc Wrapper for Text Classification Models +============================================== + +This module provides an MLflow pyfunc wrapper for text classification models +trained with torchTextClassifiers. The wrapper enables: + +- **Flexible input formats**: Accept strings, lists, DataFrames +- **Multiple predictions**: Return top-k predictions per sample +- **Explainability**: Token-level attributions via Captum + +The wrapper is designed to work with models logged to MLflow and can be +loaded in any environment with the required dependencies. + +Dependencies: + - torch>=2.0 + - transformers>=4.30 + - pandas + - numpy + - captum (for explainability) + - torchTextClassifiers (for model class) + +Example: + >>> import mlflow + >>> model = mlflow.pyfunc.load_model("runs://model") + >>> model.predict(["Great product!"]) + >>> model.predict(data, params={"top_k": 3, "explain": True}) +""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import mlflow.pyfunc +import numpy as np +import pandas as pd + + +class TextClassifierWrapper(mlflow.pyfunc.PythonModel): + """ + MLflow pyfunc wrapper for text classification. + + This wrapper loads a PyTorch model and HuggingFace tokenizer, + enabling inference with full explainability support via Captum. + + Attributes: + model: The loaded PyTorch text classification model. + tokenizer: HuggingFace tokenizer for text preprocessing. + label_mapping: Dict mapping class indices to label names. + config: Model configuration (output_dim, num_classes, etc.). + categorical_mapping: Optional mapping for categorical features. + + Supported Input Formats: + - Single string: "text" + - List of strings: ["text1", "text2"] + - List of lists: [["text1", cat1], ["text2", cat2]] + - DataFrame with columns: pd.DataFrame({"text": [...], "category": [...]}) + - DataFrame positional: pd.DataFrame([["text1", 0], ["text2", 1]]) + + Inference Parameters (via params dict): + - top_k (int): Number of top predictions to return (default: 1) + - explain (bool): Return token attributions (default: False) + + Example: + >>> # Basic prediction + >>> model.predict(["Great product!"]) + + >>> # Top-3 predictions + >>> model.predict(data, params={"top_k": 3}) + + >>> # With explainability + >>> model.predict(data, params={"explain": True}) + """ + + def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None: + """ + Load model artifacts from MLflow context. + + This method is called once when the model is loaded. It initializes: + - The PyTorch model from the saved checkpoint + - The HuggingFace tokenizer + - Label mapping and model configuration + + Args: + context: MLflow context containing artifact paths. + """ + import json + + import torch + from transformers import AutoTokenizer + + # ===================================================================== + # Load PyTorch Model + # ===================================================================== + # The model is saved as a full PyTorch model (not TorchScript) + # to enable Captum-based explainability + model_path = context.artifacts["model"] + self.model = torch.load(model_path, weights_only=False) + self.model.eval() + + # ===================================================================== + # Load Tokenizer + # ===================================================================== + # HuggingFace tokenizer saved with save_pretrained() + tokenizer_path = context.artifacts["tokenizer"] + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + # ===================================================================== + # Load Configurations + # ===================================================================== + # Label mapping: {"0": "negative", "1": "positive"} + with open(context.artifacts["label_mapping"]) as f: + self.label_mapping = json.load(f) + + # Model config: output_dim, num_classes, categorical_columns + with open(context.artifacts["model_config"]) as f: + self.config = json.load(f) + + # Optional: categorical feature mapping + if "categorical_mapping" in context.artifacts: + with open(context.artifacts["categorical_mapping"]) as f: + self.categorical_mapping = json.load(f) + else: + self.categorical_mapping = {} + + def _parse_input( + self, model_input: Any + ) -> Tuple[List[str], List[List[int]]]: + """ + Parse various input formats into (texts, categories). + + This method handles multiple input formats for flexibility in + production environments where data may come in different shapes. + + Args: + model_input: Input data in one of the supported formats. + + Returns: + Tuple of (texts, categories) where: + - texts: List of text strings + - categories: List of category values per sample + + Raises: + ValueError: If input format is not supported. + + Examples: + >>> # Single string + >>> texts, cats = self._parse_input("Hello world") + >>> # texts = ["Hello world"], cats = [[0]] + + >>> # List of lists with categories + >>> texts, cats = self._parse_input([["Text", 1], ["Other", 2]]) + >>> # texts = ["Text", "Other"], cats = [[1], [2]] + """ + num_cat_features = len(self.config.get("categorical_columns", [])) + + # ----------------------------------------------------------------- + # Case 1: Single string + # ----------------------------------------------------------------- + if isinstance(model_input, str): + texts = [model_input] + categories = [[0] * num_cat_features] if num_cat_features else [] + return texts, categories + + # ----------------------------------------------------------------- + # Case 2: DataFrame + # ----------------------------------------------------------------- + if isinstance(model_input, pd.DataFrame): + if "text" in model_input.columns: + # Named columns: DataFrame({"text": [...], "category": [...]}) + texts = model_input["text"].tolist() + cat_cols = self.config.get("categorical_columns", []) + if cat_cols and all(c in model_input.columns for c in cat_cols): + categories = model_input[cat_cols].values.tolist() + else: + categories = [[0] * num_cat_features] * len(texts) + else: + # Positional: first column = text, rest = categories + texts = model_input.iloc[:, 0].tolist() + if model_input.shape[1] > 1: + categories = model_input.iloc[:, 1:].values.tolist() + else: + categories = [[0] * num_cat_features] * len(texts) + return texts, categories + + # ----------------------------------------------------------------- + # Case 3: List or Array + # ----------------------------------------------------------------- + if isinstance(model_input, (list, np.ndarray)): + model_input = list(model_input) + + if len(model_input) == 0: + return [], [] + + # Check if first element is a list/array (batch of samples) + if isinstance(model_input[0], (list, np.ndarray)): + # List of lists: [["text1", cat1], ["text2", cat2]] + texts = [row[0] for row in model_input] + if len(model_input[0]) > 1: + categories = [list(row[1:]) for row in model_input] + else: + categories = [[0] * num_cat_features] * len(texts) + return texts, categories + + # First element is not a list + if all(isinstance(x, str) for x in model_input): + # List of strings: ["text1", "text2"] + return model_input, [[0] * num_cat_features] * len(model_input) + else: + # Single sample: ["text", cat1, cat2] + texts = [str(model_input[0])] + if len(model_input) > 1: + categories = [list(model_input[1:])] + else: + categories = [[0] * num_cat_features] + return texts, categories + + raise ValueError(f"Unsupported input format: {type(model_input)}") + + def predict( + self, + context: mlflow.pyfunc.PythonModelContext, + model_input: Any, + params: Optional[Dict[str, Any]] = None, + ) -> pd.DataFrame: + """ + Run inference on input data. + + This method handles both standard predictions and explainability. + When explain=True, it uses Captum's LayerIntegratedGradients to + compute token-level attributions. + + Args: + context: MLflow context (not used directly). + model_input: Input data in any supported format. + params: Optional inference parameters: + - top_k (int): Number of top predictions (default: 1) + - explain (bool): Return attributions (default: False) + + Returns: + DataFrame with predictions. Columns depend on params: + - top_k=1: "prediction", "confidence" + - top_k>1: "prediction_1", "confidence_1", ..., "prediction_k", "confidence_k" + - explain=True: adds "tokens", "attributions" columns + + Example: + >>> # Standard prediction + >>> result = model.predict(["Great product!"]) + >>> print(result) + # prediction confidence + # 0 positive 0.89 + + >>> # Top-3 with explanations + >>> result = model.predict(data, params={"top_k": 3, "explain": True}) + """ + import torch + + # ===================================================================== + # Parse Parameters + # ===================================================================== + params = params or {} + top_k = params.get("top_k", 1) + explain = params.get("explain", False) + + # ===================================================================== + # Parse Input + # ===================================================================== + texts, categories = self._parse_input(model_input) + + if len(texts) == 0: + return pd.DataFrame({"prediction": [], "confidence": []}) + + # ===================================================================== + # Tokenize Text + # ===================================================================== + tokens = self.tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=self.config["output_dim"], + return_tensors="pt", + ) + + # ===================================================================== + # Prepare Categorical Features + # ===================================================================== + num_cat_features = len(self.config.get("categorical_columns", [])) + if num_cat_features > 0 and categories: + categorical_vars = torch.tensor(categories, dtype=torch.long) + else: + categorical_vars = torch.zeros(len(texts), num_cat_features, dtype=torch.long) + + # Ensure top_k doesn't exceed number of classes + num_classes = self.config["num_classes"] + top_k = min(top_k, num_classes) + + # ===================================================================== + # Run Inference (with or without explainability) + # ===================================================================== + if explain: + # Use Captum for token-level attributions + probs, attributions_list, tokens_list = self._predict_with_explain( + tokens, categorical_vars, texts + ) + else: + # Standard inference without gradients + with torch.no_grad(): + logits = self.model( + tokens["input_ids"], tokens["attention_mask"], categorical_vars + ) + probs = torch.softmax(logits, dim=-1) + + # ===================================================================== + # Build Result DataFrame + # ===================================================================== + result_data = {} + + if top_k == 1: + # Single prediction per sample + predictions = probs.argmax(dim=-1) + confidence = probs.max(dim=-1).values + pred_labels = [self.label_mapping[str(p.item())] for p in predictions] + result_data["prediction"] = pred_labels + result_data["confidence"] = confidence.detach().numpy() + else: + # Top-k predictions per sample + top_probs, top_indices = probs.topk(k=top_k, dim=-1) + for k_idx in range(top_k): + pred_labels = [ + self.label_mapping[str(idx[k_idx].item())] for idx in top_indices + ] + result_data[f"prediction_{k_idx + 1}"] = pred_labels + result_data[f"confidence_{k_idx + 1}"] = ( + top_probs[:, k_idx].detach().numpy() + ) + + # Add attribution columns if explain=True + if explain: + result_data["tokens"] = tokens_list + result_data["attributions"] = attributions_list + + return pd.DataFrame(result_data) + + def _predict_with_explain( + self, + tokens: Dict[str, "torch.Tensor"], + categorical_vars: "torch.Tensor", + texts: List[str], + ) -> Tuple["torch.Tensor", List[List[float]], List[List[str]]]: + """ + Run inference with Captum explainability. + + Uses LayerIntegratedGradients to compute token-level attributions + by analyzing gradients through the embedding layer. + + Args: + tokens: Tokenized input with input_ids and attention_mask. + categorical_vars: Categorical feature tensor. + texts: Original text strings (for token decoding). + + Returns: + Tuple of (probs, attributions_list, tokens_list): + - probs: Softmax probabilities + - attributions_list: Normalized attribution scores per sample + - tokens_list: Decoded tokens per sample + """ + import torch + from captum.attr import LayerIntegratedGradients + + # Initialize Captum with the embedding layer + lig = LayerIntegratedGradients( + self.model, self.model.text_embedder.embedding_layer + ) + + # Forward pass to get predictions + with torch.no_grad(): + logits = self.model( + tokens["input_ids"], tokens["attention_mask"], categorical_vars + ) + probs = torch.softmax(logits, dim=-1) + + # Get predictions for attribution targets + predictions = probs.argmax(dim=-1) + + # Compute attributions using LayerIntegratedGradients + attributions = lig.attribute( + (tokens["input_ids"], tokens["attention_mask"], categorical_vars), + target=predictions, + ) + # Sum over embedding dimension to get per-token attributions + attributions = attributions.sum(dim=-1) # (batch_size, seq_len) + + # Normalize attributions per sample + attributions_list = [] + tokens_list = [] + for i in range(len(texts)): + attr = attributions[i].abs() + if attr.sum() > 0: + attr = attr / attr.sum() + attributions_list.append(attr.detach().tolist()) + + # Decode tokens for this sample + sample_tokens = self.tokenizer.convert_ids_to_tokens( + tokens["input_ids"][i].tolist() + ) + tokens_list.append(sample_tokens) + + return probs, attributions_list, tokens_list diff --git a/examples/mlflow_logging/run_example.py b/examples/mlflow_logging/run_example.py new file mode 100644 index 0000000..f2d3c12 --- /dev/null +++ b/examples/mlflow_logging/run_example.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python +""" +MLflow Logging Example - Main Script +===================================== + +This script demonstrates the complete workflow for training a text classifier +with torchTextClassifiers and logging it to MLflow with a pyfunc wrapper. + +The workflow consists of: +1. Loading training data from CSV files +2. Training a WordPiece tokenizer +3. Configuring and training the model with MLflow logging +4. Exporting artifacts (model, tokenizer, configs) +5. Logging a pyfunc model to MLflow +6. Testing inference with various input formats + +Usage: + uv run examples/mlflow_logging/run_example.py + + # Or with explicit dependencies: + uv run --extra huggingface --with mlflow --with captum \\ + examples/mlflow_logging/run_example.py + +Requirements: + - torchTextClassifiers[huggingface] + - mlflow + - captum (for explainability) + +Output: + - MLflow run with training metrics logged per epoch + - Logged pyfunc model ready for deployment + - Test predictions demonstrating various input formats +""" + +import json +import os +import tempfile +import warnings +from pathlib import Path + +import mlflow +import mlflow.pyfunc +import numpy as np +import pandas as pd +import torch +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, ParamSchema, ParamSpec +from pytorch_lightning.loggers import MLFlowLogger + +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.tokenizers import WordPieceTokenizer + +# Import the pyfunc wrapper from the same package +# Handle both direct script execution and module import +try: + from .pyfunc_wrapper import TextClassifierWrapper +except ImportError: + from pyfunc_wrapper import TextClassifierWrapper + +# Suppress common warnings for cleaner output +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning, module="mlflow") + + +# ============================================================================= +# Configuration (from environment variables or defaults) +# ============================================================================= +# These can be overridden via environment variables for Argo Workflows +# or other orchestration systems. + + +def get_config() -> dict: + """ + Get training configuration from environment variables or defaults. + + Environment Variables: + NUM_WORKERS: Number of data loader workers (default: 0) + NUM_EPOCHS: Number of training epochs (default: 15) + BATCH_SIZE: Training batch size (default: 8) + LEARNING_RATE: Learning rate (default: 0.001) + VOCAB_SIZE: Tokenizer vocabulary size (default: 1000) + EMBEDDING_DIM: Token embedding dimension (default: 32) + OUTPUT_DIM: Maximum sequence length (default: 64) + EXPERIMENT_NAME: MLflow experiment name (default: text-classification) + + Returns: + Dictionary with configuration values. + """ + return { + "num_workers": int(os.environ.get("NUM_WORKERS", 0)), + "num_epochs": int(os.environ.get("NUM_EPOCHS", 15)), + "batch_size": int(os.environ.get("BATCH_SIZE", 8)), + "learning_rate": float(os.environ.get("LEARNING_RATE", 0.001)), + "vocab_size": int(os.environ.get("VOCAB_SIZE", 1000)), + "embedding_dim": int(os.environ.get("EMBEDDING_DIM", 32)), + "output_dim": int(os.environ.get("OUTPUT_DIM", 64)), + "experiment_name": os.environ.get("EXPERIMENT_NAME", "text-classification"), + } + + +# ============================================================================= +# Data Loading +# ============================================================================= + + +def load_data(data_dir: Path) -> tuple: + """ + Load training, validation, and test data from CSV files. + + The CSV files should have columns: text, category, label + - text: The text content to classify + - category: Categorical feature (0=electronics, 1=clothing, 2=books) + - label: Target label (0=negative, 1=positive) + + Args: + data_dir: Path to the directory containing train.csv, val.csv, test.csv + + Returns: + Tuple of (X_train, y_train, X_val, y_val, X_test, y_test) where: + - X arrays have shape (n_samples, 2) with [text, category] + - y arrays have shape (n_samples,) with labels + + Example: + >>> data_dir = Path("examples/mlflow_logging/data") + >>> X_train, y_train, X_val, y_val, X_test, y_test = load_data(data_dir) + >>> print(f"Training samples: {len(X_train)}") + """ + def load_csv(filename: str) -> tuple: + """Load a single CSV file and convert to arrays.""" + df = pd.read_csv(data_dir / filename) + # X: [text, category] as object array + X = np.array( + [[row["text"], row["category"]] for _, row in df.iterrows()], + dtype=object + ) + # y: labels as integer array + y = df["label"].values + return X, y + + X_train, y_train = load_csv("train.csv") + X_val, y_val = load_csv("val.csv") + X_test, y_test = load_csv("test.csv") + + return X_train, y_train, X_val, y_val, X_test, y_test + + +# ============================================================================= +# Main Training and Logging Function +# ============================================================================= + + +def main(): + """ + Main function demonstrating the complete MLflow logging workflow. + + This function: + 1. Loads data from CSV files + 2. Trains a tokenizer and classifier + 3. Logs training metrics to MLflow via PyTorch Lightning + 4. Exports and logs a pyfunc model + 5. Tests the loaded model with various input formats + """ + print("=" * 60) + print("MLflow Logging Example") + print("=" * 60) + + # ========================================================================= + # Load Configuration + # ========================================================================= + config = get_config() + print("\nConfiguration:") + for key, value in config.items(): + print(f" {key}: {value}") + + # ========================================================================= + # Step 1: Load Data from CSV Files + # ========================================================================= + print("\n1. Loading data from CSV files...") + + # Get the directory where this script is located + script_dir = Path(__file__).parent + data_dir = script_dir / "data" + + X_train, y_train, X_val, y_val, X_test, y_test = load_data(data_dir) + + print(f" Training samples: {len(X_train)}") + print(f" Validation samples: {len(X_val)}") + print(f" Test samples: {len(X_test)}") + + # Extract text for tokenizer training + training_texts = X_train[:, 0].tolist() + + # ========================================================================= + # Step 2: Configure and Train Tokenizer + # ========================================================================= + print("\n2. Training tokenizer...") + + # Tokenizer hyperparameters (from config) + vocab_size = config["vocab_size"] + output_dim = config["output_dim"] + + tokenizer = WordPieceTokenizer(vocab_size=vocab_size, output_dim=output_dim) + tokenizer.train(training_texts) + + print(f" Vocabulary size: {len(tokenizer)}") + + # ========================================================================= + # Step 3: Configure Model + # ========================================================================= + print("\n3. Creating classifier...") + + # Model hyperparameters (from config) + embedding_dim = config["embedding_dim"] + num_classes = 2 # Binary classification (positive/negative) + categorical_vocab_sizes = [3] # 3 categories: electronics, clothing, books + categorical_embedding_dims = 8 # Dimension for categorical embeddings + + model_config = ModelConfig( + embedding_dim=embedding_dim, + num_classes=num_classes, + categorical_vocabulary_sizes=categorical_vocab_sizes, + categorical_embedding_dims=categorical_embedding_dims, + ) + + # Create the classifier + classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config) + + # Label and category mappings for inference + label_mapping = {"0": "negative", "1": "positive"} + category_mapping = {"category": {"electronics": 0, "clothing": 1, "books": 2}} + + # ========================================================================= + # Step 4: Configure MLflow Logger + # ========================================================================= + # PyTorch Lightning's MLFlowLogger automatically logs metrics per epoch: + # - train_loss_step, train_loss_epoch + # - train_accuracy + # - val_loss, val_accuracy + + print("\n4. Training with MLflow logging...") + + mlflow_logger = MLFlowLogger( + experiment_name=config["experiment_name"], + log_model=False, # We'll log manually with pyfunc wrapper + ) + + # Training hyperparameters (from config) + num_epochs = config["num_epochs"] + batch_size = config["batch_size"] + lr = config["learning_rate"] + num_workers = config["num_workers"] + + training_config = TrainingConfig( + num_epochs=num_epochs, + batch_size=batch_size, + lr=lr, + patience_early_stopping=5, + num_workers=num_workers, + trainer_params={"logger": mlflow_logger}, + ) + + # ========================================================================= + # Step 5: Train the Model + # ========================================================================= + # Training metrics are logged automatically per epoch + classifier.train( + X_train, + y_train, + training_config=training_config, + X_val=X_val, + y_val=y_val, + verbose=True, + ) + + # Get the run_id from MLFlowLogger to continue logging in the same run + run_id = mlflow_logger.run_id + + # ========================================================================= + # Step 6: Log Additional Metrics and Artifacts + # ========================================================================= + with mlflow.start_run(run_id=run_id): + # Log hyperparameters + mlflow.log_params({ + "embedding_dim": embedding_dim, + "num_classes": num_classes, + "vocab_size": vocab_size, + "output_dim": output_dim, + "num_epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": lr, + "num_workers": num_workers, + "categorical_vocab_sizes": str(categorical_vocab_sizes), + "categorical_embedding_dims": categorical_embedding_dims, + }) + + # ----------------------------------------------------------------- + # Evaluate on all sets + # ----------------------------------------------------------------- + print("\n5. Evaluating model...") + + def evaluate(X, y, name): + result = classifier.predict(X) + preds = result["prediction"].squeeze().numpy() + acc = (preds == y).mean() + print(f" {name} accuracy: {acc:.4f}") + return acc + + train_acc = evaluate(X_train, y_train, "Training") + val_acc = evaluate(X_val, y_val, "Validation") + test_acc = evaluate(X_test, y_test, "Test") + + # Log final evaluation metrics + mlflow.log_metrics({ + "final_train_accuracy": train_acc, + "final_val_accuracy": val_acc, + "test_accuracy": test_acc, + }) + + # ----------------------------------------------------------------- + # Prepare artifacts for logging + # ----------------------------------------------------------------- + print("\n6. Preparing artifacts for logging...") + + with tempfile.TemporaryDirectory() as tmpdir: + # Save tokenizer (HuggingFace format) + tokenizer_path = os.path.join(tmpdir, "tokenizer") + classifier.tokenizer.tokenizer.save_pretrained(tokenizer_path) + print(f" Saved tokenizer to {tokenizer_path}") + + # Save PyTorch model (full model for Captum support) + model_path = os.path.join(tmpdir, "model.pt") + pytorch_model = classifier.pytorch_model + pytorch_model.eval() + torch.save(pytorch_model, model_path) + print(f" Saved PyTorch model to {model_path}") + + # Save label mapping + label_mapping_path = os.path.join(tmpdir, "label_mapping.json") + with open(label_mapping_path, "w") as f: + json.dump(label_mapping, f) + print(f" Saved label mapping to {label_mapping_path}") + + # Save model config + model_config_path = os.path.join(tmpdir, "model_config.json") + config_dict = { + "output_dim": output_dim, + "num_classes": num_classes, + "embedding_dim": embedding_dim, + "categorical_columns": ["category"], + } + with open(model_config_path, "w") as f: + json.dump(config_dict, f) + print(f" Saved model config to {model_config_path}") + + # Save categorical mapping + categorical_mapping_path = os.path.join(tmpdir, "categorical_mapping.json") + with open(categorical_mapping_path, "w") as f: + json.dump(category_mapping, f) + print(f" Saved categorical mapping to {categorical_mapping_path}") + + # Define artifacts dictionary + artifacts = { + "model": model_path, + "tokenizer": tokenizer_path, + "label_mapping": label_mapping_path, + "model_config": model_config_path, + "categorical_mapping": categorical_mapping_path, + } + + # Define pip requirements + pip_requirements = [ + "torch>=2.0", + "transformers>=4.30", + "pandas", + "numpy", + "captum", + "torchTextClassifiers", + ] + + # ----------------------------------------------------------------- + # Define model signature with params schema + # ----------------------------------------------------------------- + # This enables the params argument (top_k, explain) in predict() + params_schema = ParamSchema([ + ParamSpec("top_k", DataType.long, default=1), + ParamSpec("explain", DataType.boolean, default=False), + ]) + signature = ModelSignature( + inputs=None, # Flexible input formats + outputs=None, # Output varies based on params + params=params_schema, + ) + + # ----------------------------------------------------------------- + # Log the pyfunc model + # ----------------------------------------------------------------- + print("\n7. Logging pyfunc model to MLflow...") + + mlflow.pyfunc.log_model( + artifact_path="model", + python_model=TextClassifierWrapper(), + artifacts=artifacts, + pip_requirements=pip_requirements, + signature=signature, + ) + + print(f"\n Run ID: {run_id}") + + # ===================================================================== + # Step 7: Test Model Loading and Inference + # ===================================================================== + print("\n8. Testing model loading and inference...") + + model_uri = f"runs:/{run_id}/model" + loaded_model = mlflow.pyfunc.load_model(model_uri) + + # ----------------------------------------------------------------- + # Test different input formats + # ----------------------------------------------------------------- + print("\n Testing different input formats:") + + # Format 1: List of lists + print("\n Format 1: List of lists [['text', category], ...]") + preds = loaded_model.predict([["This is an amazing product!", 0]]) + print(f" -> {preds.iloc[0]['prediction']} ({preds.iloc[0]['confidence']:.3f})") + + # Format 2: Multiple samples + print("\n Format 2: Multiple samples") + preds = loaded_model.predict([ + ["Great quality electronics!", 0], + ["Terrible fit, returning it.", 1], + ]) + for i, p in preds.iterrows(): + print(f" -> {p['prediction']} ({p['confidence']:.3f})") + + # Format 3: List of strings (text only) + print("\n Format 3: List of strings (text only)") + preds = loaded_model.predict(["Love this product!", "Worst purchase ever."]) + for i, p in preds.iterrows(): + print(f" -> {p['prediction']} ({p['confidence']:.3f})") + + # Format 4: Single string + print("\n Format 4: Single string") + preds = loaded_model.predict("Absolutely fantastic!") + print(f" -> {preds.iloc[0]['prediction']} ({preds.iloc[0]['confidence']:.3f})") + + # Format 5: DataFrame + print("\n Format 5: DataFrame with named columns") + test_df = pd.DataFrame({"text": ["Beautiful design!"], "category": [1]}) + preds = loaded_model.predict(test_df) + print(f" -> {preds.iloc[0]['prediction']} ({preds.iloc[0]['confidence']:.3f})") + + # ----------------------------------------------------------------- + # Test inference parameters + # ----------------------------------------------------------------- + print("\n Testing inference parameters:") + + # Test top_k parameter + print("\n Parameter: top_k=2 (get top 2 predictions)") + preds = loaded_model.predict( + [["This could be good or bad, not sure.", 0]], + params={"top_k": 2}, + ) + print(f" -> Top 1: {preds.iloc[0]['prediction_1']} ({preds.iloc[0]['confidence_1']:.3f})") + print(f" -> Top 2: {preds.iloc[0]['prediction_2']} ({preds.iloc[0]['confidence_2']:.3f})") + + # Test explain parameter + print("\n Parameter: explain=True (get token attributions)") + preds = loaded_model.predict( + [["Amazing quality product!", 0]], + params={"explain": True}, + ) + print(f" -> Prediction: {preds.iloc[0]['prediction']} ({preds.iloc[0]['confidence']:.3f})") + tokens = preds.iloc[0]["tokens"] + attributions = preds.iloc[0]["attributions"] + # Show top 5 most important tokens (excluding padding) + token_attr_pairs = [ + (t, a) for t, a in zip(tokens, attributions) if t != "[PAD]" + ] + token_attr_pairs.sort(key=lambda x: x[1], reverse=True) + print(" -> Top tokens by attribution:") + for tok, attr in token_attr_pairs[:5]: + print(f" '{tok}': {attr:.4f}") + + # Test combined top_k and explain + print("\n Parameters: top_k=2, explain=True (combined)") + preds = loaded_model.predict( + [["Great product but shipping was slow.", 0]], + params={"top_k": 2, "explain": True}, + ) + print(f" -> Top 1: {preds.iloc[0]['prediction_1']} ({preds.iloc[0]['confidence_1']:.3f})") + print(f" -> Top 2: {preds.iloc[0]['prediction_2']} ({preds.iloc[0]['confidence_2']:.3f})") + print(f" -> Tokens: {preds.iloc[0]['tokens'][:10]}...") + + # ========================================================================= + # Print Usage Instructions + # ========================================================================= + print("\n" + "=" * 60) + print("Done! Model logged successfully to MLflow.") + print(f"\nTo load and use the model:") + print() + print(f" import mlflow") + print(f" model = mlflow.pyfunc.load_model('runs:/{run_id}/model')") + print() + print(f" # Flexible input formats supported:") + print(f' model.predict([["Great product!", 0]]) # List of [text, category]') + print(f' model.predict(["Text 1", "Text 2"]) # List of strings') + print(f' model.predict("Single text") # Single string') + print(f' model.predict(pd.DataFrame(...)) # DataFrame') + print() + print(f" # Inference parameters (via params dict):") + print(f' model.predict(data, params={{"top_k": 3}}) # Get top 3 predictions') + print(f' model.predict(data, params={{"explain": True}}) # Get token attributions') + print(f' model.predict(data, params={{"top_k": 2, "explain": True}}) # Combined') + print("=" * 60) + + +if __name__ == "__main__": + main()