目录

强化学习PPO-分类任务

强化学习PPO-分类任务

一、定义

  1. 注意点
  2. 案例

二、实现

  1. 注意点
    强化学习需要根据实际需求创建符合自己业务场景的环境,从而与智能体进行交互。
    1. 环境需要自己写reset()、step() 函数。 因为分类任务每个回合不需要多步,因此为了避免reset() 重置时数据id 重置,因此每次遍历的时候+1,从而保证能够学习所有的数据。
    2. step() 方法注意,因为每个回合只走一步,因此在step 中,需要终止参数terminated = True。
class ImprovedClassificationEnv(gym.Env):
    """改进的分类环境"""
    metadata = {'render.modes': ['human']}

    def __init__(self, X, y):
        super(ImprovedClassificationEnv, self).__init__()
        ...   
    def reset(self, seed=None, options=None):
        """重置环境状态"""
        super().reset(seed=seed)
        # 循环使用所有样本
        self.current_sample_idx = self.current_episode % self.num_samples
        self.current_episode += 1
        
        # 如果指定了特定样本(用于评估)
        if options and 'sample_idx' in options:
            self.current_sample_idx = options['sample_idx']
        
        # 创建增强的状态:特征 + 样本索引归一化 + 类别先验
        sample_features = self.X[self.current_sample_idx].astype(np.float32)
        
        # 添加额外信息:样本索引(归一化)和类别分布先验
        extra_info = np.array([
            self.current_sample_idx / self.num_samples,  # 归一化索引
            np.mean(self.y == self.y[self.current_sample_idx])  # 同类比例
        ], dtype=np.float32)
        
        state = np.concatenate([sample_features, extra_info])
        
        return state, {}

    def step(self, action):
        """执行动作(进行分类)"""
        true_label = self.y[self.current_sample_idx]
        
        # 改进的奖励函数
        if action == true_label:
            # 正确分类:基础奖励 + 置信度奖励
            reward = 2.0
        else:
            # 错误分类:基础惩罚 + 根据错误程度调整
            reward = -1.0
            
            # 如果错误程度较大(如将类别0预测为类别2),惩罚更重
            if abs(action - true_label) > 1:
                reward -= 0.5
        
        # 添加探索奖励(鼓励尝试不同类别)
        if len(self.visited_samples) < self.num_samples * 0.1:  # 前10%的探索阶段
            if action not in self.visited_samples:
                reward += 0.1
                self.visited_samples.add(action)
        
        # 总是终止,因为每个样本只做一次分类决策
        terminated = True
        truncated = False
        
        info = {
            'true_label': true_label,
            'predicted_label': action,
            'correct': action == true_label,
            'sample_idx': self.current_sample_idx
        }
        
        # 返回下一个状态(虽然是终止状态,但仍返回当前状态)
        next_state, _ = self.reset()
        return next_state, reward, terminated, truncated, info

    def render(self, mode='human'):
        if mode == 'human':
            print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")
        return None
  1. 案例
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import deque
import torch
import torch.nn as nn

# 设置随机种子确保可重现性
np.random.seed(42)
torch.manual_seed(42)

# 1. 生成模拟数据
def generate_classification_data(n_samples=10000, n_features=10, n_classes=3):
    """生成分类数据集"""
    print("生成分类数据...")
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=8,
        n_redundant=2,
        n_classes=n_classes,
        n_clusters_per_class=1,
        random_state=42
    )
    
    # 数据标准化
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    # 分割训练测试集
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    print(f"训练集形状: {X_train.shape}, 测试集形状: {X_test.shape}")
    print(f"类别分布 - 训练集: {np.bincount(y_train)}, 测试集: {np.bincount(y_test)}")
    
    return X_train, X_test, y_train, y_test, scaler

