Transfer Learning (12): Industrial Applications and Best Practices
Chen Kai BOSS

Can academic SOTA models be used in industry? How to quickly land transfer learning projects with limited time and computational resources? This chapter summarizes industrial application experience of transfer learning in recommendation systems, NLP, computer vision, and provides a complete best practices guide from model selection to deployment monitoring from a practical perspective.

This article systematically explains the complete workflow of industrial transfer learning: pre-trained model selection, data preparation and augmentation, efficient fine-tuning strategies, model compression and quantization, deployment optimization, performance monitoring and continuous iteration, and provides complete code (300+ lines) for building production-grade transfer learning systems from scratch.

Industrial Application Scenarios of Transfer Learning

Natural Language Processing

1. Text Classification

Scenarios: Sentiment analysis, spam detection, news classification, intent recognition

Transfer strategies: - Pre-trained models: BERT, RoBERTa, DistilBERT - Fine-tuning layers: Classification head (1-2 layer MLP) - Data requirements: 100-1000 samples per class

Success cases:

Company Application Results
Google Gmail spam detection 99.9% accuracy
Amazon Product review sentiment analysis 15% improvement over traditional methods
Twitter Harmful content detection 20% recall improvement

2. Named Entity Recognition (NER)

Scenarios: Information extraction, knowledge graph construction, resume parsing

Transfer strategies: - Pre-trained models: BERT + CRF layer - Fine-tuning: Sequence tagging head - Data requirements: 1000-5000 annotated sentences

Practical experience: - Domain dictionary integration: Medical, financial specialized domains - Active learning: Prioritize annotating samples with model uncertainty - Pseudo-labeling: Expand training set with high-confidence model predictions

3. Question Answering Systems

Scenarios: Customer service chatbots, knowledge Q&A, search engines

Transfer strategies: - Pre-trained models: BERT-QA, RoBERTa-QA - Fine-tuning: Extractive QA (span prediction) or generative QA (seq2seq) - Data requirements: 500-2000 question-answer pairs

Architecture:

1
User question → Retrieval module → Candidate passages → BERT-QA → Answer span

Computer Vision

1. Image Classification

Scenarios: Product recognition, medical imaging diagnosis, defect detection

Transfer strategies: - Pre-trained models: ResNet, EfficientNet, ViT - Fine-tuning: Replace classification head, freeze early layers - Data requirements: 50-500 images per class

Practical tips: - Progressive unfreezing: Train classification head first, then gradually unfreeze layers - Data augmentation: Random cropping, flipping, color jittering - Mixed precision training: FP16 acceleration

2. Object Detection

Scenarios: Autonomous driving, security monitoring, retail checkout

Transfer strategies: - Pre-trained models: Faster R-CNN, YOLO, DETR - Fine-tuning: Detection head + RPN (Region Proposal Network) - Data requirements: 1000-5000 annotated images

Data annotation strategies: - Phased annotation: Coarse annotation first (bounding boxes), then fine annotation (subcategories) - Weak supervision: Train with image-level labels to reduce box annotation cost - Semi-supervised: Expand data with pseudo-labels

3. Semantic Segmentation

Scenarios: Medical image segmentation, autonomous driving scene understanding

Transfer strategies: - Pre-trained models: U-Net, DeepLab, Mask R-CNN - Fine-tuning: Segmentation head - Data requirements: 200-1000 pixel-level annotated images

Recommendation Systems

1. Cold Start Problem

Challenge: New users/items have no historical data

Transfer strategies: - Pre-training: Learn general representations on large-scale user behavior data - Fine-tuning: Fine-tune with small interaction data from new users/items - Meta-learning: Learn ability to quickly adapt to new users/items

Methods: - Two-tower model: User tower + item tower, pre-train then fine-tune independently - Graph neural networks: Leverage user-item graph structure for transfer

2. Cross-Domain Recommendation

Scenarios: E-commerce → Video, Music → Books

