pytorch
Assists with building, training, and deploying neural networks using PyTorch. Use when designing architectures for computer vision, NLP, or tabular data, optimizing training with mixed precision and distributed strategies, or exporting models for production inference. Trigger words: pytorch, torch, neural network, deep learning, training loop, cuda.
Usage
Getting Started
- Install the skill using the command above
- Open your AI coding agent (Claude Code, Codex, Gemini CLI, or Cursor)
- Reference the skill in your prompt
- The AI will use the skill's capabilities automatically
Example Prompts
- "Analyze the sales data in revenue.csv and identify trends"
- "Create a visualization comparing Q1 vs Q2 performance metrics"
Documentation
Overview
PyTorch is a deep learning framework for building and training neural networks with dynamic computation graphs and automatic differentiation. It provides tensor operations with GPU acceleration, nn.Module for defining architectures, DataLoader for efficient data loading, mixed precision training for performance, and export tools (TorchScript, ONNX) for production deployment.
Instructions
- When defining models, subclass
nn.Modulewith__init__for layers andforwardfor computation, usingnn.Sequentialfor simple stacks and custom forward logic for complex architectures. - When training, implement the standard loop: forward pass, loss computation,
loss.backward(),optimizer.step(),optimizer.zero_grad(), with gradient clipping viaclip_grad_norm_for stability. - When loading data, subclass
Datasetwith__len__and__getitem__, then useDataLoaderwithnum_workers=4andpin_memory=Truefor GPU training throughput. - When optimizing performance, use
torch.compile(model)on PyTorch 2.0+ for 20-50% speedup, mixed precision withtorch.amp.autocast()for halved memory and doubled throughput, andDistributedDataParallelfor multi-GPU training. - When doing transfer learning, load pretrained models from
torchvision.modelsor Hugging Face, freeze the backbone, and replace the classifier head for your task. - When deploying, use
torch.export()ortorch.jit.trace()for production,torch.onnx.export()for cross-framework compatibility, andtorch.quantizationfor INT8 inference speedup.
Examples
Example 1: Fine-tune a vision model for image classification
User request: "Fine-tune a pretrained ResNet for classifying product images"
Actions:
- Load
resnet50(weights=ResNet50_Weights.DEFAULT)and freeze all layers except the final classifier - Replace the classifier head with
nn.Linear(2048, num_classes) - Set up DataLoader with image augmentation transforms (RandomCrop, ColorJitter, Normalize)
- Train with AdamW, CosineAnnealingLR scheduler, and mixed precision
Output: A fine-tuned image classifier with production-quality accuracy and efficient mixed-precision training.
Example 2: Train a text classification model with Hugging Face
User request: "Build a sentiment analysis model using a pretrained transformer"
Actions:
- Load
AutoModel.from_pretrained("bert-base-uncased")with a classification head - Tokenize the dataset using
AutoTokenizerand create a DataLoader - Fine-tune with AdamW, linear warmup scheduler, and gradient clipping
- Export the trained model with
torch.export()for production serving
Output: A sentiment analysis model fine-tuned on custom data and exported for production inference.
Guidelines
- Use
torch.compile(model)on PyTorch 2.0+ for a free 20-50% speedup with one line. - Use
AdamWoverAdamfor correct weight decay implementation with modern architectures. - Use mixed precision (
torch.amp) for any GPU training to halve memory and double throughput. - Move data to device in the training loop, not in the Dataset, to keep Dataset device-agnostic.
- Use
model.eval()andtorch.no_grad()during inference to prevent unnecessary gradient computation. - Use
pin_memory=Truein DataLoader when training on GPU to speed up CPU-to-GPU data transfer. - Save
model.state_dict()not the full model since state dicts are portable across code changes.
Information
- Version
- 1.0.0
- Author
- terminal-skills
- Category
- Data & AI
- License
- Apache-2.0