topic-classification
id: topic-classification title: Topic Classification with Triton sidebar_label: Topic Classification description: Deploy zero-shot topic classification models with NVIDIA Triton Inference Server
Triton Inference Server for Topic Classification
This documentation provides a comprehensive guide to understanding and running the zero-shot classification project using NVIDIA's Triton Inference Server with a FastAPI proxy.
Project Overview
This project implements a topic text classification service using NVIDIA's Triton Inference Server and a FastAPI application as a proxy. The system uses a pre-trained BART-MNLI model for zero-shot text classification, allowing users to classify text into arbitrary categories without specific training for those categories.
Key Components
- Triton Inference Server: Manages model execution and scaling
- FastAPI Proxy: Provides a user-friendly REST API for the classification service
- Zero-Shot Classification Pipeline: Split into three components:
- Tokenizer: Prepares text and labels for the model
- Model: Runs inference on the tokenized inputs
- Postprocessor: Transforms model outputs into user-friendly results
Project Structure
ai/TritonTopicClassification/
├── app/ # FastAPI application
│ ├── main.py # FastAPI server code
│ └── start.sh # Script to start both Triton and FastAPI
├── models/ # Triton model repository
│ ├── zero-shot/ # Ensemble model (combines all components)
│ │ ├── 1/ # Model version directory
│ │ └── config.pbtxt # Ensemble configuration
│ ├── zero-shot-model/ # ONNX model component
│ │ ├── 1/ # Model version with model.onnx file
│ │ └── config.pbtxt # Model configuration
│ ├── zero-shot-postprocessor/ # Python postprocessing component
│ │ ├── 1/ # Model version with model.py
│ │ └── config.pbtxt # Postprocessor configuration
│ └── zero-shot-tokenizer/ # Python tokenizer component
│ ├── 1/ # Model version with model.py and tokenizer_data/
│ └── config.pbtxt # Tokenizer configuration
└── notebooks/ # Test scripts and notebooks
└── test_fastapi_proxy.py # Script to test the API
Prerequisites
- NVIDIA GPU with CUDA support
- Docker
- NVIDIA Container Toolkit
- Python 3.10+ (for local development/testing)
Getting Started
Step 1: Build the Docker Image
Build the Docker image using the provided Dockerfile:
docker build -t triton-topic-fastapi:1
This builds an image based on NVIDIA's Triton Inference Server (version 24.12-py3) with the additional dependencies needed for the FastAPI proxy.
Step 2: Start the Container
Use Docker Compose to start the container:
docker-compose up -d
This will:
- Start the Triton container with GPU access
- Mount the ./models and ./app directories into the container
- Map the necessary ports:
- 8000: Triton HTTP API
- 8001: Triton gRPC API
- 8002: Triton Metrics API
- 8005: FastAPI Proxy API
Step 3: Compile the Models (if needed)
If you need to compile or export the models, you can:
Set up a local Python environment:
conda create -n compile python=3.12
conda activate compile
pip install -r requirements_compile.txt
Use the provided compilation scripts:
python compile_onnx.py
This will:
- Export a BART-MNLI model to ONNX format
- Create the necessary Triton configuration files
- Save the model files to the appropriate directories
Step 4: Start the Services
Connect to the container and start the Triton server and FastAPI proxy:
docker exec -it triton-topic /bin/bash
cd /app
./start.sh
This script:
- Starts the Triton server with the zero-shot model
- Waits for the Triton server to become ready
- Starts the FastAPI proxy server
Testing the API
You can test the API using the provided test script:
python notebooks/test_fastapi_proxy.py
This script tests various aspects of the API, including:
- Basic health checks
- Sentiment classification
- Topic classification
- Multi-label classification
- Model management (loading/unloading)
API Endpoints
Classification Endpoint
POST /models/classify
Example request:
{
"text": "The latest climate report shows alarming increases in global temperatures over the past decade.",
"candidate_labels": [
"politics",
"science",
"environment",
"technology",
"health",
"entertainment"
],
"multi_label": false
}
Example response:
{
"text": "The latest climate report shows alarming increases in global temperatures over the past decade.",
"candidate_labels": [
"politics",
"science",
"environment",
"technology",
"health",
"entertainment"
],
"multi_label": false,
"prediction": "environment",
"score": 0.712,
"all_predictions": [
{ "label": "environment", "score": 0.7124707102775574 },
{ "label": "science", "score": 0.23611347377300262 },
{ "label": "technology", "score": 0.020212717354297638 },
{ "label": "health", "score": 0.013427206315100193 },
{ "label": "politics", "score": 0.010179099626839161 },
{ "label": "entertainment", "score": 0.007596719544380903 }
]
}
Health Check
GET /health
Returns the status of the server and loaded models.
Model Management
POST /models/zero_shot/load
POST /models/zero_shot/unload
Manually load or unload models from the Triton server.
Advanced Usage
Model Management Configuration
The FastAPI proxy includes a model management system that can automatically load and unload models based on usage. Configuration is in main.py:
MODEL_MANAGEMENT = {
"enabled": True, # Enable/disable model management
"idle_threshold": 3600, # Unload models after 1 hour of inactivity
"check_interval": 300, # Check for idle models every 5 minutes
"always_loaded": [], # Models that should never be unloaded
"load_timeout": 30 # Timeout for model loading in seconds
}
Troubleshooting
-
GPU not detected: Make sure the NVIDIA Container Toolkit is installed and your Docker Compose file includes the correct GPU configuration.
-
Model loading errors: Check the Triton server logs for detailed error messages:
docker exec -it triton-topic cat /var/log/triton/server.log -
API errors: Check the FastAPI logs for detailed error messages:
docker exec -it triton-topic cat /var/log/fastapi.log -
Memory issues: If you encounter out-of-memory errors, try:
- Reducing batch sizes in the model configurations
- Using a smaller model
- Adding more GPU memory
Performance Optimization
-
Instance Groups: You can modify the number of model instances by editing the instance_group section in the model's config.pbtxt:
instance_group [
{
kind: KIND_GPU
count: 4 # Run 4 instances of the model
}
] -
Dynamic Batching: Enable dynamic batching to improve throughput by adding to the model's config.pbtxt:
dynamic_batching { } -
Model Concurrency: Adjust the max concurrency parameter in the Triton server startup command
--concurrency-limit=4
Conclusion
This project provides a flexible and scalable solution for zero-shot text classification using NVIDIA's Triton Inference Server and FastAPI. By following the steps outlined in this documentation, you can build, deploy, and use the system for a variety of text classification tasks without the need for task-specific training.