Transfer strategies: - Shared user representations: Share user embeddings across different domains - Domain adaptation: Adversarial training to reduce domain differences - Transfer learning: Source domain pre-training, target domain fine-tuning

Speech Recognition

Scenarios: Smart assistants, meeting transcription, call centers

Transfer strategies: - Pre-trained models: Wav2Vec 2.0, Whisper - Fine-tuning: Language model head + CTC loss - Data requirements: 10-100 hours annotated audio

Practices: - Data augmentation: Speed perturbation, noise injection, spectrum augmentation - Multi-task learning: Train ASR and language model simultaneously - Self-supervised pre-training: Pre-train on large amounts of unlabeled audio

Model Selection Strategies

Pre-trained Model Selection Matrix

NLP Tasks

Task Type Recommended Model Alternatives Reason
Text classification RoBERTa-base BERT, DistilBERT Good performance, stable training
NER BERT-base RoBERTa, ELECTRA Bidirectional modeling suitable for sequence tagging
Q&A RoBERTa-large BERT-large, ALBERT Large models have strong understanding
Text generation GPT-2, T5 BART, mT5 Generative architecture
Multilingual XLM-R mBERT, mT5 Best cross-lingual performance

CV Tasks

Task Type Recommended Model Alternatives Reason
Image classification EfficientNet-B3 ResNet-50, ViT Accuracy-efficiency balance
Object detection YOLOv8 Faster R-CNN, DETR Fast speed
Semantic segmentation DeepLabv3+ U-Net, Mask R-CNN High accuracy
Image retrieval CLIP ResNet + ArcFace Multimodal capability

Selection Criteria

1. Task Similarity

Rule: More similar pre-training and target tasks lead to better results.

Examples: - Sentiment classification: Choose BERT (pre-trained on general corpus) - Biomedical NER: Choose BioBERT (pre-trained on medical literature) - Legal text understanding: Choose Legal-BERT

2. Data Scale

Data Volume Model Size Fine-tuning Strategy
<100 samples Small model (BERT-base) Freeze most layers, only train head
100-1000 samples Medium model (RoBERTa-base) Freeze partial layers
1000-10000 samples Large model (RoBERTa-large) Full fine-tuning or LoRA
>10000 samples Extra-large model (GPT-3) Full fine-tuning

Principle: Use small models with less data to avoid overfitting.

3. Inference Latency

Scenarios: Online inference vs offline batch processing

Scenario Latency Requirement Recommended Model
Online search <50ms DistilBERT, MobileNet
Real-time recommendation <100ms TinyBERT, EfficientNet-B0
Offline analysis No requirement RoBERTa-large, EfficientNet-B7

Optimization: - Model distillation: BERT → DistilBERT (2x speedup) - Quantization: FP32 → INT8 (3-4x speedup) - Pruning: Remove unimportant parameters (50% computation reduction)

4. Resource Constraints

Factors: GPU memory, disk space, inference compute

Resource Level GPU Memory Recommended Model
Low <8GB DistilBERT, MobileNetV3
Medium 8-16GB BERT-base, ResNet-50
High >16GB RoBERTa-large, EfficientNet-B5

Data Preparation and Augmentation

Data Quality Assessment

1. Annotation Quality Check

Methods: - Inter-annotator agreement: Kappa coefficient > 0.7 - Annotation error detection: Train model, find high-loss samples for manual review - Adversarial sample testing: Test annotation robustness with adversarial samples

Practice:

1
2
3
4
5
6
7
8
9
10
11
from sklearn.metrics import cohen_kappa_score

# Calculate inter-annotator agreement
annotator1_labels = [0, 1, 1, 0, 1]
annotator2_labels = [0, 1, 0, 0, 1]

kappa = cohen_kappa_score(annotator1_labels, annotator2_labels)
print(f"Kappa score: {kappa:.3f}")

# Kappa > 0.7: Good agreement
# Kappa < 0.4: Need to retrain annotators

2. Data Distribution Check

Checklist: - Class balance: Similar number of samples per class - Distribution consistency: Train, validation, test sets have same distribution - Noise level: Proportion of duplicate samples, incorrect annotations