# 2. 改进的分类环境
class ImprovedClassificationEnv(gym.Env):
    """改进的分类环境"""
    metadata = {'render.modes': ['human']}

    def __init__(self, X, y):
        super(ImprovedClassificationEnv, self).__init__()
        
        self.X = X
        self.y = y
        self.num_samples, self.num_features = X.shape
        self.num_classes = len(np.unique(y))
        self.current_episode = 0
        
        # 动作空间:选择类别
        self.action_space = spaces.Discrete(self.num_classes)
        
        # 状态空间:当前样本特征 + 额外信息
        self.observation_space = spaces.Box(
            low=-5.0,
            high=5.0,
            shape=(self.num_features + 2,),  # 增加额外信息
            dtype=np.float32
        )
        
        self.current_sample_idx = None
        self.visited_samples = set()

    def reset(self, seed=None, options=None):
        """重置环境状态"""
        super().reset(seed=seed)
        
        # 循环使用所有样本
        self.current_sample_idx = self.current_episode % self.num_samples
        self.current_episode += 1
        
        # 如果指定了特定样本(用于评估)
        if options and 'sample_idx' in options:
            self.current_sample_idx = options['sample_idx']
        
        # 创建增强的状态:特征 + 样本索引归一化 + 类别先验
        sample_features = self.X[self.current_sample_idx].astype(np.float32)
        
        # 添加额外信息:样本索引(归一化)和类别分布先验
        extra_info = np.array([
            self.current_sample_idx / self.num_samples,  # 归一化索引
            np.mean(self.y == self.y[self.current_sample_idx])  # 同类比例
        ], dtype=np.float32)
        
        state = np.concatenate([sample_features, extra_info])
        
        return state, {}

    def step(self, action):
        """执行动作(进行分类)"""
        true_label = self.y[self.current_sample_idx]
        
        # 改进的奖励函数
        if action == true_label:
            # 正确分类:基础奖励 + 置信度奖励
            reward = 2.0
        else:
            # 错误分类:基础惩罚 + 根据错误程度调整
            reward = -1.0
            
            # 如果错误程度较大(如将类别0预测为类别2),惩罚更重
            if abs(action - true_label) > 1:
                reward -= 0.5
        
        # 添加探索奖励(鼓励尝试不同类别)
        if len(self.visited_samples) < self.num_samples * 0.1:  # 前10%的探索阶段
            if action not in self.visited_samples:
                reward += 0.1
                self.visited_samples.add(action)
        
        # 总是终止,因为每个样本只做一次分类决策
        terminated = True
        truncated = False
        
        info = {
            'true_label': true_label,
            'predicted_label': action,
            'correct': action == true_label,
            'sample_idx': self.current_sample_idx
        }
        
        # 返回下一个状态(虽然是终止状态,但仍返回当前状态)
        next_state, _ = self.reset()
        return next_state, reward, terminated, truncated, info

    def render(self, mode='human'):
        if mode == 'human':
            print(f"样本 {self.current_sample_idx}: 真实标签={self.y[self.current_sample_idx]}")
        return None


# 4. 改进的评估函数
def enhanced_evaluate_model(model, X_test, y_test, env_class):
    """增强的模型评估"""
    print("\n评估模型性能...")
    
    eval_env = env_class(X_test, y_test)
    
    # 评估奖励
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=len(X_test))
    print(f"平均奖励: {mean_reward:.4f} +/- {std_reward:.4f}")
    
    # 计算准确率和其他指标
    all_predictions = []
    all_true_labels = []
    all_confidences = []
    
    for i in range(len(X_test)):
        obs, _ = eval_env.reset(options={'sample_idx': i})
        action, _ = model.predict(obs, deterministic=True)
        all_predictions.append(action)
        all_true_labels.append(y_test[i])
    
    accuracy = accuracy_score(all_true_labels, all_predictions)
    print(f"分类准确率: {accuracy:.4f}")
    
    # 显示详细分类报告
    print("\n详细分类报告:")
    print(classification_report(all_true_labels, all_predictions))
    
    # 绘制混淆矩阵
    # cm = confusion_matrix(all_true_labels, all_predictions)
    # plt.figure(figsize=(8, 6))
    # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    # plt.title('Confusion Matrix')
    # plt.ylabel('True Label')
    # plt.xlabel('Predicted Label')
    # plt.tight_layout()
    # plt.savefig('confusion_matrix.png')
    # plt.show()
    
    return accuracy, mean_reward

