Skip to main content

topic-classification


id: topic-classification title: Topic Classification with Triton sidebar_label: Topic Classification description: Deploy zero-shot topic classification models with NVIDIA Triton Inference Server


Triton Inference Server for Topic Classification

This documentation provides a comprehensive guide to understanding and running the zero-shot classification project using NVIDIA's Triton Inference Server with a FastAPI proxy.

Project Overview

This project implements a topic text classification service using NVIDIA's Triton Inference Server and a FastAPI application as a proxy. The system uses a pre-trained BART-MNLI model for zero-shot text classification, allowing users to classify text into arbitrary categories without specific training for those categories.

Key Components

  1. Triton Inference Server: Manages model execution and scaling
  2. FastAPI Proxy: Provides a user-friendly REST API for the classification service
  3. Zero-Shot Classification Pipeline: Split into three components:
    • Tokenizer: Prepares text and labels for the model
    • Model: Runs inference on the tokenized inputs
    • Postprocessor: Transforms model outputs into user-friendly results

Project Structure

ai/TritonTopicClassification/
├── app/ # FastAPI application
│ ├── main.py # FastAPI server code
│ └── start.sh # Script to start both Triton and FastAPI
├── models/ # Triton model repository
│ ├── zero-shot/ # Ensemble model (combines all components)
│ │ ├── 1/ # Model version directory
│ │ └── config.pbtxt # Ensemble configuration
│ ├── zero-shot-model/ # ONNX model component
│ │ ├── 1/ # Model version with model.onnx file
│ │ └── config.pbtxt # Model configuration
│ ├── zero-shot-postprocessor/ # Python postprocessing component
│ │ ├── 1/ # Model version with model.py
│ │ └── config.pbtxt # Postprocessor configuration
│ └── zero-shot-tokenizer/ # Python tokenizer component
│ ├── 1/ # Model version with model.py and tokenizer_data/
│ └── config.pbtxt # Tokenizer configuration
└── notebooks/ # Test scripts and notebooks
└── test_fastapi_proxy.py # Script to test the API

Prerequisites

  • NVIDIA GPU with CUDA support
  • Docker
  • NVIDIA Container Toolkit
  • Python 3.10+ (for local development/testing)

Getting Started

Step 1: Build the Docker Image

Build the Docker image using the provided Dockerfile:

docker build -t triton-topic-fastapi:1

This builds an image based on NVIDIA's Triton Inference Server (version 24.12-py3) with the additional dependencies needed for the FastAPI proxy.

Step 2: Start the Container

Use Docker Compose to start the container:

docker-compose up -d

This will:

  • Start the Triton container with GPU access
  • Mount the ./models and ./app directories into the container
  • Map the necessary ports:
    • 8000: Triton HTTP API
    • 8001: Triton gRPC API
    • 8002: Triton Metrics API
    • 8005: FastAPI Proxy API

Step 3: Compile the Models (if needed)

If you need to compile or export the models, you can:

Set up a local Python environment:

conda create -n compile python=3.12
conda activate compile
pip install -r requirements_compile.txt

Use the provided compilation scripts:

python compile_onnx.py

This will:

  • Export a BART-MNLI model to ONNX format
  • Create the necessary Triton configuration files
  • Save the model files to the appropriate directories

Step 4: Start the Services

Connect to the container and start the Triton server and FastAPI proxy:

docker exec -it triton-topic /bin/bash
cd /app
./start.sh

This script:

  1. Starts the Triton server with the zero-shot model
  2. Waits for the Triton server to become ready
  3. Starts the FastAPI proxy server

Testing the API

You can test the API using the provided test script:

python notebooks/test_fastapi_proxy.py

This script tests various aspects of the API, including:

  • Basic health checks
  • Sentiment classification
  • Topic classification
  • Multi-label classification
  • Model management (loading/unloading)

API Endpoints

Classification Endpoint

POST /models/classify

Example request:

{
"text": "The latest climate report shows alarming increases in global temperatures over the past decade.",
"candidate_labels": [
"politics",
"science",
"environment",
"technology",
"health",
"entertainment"
],
"multi_label": false
}

Example response:

{
"text": "The latest climate report shows alarming increases in global temperatures over the past decade.",
"candidate_labels": [
"politics",
"science",
"environment",
"technology",
"health",
"entertainment"
],
"multi_label": false,
"prediction": "environment",
"score": 0.712,
"all_predictions": [
{ "label": "environment", "score": 0.7124707102775574 },
{ "label": "science", "score": 0.23611347377300262 },
{ "label": "technology", "score": 0.020212717354297638 },
{ "label": "health", "score": 0.013427206315100193 },
{ "label": "politics", "score": 0.010179099626839161 },
{ "label": "entertainment", "score": 0.007596719544380903 }
]
}

Health Check

GET /health

Returns the status of the server and loaded models.

Model Management

POST /models/zero_shot/load
POST /models/zero_shot/unload

Manually load or unload models from the Triton server.

Advanced Usage

Model Management Configuration

The FastAPI proxy includes a model management system that can automatically load and unload models based on usage. Configuration is in main.py:

MODEL_MANAGEMENT = {
"enabled": True, # Enable/disable model management
"idle_threshold": 3600, # Unload models after 1 hour of inactivity
"check_interval": 300, # Check for idle models every 5 minutes
"always_loaded": [], # Models that should never be unloaded
"load_timeout": 30 # Timeout for model loading in seconds
}

Troubleshooting

  1. GPU not detected: Make sure the NVIDIA Container Toolkit is installed and your Docker Compose file includes the correct GPU configuration.

  2. Model loading errors: Check the Triton server logs for detailed error messages:

    docker exec -it triton-topic cat /var/log/triton/server.log
  3. API errors: Check the FastAPI logs for detailed error messages:

    docker exec -it triton-topic cat /var/log/fastapi.log
  4. Memory issues: If you encounter out-of-memory errors, try:

    • Reducing batch sizes in the model configurations
    • Using a smaller model
    • Adding more GPU memory

Performance Optimization

  1. Instance Groups: You can modify the number of model instances by editing the instance_group section in the model's config.pbtxt:

    instance_group [
    {
    kind: KIND_GPU
    count: 4 # Run 4 instances of the model
    }
    ]
  2. Dynamic Batching: Enable dynamic batching to improve throughput by adding to the model's config.pbtxt:

    dynamic_batching { }
  3. Model Concurrency: Adjust the max concurrency parameter in the Triton server startup command

    --concurrency-limit=4

Conclusion

This project provides a flexible and scalable solution for zero-shot text classification using NVIDIA's Triton Inference Server and FastAPI. By following the steps outlined in this documentation, you can build, deploy, and use the system for a variety of text classification tasks without the need for task-specific training.