Data Augmentation Techniques

NLP Data Augmentation

  1. Back-Translation:
1
2
3
4
# English → Chinese → English
original = "This movie is great."
translated_zh = translate_en_to_zh(original) # "这部电影很棒"
back_translated = translate_zh_to_en(translated_zh) # "This film is excellent."
  1. EDA (Easy Data Augmentation)1:
    • Synonym replacement
    • Random insertion
    • Random swap
    • Random deletion
1
2
3
4
5
6
7
8
9
10
11
12
def eda(sentence, alpha=0.1, num_aug=4):
words = sentence.split()
n = len(words)

augmented = []
for _ in range(num_aug):
# Randomly replace alpha proportion of words with synonyms
num_replace = max(1, int(alpha * n))
new_words = synonym_replacement(words, num_replace)
augmented.append(' '.join(new_words))

return augmented
  1. Mixup for Text2:

Mix at embedding layer:

1
2
3
lambda_param = np.random.beta(0.2, 0.2)
mixed_embedding = lambda_param * emb_i + (1 - lambda_param) * emb_j
mixed_label = lambda_param * label_i + (1 - lambda_param) * label_j

CV Data Augmentation

  1. Basic Augmentation:
    • Random cropping, flipping, rotation
    • Color jittering, grayscale
    • Gaussian noise, blurring
  2. AutoAugment3: Automatically search optimal augmentation strategies
1
2
3
4
5
6
7
8
9
10
11
from torchvision import transforms

augmentation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
  1. Mixup & CutMix4:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def cutmix(image1, image2, label1, label2, alpha=1.0):
lam = np.random.beta(alpha, alpha)

# Generate random crop box
H, W = image1.shape[2:]
cut_w = int(W * np.sqrt(1 - lam))
cut_h = int(H * np.sqrt(1 - lam))

cx = np.random.randint(W)
cy = np.random.randint(H)