# 5. 改进的训练回调
class ImprovedTrainingCallback(BaseCallback):
    """改进的训练回调"""
    def __init__(self, eval_env, check_freq=1000, verbose=0):
        super(ImprovedTrainingCallback, self).__init__(verbose)
        self.eval_env = eval_env
        self.check_freq = check_freq
        self.best_accuracy = 0
        self.accuracies = []
        self.rewards = []
        
    def _on_step(self) -> bool:
        if self.n_calls % self.check_freq == 0:
            # 评估当前模型
            current_accuracy, mean_reward = enhanced_evaluate_model(
                self.model, 
                self.eval_env.X, 
                self.eval_env.y,
                ImprovedClassificationEnv
            )
            
            self.accuracies.append(current_accuracy)
            self.rewards.append(mean_reward)
            
            print(f"Timestep: {self.n_calls}")
            print(f"准确率: {current_accuracy:.4f}, 平均奖励: {mean_reward:.4f}")
            
            # 保存最佳模型
            if current_accuracy > self.best_accuracy:
                self.best_accuracy = current_accuracy
                print(f"新的最佳模型! 准确率: {current_accuracy:.4f}")
                self.model.save("best_classifier_model")
                
        return True


# 7. 主函数
def main():
    # 生成数据
    X_train, X_test, y_train, y_test, scaler = generate_classification_data()

    # 创建改进的环境
    env = make_vec_env(
        lambda: ImprovedClassificationEnv(X_train, y_train), 
        n_envs=4,
        seed=42
    )
    
    # 创建改进的PPO模型
    model = PPO(
        "MlpPolicy", 
        env, 
        verbose=1,
        learning_rate=1e-4,  # 更小的学习率
        n_steps=1024,        # 更多的步数
        batch_size=256,      # 更大的批次
        n_epochs=20,         # 更多的训练轮次
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.1,      # 更小的裁剪范围
        ent_coef=0.02,       # 适当的熵系数
        vf_coef=0.5,         # 值函数系数
        max_grad_norm=0.5,   # 梯度裁剪
        tensorboard_log="./tensorboard_logs/",
    )
    
    # 创建评估环境和回调函数
    eval_env = ImprovedClassificationEnv(X_test, y_test)
    callback = ImprovedTrainingCallback(eval_env, check_freq=5000)
    accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)
    # 训练模型
    print("\n开始训练...")
    model.learn(
        total_timesteps=200000,  # 更多的训练步数
        callback=callback,
        progress_bar=True,
        tb_log_name="ppo_classification"
    )
    
    # 保存最终模型
    model.save("ppo_classifier_final")
    print("训练完成,模型已保存")
    
    # 评估模型
    accuracy, mean_reward = enhanced_evaluate_model(model, X_test, y_test, ImprovedClassificationEnv)
    
    # 与传统方法对比
    from sklearn.linear_model import LogisticRegression
    from sklearn.ensemble import RandomForestClassifier
    
    print("\n与传统监督学习方法对比:")
    
    lr_model = LogisticRegression(random_state=42, max_iter=1000)
    lr_model.fit(X_train, y_train)
    lr_pred = lr_model.predict(X_test)
    lr_accuracy = accuracy_score(y_test, lr_pred)
    print(f"逻辑回归准确率: {lr_accuracy:.4f}")
    
    rf_model = RandomForestClassifier(random_state=42, n_estimators=100)
    rf_model.fit(X_train, y_train)
    rf_pred = rf_model.predict(X_test)
    rf_accuracy = accuracy_score(y_test, rf_pred)
    print(f"随机森林准确率: {rf_accuracy:.4f}")
    
    # 绘制性能对比
    methods = ['PPO RL', 'Logistic Regression', 'Random Forest']
    accuracies = [accuracy, lr_accuracy, rf_accuracy]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(methods, accuracies, color=['blue', 'green', 'orange'])
    plt.ylabel('Accuracy')
    plt.title('Classification Performance Comparison')
    plt.ylim(0, 1)
    
    for bar, acc in zip(bars, accuracies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{acc:.4f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('performance_comparison.png')
    plt.show()
    
    # 分析结果
    print(f"\n结果分析:")
    print(f"PPO强化学习准确率: {accuracy:.4f}")
    print(f"逻辑回归准确率: {lr_accuracy:.4f}")
    print(f"随机森林准确率: {rf_accuracy:.4f}")

if __name__ == "__main__":
    main()