x1 = np.clip(cx - cut_w // 2, 0, W)
y1 = np.clip(cy - cut_h // 2, 0, H)
x2 = np.clip(cx + cut_w // 2, 0, W)
y2 = np.clip(cy + cut_h // 2, 0, H)

# Mix images
image1[:, :, y1:y2, x1:x2] = image2[:, :, y1:y2, x1:x2]

# Mix labels
lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
mixed_label = lam * label1 + (1 - lam) * label2

return image1, mixed_label

Efficient Fine-Tuning Strategies

Learning Rate Scheduling

1. Layer-wise Learning Rate

Principle: Shallow layers learn general features, deep layers learn task-specific features.

Strategy:whereis total number of layers,(decay factor).

Code:

1
2
3
4
5
6
7
8
9
10
11
def get_layer_wise_lr_params(model, base_lr=2e-5, alpha=0.95):
params = []
num_layers = len(list(model.named_parameters()))

for i, (name, param) in enumerate(model.named_parameters()):
layer_lr = base_lr * (alpha ** (num_layers - i))
params.append({'params': param, 'lr': layer_lr})

return params

optimizer = Adam(get_layer_wise_lr_params(model))

2. Learning Rate Warmup

Purpose: Avoid large steps destroying pre-trained weights in early training.

Linear warmup:

Code:

1
2
3
4
5
6
7
8
9
10
from transformers import get_linear_schedule_with_warmup

num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)

3. Cosine Annealing

Advantage: Avoid sudden learning rate drops, smooth convergence.

Gradual Unfreezing

Strategy: Unfreeze model layers in stages.

Algorithm:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def gradual_unfreeze(model, optimizer, num_stages=4):
layers = list(model.children())
layers_per_stage = len(layers) // num_stages

for stage in range(num_stages):
# Unfreeze layers in current stage
start_idx = stage * layers_per_stage
end_idx = (stage + 1) * layers_per_stage

for layer in layers[start_idx:end_idx]:
for param in layer.parameters():
param.requires_grad = True

# Train current stage
train_one_stage(model, optimizer, train_loader)

Effect: - Avoid catastrophic forgetting - Reduce training time (early layers don't need updates)

Discriminative Fine-Tuning

Principle: Different layers use different learning rates.

Howard and Ruder5 proposed ULMFiT strategy:

  1. Stage 1: Only train classification head (freeze BERT)
  2. Stage 2: Unfreeze last few layers, fine-tune with small learning rate
  3. Stage 3: Full fine-tuning, layer-wise decreasing learning rates

Empirical values:

Layer Learning Rate Multiplier
Classification head 1.0
BERT last layer 0.5
BERT middle layers 0.25
BERT initial layers 0.1

Early Stopping and Regularization

1. Early Stopping

Strategy: Stop training when validation performance stops improving.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class EarlyStopping:
def __init__(self, patience=5, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False

def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0

2. Dropout

Strategy: Use higher Dropout (0.3-0.5) in classification head, lower Dropout (0.1) in BERT layers.

3. Weight Decay

Purpose: Prevent overfitting.

Empirical values: - Small datasets (<1000 samples): 0.01-0.1 - Medium datasets (1000-10000 samples): 0.001-0.01 - Large datasets (>10000 samples): 0.0001-0.001

Model Compression and Acceleration

Knowledge Distillation

Goal: Transfer knowledge from large model (teacher) to small model (student).

Loss function:where: -: Cross-entropy loss (true labels) -: KL divergence loss (soft labels)

Temperature Scaling:Largermakes distribution smoother (typical:).

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
# Soft label loss (KL divergence)
soft_targets = F.softmax(teacher_logits / T, dim=1)
soft_student = F.log_softmax(student_logits / T, dim=1)
kl_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (T ** 2)

# Hard label loss (cross-entropy)
ce_loss = F.cross_entropy(student_logits, labels)

# Total loss
loss = alpha * ce_loss + (1 - alpha) * kl_loss

return loss

Performance: - DistilBERT (66M params): Retains 97% of BERT-base (110M) performance, 2x faster - TinyBERT (14M params): Retains 96% of BERT-base performance, 9x faster

Quantization

Goal: Convert FP32 weights to INT8, 4x reduction in storage and computation.

Post-Training Quantization

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.quantization as quant

# Quantize model
model.eval()
model_int8 = quant.quantize_dynamic(
model,
{torch.nn.Linear}, # Quantize Linear layers
dtype=torch.qint8
)

# Inference
with torch.no_grad():
output = model_int8(input)

Performance: - Model size: 75% reduction - Inference speed: 2-4x improvement - Accuracy loss: <1%

Quantization-Aware Training

Strategy: Simulate quantization during training to reduce accuracy loss.

1
2
3
4
5
6
7
8
9
10
11
12
# Prepare quantization config
model.qconfig = quant.get_default_qat_qconfig('fbgemm')

# Insert fake quantization modules
model_prepared = quant.prepare_qat(model, inplace=False)

# Train
for epoch in range(num_epochs):
train_one_epoch(model_prepared, train_loader, optimizer)

# Convert to quantized model
model_quantized = quant.convert(model_prepared.eval(), inplace=False)

Pruning

Goal: Remove unimportant parameters.

Unstructured Pruning

Strategy: Removeparameters with smallest absolute weight (e.g.,).

1
2
3
4
5
6
7
import torch.nn.utils.prune as prune

# Prune Linear layer
prune.l1_unstructured(model.layer, name='weight', amount=0.5)

# Remove pruned weights
prune.remove(model.layer, 'weight')

Structured Pruning

Strategy: Remove entire neurons or channels.

Advantage: No special hardware support needed, directly reduces computation.

ONNX Export and Optimization

ONNX: Cross-platform model format supporting multiple inference engines (TensorRT, ONNX Runtime).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.onnx

# Export ONNX
dummy_input = torch.randn(1, 128, dtype=torch.long)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence' }}
)

# Inference with ONNX Runtime
import onnxruntime as ort

session = ort.InferenceSession("model.onnx")
output = session.run(['logits'], {'input_ids': input_ids.numpy()})

Performance improvement: - CPU inference: 2-3x speedup - GPU inference: 1.5-2x speedup

Deployment and Monitoring

Model Serving

1. REST API Deployment (Flask/FastAPI)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from fastapi import FastAPI
from pydantic import BaseModel
import torch

app = FastAPI()

# Load model (execute once at startup)
model = load_model("model.pt")
model.eval()

class PredictRequest(BaseModel):
text: str

@app.post("/predict")
async def predict(request: PredictRequest):
# Preprocessing
inputs = tokenizer(request.text, return_tensors='pt')

# Inference
with torch.no_grad():
logits = model(**inputs)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()

return {"prediction": pred, "confidence": probs[0][pred].item()}

Startup:

1
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4

2. Batch Processing Optimization

Dynamic Batching: Combine multiple requests into one batch for inference.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import asyncio
from collections import deque

class BatchInferenceService:
def __init__(self, model, max_batch_size=32, timeout_ms=100):
self.model = model
self.max_batch_size = max_batch_size
self.timeout_ms = timeout_ms
self.queue = deque()

async def add_request(self, input_data):
future = asyncio.Future()
self.queue.append((input_data, future))

# Trigger batch processing
if len(self.queue) >= self.max_batch_size:
await self.process_batch()

return await future

async def process_batch(self):
if not self.queue:
return

# Extract batch
batch_size = min(len(self.queue), self.max_batch_size)
batch = [self.queue.popleft() for _ in range(batch_size)]

# Batch inference
inputs = [item[0] for item in batch]
futures = [item[1] for item in batch]

outputs = self.model(inputs)

# Distribute results
for future, output in zip(futures, outputs):
future.set_result(output)

3. Model Caching

Strategy: Cache prediction results for common queries.

1
2
3
4
5
from functools import lru_cache

@lru_cache(maxsize=10000)
def predict_cached(text):
return model.predict(text)

Performance Monitoring

1. Inference Latency Monitoring

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import time
from prometheus_client import Histogram

# Prometheus metric
inference_latency = Histogram('inference_latency_seconds', 'Inference latency')

@inference_latency.time()
def predict(input_data):
start_time = time.time()

# Inference
output = model(input_data)

latency = time.time() - start_time
print(f"Inference latency: {latency*1000:.2f}ms")

return output

Alert thresholds: - P50 latency > 100ms: Warning - P99 latency > 500ms: Critical

2. Model Performance Monitoring

Metrics: - Accuracy, precision, recall, F1 - Error sample analysis - Data drift detection

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from sklearn.metrics import accuracy_score, classification_report

def monitor_model_performance(predictions, ground_truth):
accuracy = accuracy_score(ground_truth, predictions)
report = classification_report(ground_truth, predictions)

# Log to monitoring system
log_metric("model_accuracy", accuracy)

# Alert
if accuracy < 0.85:
send_alert("Model accuracy dropped below threshold!")

return report

3. Data Drift Detection

Method: Compare production data with training data distribution.

1
2
3
4
5
6
7
8
9
10
11
from scipy.stats import ks_2samp

def detect_data_drift(train_data, prod_data, threshold=0.05):
# Kolmogorov-Smirnov test
statistic, p_value = ks_2samp(train_data, prod_data)

if p_value < threshold:
send_alert(f"Data drift detected! p-value={p_value:.4f}")
return True

return False

A/B Testing

Goal: Compare actual effects of new and old models.

Process:

  1. Traffic split: 50% users use model A, 50% use model B
  2. Collect metrics: Click rate, conversion rate, user satisfaction
  3. Statistical testing: t-test to determine significance
  4. Decision: Full rollout if significantly better

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import random

def ab_test_model_selection(user_id):
# Group by user_id hash
group = hash(user_id) % 2

if group == 0:
return model_A.predict(input)
else:
return model_B.predict(input)

# Statistical analysis
from scipy.stats import ttest_ind

results_A = [0.75, 0.82, 0.78, 0.80, 0.79] # Model A metrics
results_B = [0.81, 0.85, 0.83, 0.84, 0.82] # Model B metrics

t_stat, p_value = ttest_ind(results_A, results_B)
print(f"T-statistic: {t_stat:.4f}, P-value: {p_value:.4f}")

if p_value < 0.05:
print("Model B significantly better than Model A, recommend full rollout")

Continuous Iteration and Model Updates

Active Learning

Goal: Prioritize annotating samples with highest model uncertainty.

Uncertainty measures:

  1. Entropy:

  2. Least Confidence:

  3. Margin Sampling:whereare highest and second-highest prediction probabilities.

Algorithm:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def active_learning_selection(model, unlabeled_data, n_samples=100):
# Predict
model.eval()
with torch.no_grad():
logits = model(unlabeled_data)
probs = F.softmax(logits, dim=1)

# Calculate entropy
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)

# Select samples with highest entropy
_, top_indices = torch.topk(entropy, n_samples)

return unlabeled_data[top_indices]

Incremental Learning

Scenario: Continuously arriving new data, need to update model constantly.

Strategies:

  1. Periodic full retraining: Weekly/monthly retrain with all data
  2. Incremental fine-tuning: Fine-tune with new data, combine with EWC to prevent forgetting (see Chapter 10)
  3. Online learning: Real-time model updates

Pseudocode:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Incremental learning workflow
while True:
# 1. Collect new data
new_data = collect_new_data()

# 2. Data cleaning and annotation
labeled_data = label_data(new_data)

# 3. Incremental fine-tuning
model = incremental_fine_tune(model, labeled_data, old_data_sample)

# 4. Evaluate performance
performance = evaluate(model, test_set)

# 5. Performance monitoring
if performance < threshold:
# Rollback to old model
model = load_old_model()
send_alert("Model performance degraded!")
else:
# Deploy new model
deploy(model)

# 6. Wait for next cycle
time.sleep(period)

Model Version Management

Tools: MLflow, DVC, Weights & Biases

Practice:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import mlflow

# Log experiment
with mlflow.start_run():
# Log hyperparameters
mlflow.log_params({
"learning_rate": 2e-5,
"batch_size": 32,
"num_epochs": 10
})

# Train model
model = train_model(...)

# Log metrics
mlflow.log_metrics({
"accuracy": 0.92,
"f1_score": 0.89
})

# Save model
mlflow.pytorch.log_model(model, "model")

Model rollback:

1
2
3
# Load specified version model
model_uri = f"models:/sentiment_classifier/production"
model = mlflow.pytorch.load_model(model_uri)

Complete Code: End-to-End Transfer Learning Project

Below is a complete industrial-grade transfer learning project template covering the entire workflow from data preparation to training, evaluation, and deployment.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
End-to-End Transfer Learning Project: Text Classification
Includes: Data preparation, model training, evaluation, deployment, monitoring
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
import json
import logging
from typing import List, Dict
import time

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ============================================================================
# Configuration
# ============================================================================

class Config:
# Model configuration
MODEL_NAME = 'bert-base-uncased'
NUM_CLASSES = 3
MAX_LENGTH = 128

# Training configuration
BATCH_SIZE = 32
NUM_EPOCHS = 5
LEARNING_RATE = 2e-5
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

# Early stopping configuration
EARLY_STOPPING_PATIENCE = 3

# Paths
MODEL_SAVE_PATH = "models/best_model.pt"
CONFIG_SAVE_PATH = "models/config.json"

# ============================================================================
# Dataset
# ============================================================================

class TextClassificationDataset(Dataset):
def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]

encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)

return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label, dtype=torch.long)
}

# ============================================================================
# Trainer
# ============================================================================

class Trainer:
def __init__(self, model, train_loader, val_loader, config, device):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = device

# Optimizer
self.optimizer = optim.AdamW(
model.parameters(),
lr=config.LEARNING_RATE,
weight_decay=config.WEIGHT_DECAY
)

# Learning rate scheduler
num_training_steps = len(train_loader) * config.NUM_EPOCHS
num_warmup_steps = int(config.WARMUP_RATIO * num_training_steps)

self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)

# Loss function
self.criterion = nn.CrossEntropyLoss()

# Early stopping
self.best_val_loss = float('inf')
self.patience_counter = 0

def train_epoch(self):
self.model.train()
total_loss = 0
all_preds = []
all_labels = []

for batch in self.train_loader:
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['label'].to(self.device)

# Forward pass
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
loss = self.criterion(logits, labels)

# Backward pass
self.optimizer.zero_grad()
loss.backward()

# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

self.optimizer.step()
self.scheduler.step()

# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(self.train_loader)
accuracy = accuracy_score(all_labels, all_preds)

return avg_loss, accuracy

def evaluate(self):
self.model.eval()
total_loss = 0
all_preds = []
all_labels = []

with torch.no_grad():
for batch in self.val_loader:
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['label'].to(self.device)

outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
loss = self.criterion(logits, labels)

total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(self.val_loader)
accuracy = accuracy_score(all_labels, all_preds)

return avg_loss, accuracy, all_preds, all_labels

def train(self):
logger.info("Starting training...")

for epoch in range(self.config.NUM_EPOCHS):
# Train
train_loss, train_acc = self.train_epoch()

# Validate
val_loss, val_acc, val_preds, val_labels = self.evaluate()

logger.info(
f"Epoch [{epoch+1}/{self.config.NUM_EPOCHS}] "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
)

# Early stopping check
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.patience_counter = 0

# Save best model
self.save_model()
logger.info("Best model saved!")
else:
self.patience_counter += 1

if self.patience_counter >= self.config.EARLY_STOPPING_PATIENCE:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break

# Final report
logger.info("\nFinal Evaluation:")
logger.info(classification_report(val_labels, val_preds))

def save_model(self):
torch.save(self.model.state_dict(), self.config.MODEL_SAVE_PATH)

# Save config
config_dict = {
'model_name': self.config.MODEL_NAME,
'num_classes': self.config.NUM_CLASSES,
'max_length': self.config.MAX_LENGTH
}
with open(self.config.CONFIG_SAVE_PATH, 'w') as f:
json.dump(config_dict, f)

# ============================================================================
# Inference Service
# ============================================================================

class InferenceService:
def __init__(self, model_path: str, config_path: str, device: str = 'cpu'):
# Load config
with open(config_path, 'r') as f:
config = json.load(f)

# Load model
self.model = BertForSequenceClassification.from_pretrained(
config['model_name'],
num_labels=config['num_classes']
)
self.model.load_state_dict(torch.load(model_path, map_location=device))
self.model.to(device)
self.model.eval()

# Load tokenizer
self.tokenizer = BertTokenizer.from_pretrained(config['model_name'])
self.max_length = config['max_length']
self.device = device

logger.info("Model loaded successfully!")

def predict(self, text: str) -> Dict:
start_time = time.time()

# Preprocessing
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)

# Inference
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
pred = int(np.argmax(probs))

latency = time.time() - start_time

return {
'prediction': pred,
'confidence': float(probs[pred]),
'probabilities': probs.tolist(),
'latency_ms': latency * 1000
}

def batch_predict(self, texts: List[str]) -> List[Dict]:
results = []
for text in texts:
result = self.predict(text)
results.append(result)
return results

# ============================================================================
# Main Function
# ============================================================================

def main():
# Configuration
config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Simulate data (replace with real data in actual applications)
texts = [
"This product is amazing!", "Terrible service, very disappointed.",
"Average quality, nothing special.", "I love it!", "Waste of money."
] * 200
labels = [2, 0, 1, 2, 0] * 200 # 0: negative, 1: neutral, 2: positive

# Split dataset
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42
)

logger.info(f"Train samples: {len(train_texts)}, Val samples: {len(val_texts)}")

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(config.MODEL_NAME)

# Create datasets
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, config.MAX_LENGTH)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, config.MAX_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

# Create model
model = BertForSequenceClassification.from_pretrained(
config.MODEL_NAME,
num_labels=config.NUM_CLASSES
)

# Train
trainer = Trainer(model, train_loader, val_loader, config, device)
trainer.train()

# Inference test
logger.info("\n" + "="*70)
logger.info("Testing Inference Service")
logger.info("="*70)

inference_service = InferenceService(config.MODEL_SAVE_PATH, config.CONFIG_SAVE_PATH, device)

test_texts = [
"This is the best product ever!",
"Completely useless, do not buy.",
"It's okay, meets expectations."
]

for text in test_texts:
result = inference_service.predict(text)
logger.info(f"\nText: {text}")
logger.info(f"Prediction: {result['prediction']}, Confidence: {result['confidence']:.4f}")
logger.info(f"Latency: {result['latency_ms']:.2f}ms")

logger.info("\n" + "="*70)
logger.info("Project completed successfully!")
logger.info("="*70)

if __name__ == "__main__":
main()

Code Explanation

Core modules:

  1. Config: Centrally manage all hyperparameters
  2. TextClassificationDataset: Data loading and preprocessing
  3. Trainer: Training workflow encapsulation (early stopping, learning rate scheduling)
  4. InferenceService: Inference service (model loading, prediction, latency monitoring)

Production-grade features:

  • Complete training and validation workflow
  • Early stopping to prevent overfitting
  • Learning rate warmup and decay
  • Model saving and loading
  • Inference latency monitoring
  • Logging

Extension suggestions:

  • Add data augmentation
  • Integrate MLflow for experiment tracking
  • Add Prometheus metrics export
  • Implement dynamic batching
  • Add model version management

Summary

This article comprehensively summarizes industrial applications and best practices of transfer learning:

  1. Application scenarios: Real cases in NLP, CV, recommendation systems, speech recognition
  2. Model selection: Pre-trained model selection matrix, task/data/resource-based selection
  3. Data preparation: Quality assessment, data augmentation techniques (back-translation, EDA, Mixup, CutMix)
  4. Efficient fine-tuning: Learning rate scheduling, gradual unfreezing, discriminative fine-tuning, early stopping
  5. Model compression: Knowledge distillation, quantization, pruning, ONNX export
  6. Deployment monitoring: Model serving, batch processing optimization, performance monitoring, A/B testing
  7. Continuous iteration: Active learning, incremental learning, model version management
  8. Complete code: 300+ lines production-grade end-to-end project template

Transfer learning has become a core technology for AI landing in industry. Mastering these best practices can significantly improve project success rate and landing efficiency.

This concludes all 12 chapters of the transfer learning series! From basic concepts to cutting-edge technologies, from theoretical derivations to engineering practices, we have systematically learned all aspects of transfer learning. Hope this complete tutorial helps you excel in both research and industrial applications.

References


  1. Wei, J., & Zou, K. (2019). EDA: Easy data augmentation techniques for boosting performance on text classification tasks. EMNLP.↩︎

  2. Guo, H., Mao, Y., & Zhang, R. (2019). Augmenting data with mixup for sentence classification: An empirical study. arXiv:1905.08941.↩︎

  3. Cubuk, E. D., Zoph, B., Mane, D., et al. (2019). AutoAugment: Learning augmentation strategies from data. CVPR.↩︎

  4. Yun, S., Han, D., Oh, S. J., et al. (2019). CutMix: Regularization strategy to train strong classifiers with localizable features. ICCV.↩︎

  5. Howard, J., & Ruder, S. (2018). Universal language model fine-tuning for text classification. ACL.↩︎

  • Post title:Transfer Learning (12): Industrial Applications and Best Practices
  • Post author:Chen Kai
  • Create time:2025-01-08 14:45:00
  • Post link:https://www.chenk.top/transfer-learning-12-industrial-applications-and-best-practices/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments