12,900
社区成员
发帖
与我相关
我的任务
分享在人工智能项目开发中,数据管理、模型训练和结果存储往往需要无缝集成。MindSpore作为华为开源的深度学习框架,与MySQL这样的关系型数据库结合,可以构建强大的端到端AI解决方案。本文将详细介绍如何在MindSpore中集成MySQL,实现数据读取、训练监控和结果存储的全流程管理。
# 安装必要的依赖 pip install mindspore==2.0.0 pip install pymysql pip install sqlalchemy pip install pandas
-- 创建数据库和用户 CREATE DATABASE mindspore_ai; CREATE USER 'ai_user'@'%' IDENTIFIED BY 'AIPassword123!'; GRANT ALL PRIVILEGES ON mindspore_ai.* TO 'ai_user'@'%'; FLUSH PRIVILEGES; -- 使用数据库 USE mindspore_ai; -- 创建训练数据表 CREATE TABLE training_data ( id INT AUTO_INCREMENT PRIMARY KEY, feature1 FLOAT NOT NULL, feature2 FLOAT NOT NULL, feature3 FLOAT NOT NULL, feature4 FLOAT NOT NULL, label INT NOT NULL, created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, data_source VARCHAR(100) ); -- 创建训练任务表 CREATE TABLE training_tasks ( task_id VARCHAR(50) PRIMARY KEY, model_name VARCHAR(100) NOT NULL, status VARCHAR(20) DEFAULT 'pending', start_time TIMESTAMP NULL, end_time TIMESTAMP NULL, total_epochs INT, current_epoch INT DEFAULT 0, train_loss FLOAT, val_accuracy FLOAT, created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); -- 创建训练日志表 CREATE TABLE training_logs ( log_id INT AUTO_INCREMENT PRIMARY KEY, task_id VARCHAR(50), epoch INT, step INT, loss FLOAT, accuracy FLOAT, learning_rate FLOAT, log_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) ); -- 创建模型结果表 CREATE TABLE model_results ( result_id INT AUTO_INCREMENT PRIMARY KEY, task_id VARCHAR(50), model_path VARCHAR(255), test_accuracy FLOAT, test_loss FLOAT, inference_time FLOAT, saved_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES training_tasks(task_id) );
# db_manager.py
import pymysql
import pandas as pd
from sqlalchemy import create_engine
import logging
from contextlib import contextmanager
class DatabaseManager:
def __init__(self, host='localhost', user='ai_user',
password='AIPassword123!', database='mindspore_ai'):
self.db_config = {
'host': host,
'user': user,
'password': password,
'database': database,
'charset': 'utf8mb4'
}
self.engine = create_engine(
f"mysql+pymysql://{user}:{password}@{host}/{database}"
)
@contextmanager
def get_connection(self):
"""获取数据库连接(上下文管理器)"""
conn = None
try:
conn = pymysql.connect(**self.db_config)
yield conn
except Exception as e:
logging.error(f"Database connection failed: {e}")
raise
finally:
if conn:
conn.close()
def execute_query(self, query, params=None):
"""执行查询语句"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, params or ())
result = cursor.fetchall()
cursor.close()
return result
def execute_update(self, query, params=None):
"""执行更新语句"""
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, params or ())
conn.commit()
affected_rows = cursor.rowcount
cursor.close()
return affected_rows
def load_training_data(self, table_name='training_data', limit=None):
"""从数据库加载训练数据"""
query = f"SELECT * FROM {table_name}"
if limit:
query += f" LIMIT {limit}"
with self.get_connection() as conn:
df = pd.read_sql(query, conn)
return df
def save_training_task(self, task_data):
"""保存训练任务信息"""
query = """
INSERT INTO training_tasks
(task_id, model_name, total_epochs, status)
VALUES (%s, %s, %s, %s)
"""
return self.execute_update(query, (
task_data['task_id'],
task_data['model_name'],
task_data['total_epochs'],
task_data.get('status', 'pending')
))
def update_training_progress(self, task_id, epoch, loss, accuracy):
"""更新训练进度"""
query = """
UPDATE training_tasks
SET current_epoch = %s, train_loss = %s, val_accuracy = %s
WHERE task_id = %s
"""
return self.execute_update(query, (epoch, loss, accuracy, task_id))
# data_loader.py
import mindspore as ms
import mindspore.dataset as ds
from mindspore import Tensor
import numpy as np
from db_manager import DatabaseManager
class MySQLDataLoader:
def __init__(self, db_manager, batch_size=32):
self.db_manager = db_manager
self.batch_size = batch_size
def create_dataset(self, table_name='training_data', split_ratio=0.8):
"""从MySQL创建MindSpore数据集"""
# 加载数据
df = self.db_manager.load_training_data(table_name)
# 分离特征和标签
features = df[['feature1', 'feature2', 'feature3', 'feature4']].values
labels = df['label'].values
# 数据集分割
split_idx = int(len(features) * split_ratio)
train_features, test_features = features[:split_idx], features[split_idx:]
train_labels, test_labels = labels[:split_idx], labels[split_idx:]
# 创建MindSpore数据集
train_dataset = ds.NumpySlicesDataset(
(train_features, train_labels),
column_names=['features', 'labels']
)
test_dataset = ds.NumpySlicesDataset(
(test_features, test_labels),
column_names=['features', 'labels']
)
# 数据预处理
train_dataset = self._preprocess_dataset(train_dataset)
test_dataset = self._preprocess_dataset(test_dataset)
return train_dataset, test_dataset
def _preprocess_dataset(self, dataset):
"""数据集预处理"""
dataset = dataset.batch(self.batch_size, drop_remainder=True)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.repeat(1)
return dataset
def get_data_statistics(self, table_name='training_data'):
"""获取数据统计信息"""
df = self.db_manager.load_training_data(table_name)
stats = {
'total_samples': len(df),
'feature_means': df[['feature1', 'feature2', 'feature3', 'feature4']].mean().to_dict(),
'feature_stds': df[['feature1', 'feature2', 'feature3', 'feature4']].std().to_dict(),
'label_distribution': df['label'].value_counts().to_dict()
}
# 保存统计信息到数据库
self._save_data_statistics(stats, table_name)
return stats
def _save_data_statistics(self, stats, table_name):
"""保存数据统计信息到数据库"""
query = """
INSERT INTO data_statistics
(table_name, total_samples, feature_means, feature_stds, label_distribution)
VALUES (%s, %s, %s, %s, %s)
"""
self.db_manager.execute_update(query, (
table_name,
stats['total_samples'],
str(stats['feature_means']),
str(stats['feature_stds']),
str(stats['label_distribution'])
))
# training_monitor.py
import mindspore as ms
from mindspore import TrainOneStepCell, LossMonitor
import logging
import time
from db_manager import DatabaseManager
class MySQLTrainingMonitor:
def __init__(self, db_manager, task_id):
self.db_manager = db_manager
self.task_id = task_id
self.start_time = None
def on_train_begin(self):
"""训练开始回调"""
self.start_time = time.time()
logging.info(f"Training task {self.task_id} started")
# 更新任务状态为运行中
self.db_manager.execute_update(
"UPDATE training_tasks SET status = 'running', start_time = NOW() WHERE task_id = %s",
(self.task_id,)
)
def on_epoch_end(self, epoch, loss, accuracy, learning_rate):
"""epoch结束回调"""
# 记录训练日志
self.db_manager.execute_update("""
INSERT INTO training_logs
(task_id, epoch, loss, accuracy, learning_rate)
VALUES (%s, %s, %s, %s, %s)
""", (self.task_id, epoch, float(loss), float(accuracy), float(learning_rate)))
# 更新任务进度
self.db_manager.update_training_progress(
self.task_id, epoch, float(loss), float(accuracy)
)
logging.info(f"Epoch {epoch}: loss={loss:.4f}, accuracy={accuracy:.4f}")
def on_train_end(self, model, test_accuracy, test_loss, model_path):
"""训练结束回调"""
end_time = time.time()
training_time = end_time - self.start_time
# 保存模型结果
self.db_manager.execute_update("""
INSERT INTO model_results
(task_id, model_path, test_accuracy, test_loss, inference_time)
VALUES (%s, %s, %s, %s, %s)
""", (self.task_id, model_path, test_accuracy, test_loss, training_time))
# 更新任务状态为完成
self.db_manager.execute_update("""
UPDATE training_tasks
SET status = 'completed', end_time = NOW(), val_accuracy = %s
WHERE task_id = %s
""", (test_accuracy, self.task_id))
logging.info(f"Training task {self.task_id} completed in {training_time:.2f}s")
class CustomLossMonitor(LossMonitor):
"""自定义损失监控器,集成MySQL日志"""
def __init__(self, mysql_monitor, per_print_times=1):
super().__init__(per_print_times)
self.mysql_monitor = mysql_monitor
self.current_epoch = 0
def step_end(self, run_context):
"""每一步结束回调"""
super().step_end(run_context)
cb_params = run_context.original_args()
# 这里可以添加自定义的日志逻辑
# 注意:实际项目中需要根据具体训练循环调整
# train_with_mysql.py
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Model, context
from mindspore.common.initializer import Normal
import uuid
from db_manager import DatabaseManager
from data_loader import MySQLDataLoader
from training_monitor import MySQLTrainingMonitor
# 定义简单的神经网络
class SimpleNN(nn.Cell):
def __init__(self, input_size=4, hidden_size=64, num_classes=3):
super(SimpleNN, self).__init__()
self.fc1 = nn.Dense(input_size, hidden_size, weight_init=Normal(0.02))
self.fc2 = nn.Dense(hidden_size, hidden_size, weight_init=Normal(0.02))
self.fc3 = nn.Dense(hidden_size, num_classes, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.dropout = nn.Dropout(keep_prob=0.5)
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
class MySQLIntegratedTrainer:
def __init__(self, config):
self.config = config
self.db_manager = DatabaseManager(**config['database'])
self.data_loader = MySQLDataLoader(self.db_manager, config['batch_size'])
# 生成训练任务ID
self.task_id = f"task_{uuid.uuid4().hex[:8]}"
self.monitor = MySQLTrainingMonitor(self.db_manager, self.task_id)
def prepare_training_task(self):
"""准备训练任务"""
# 注册训练任务
task_data = {
'task_id': self.task_id,
'model_name': self.config['model_name'],
'total_epochs': self.config['epochs']
}
self.db_manager.save_training_task(task_data)
# 获取数据统计
stats = self.data_loader.get_data_statistics()
print(f"Data statistics: {stats}")
def train(self):
"""执行训练流程"""
# 设置运行环境
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# 准备训练任务
self.prepare_training_task()
# 加载数据
train_dataset, test_dataset = self.data_loader.create_dataset()
# 定义模型
network = SimpleNN()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Adam(network.trainable_params(),
learning_rate=self.config['learning_rate'])
# 定义模型
model = Model(network, loss_fn, optimizer, metrics={'accuracy'})
# 开始训练监控
self.monitor.on_train_begin()
# 训练模型
print("Starting training...")
model.train(
epoch=self.config['epochs'],
train_dataset=train_dataset,
callbacks=[self.monitor],
dataset_sink_mode=False
)
# 评估模型
print("Evaluating model...")
test_result = model.eval(test_dataset)
test_accuracy = test_result['accuracy']
test_loss = test_result['loss']
# 保存模型
model_path = f"./models/{self.task_id}.ckpt"
ms.save_checkpoint(network, model_path)
# 结束训练
self.monitor.on_train_end(model, test_accuracy, test_loss, model_path)
print(f"Training completed! Test accuracy: {test_accuracy:.4f}")
return test_accuracy
# 配置参数
config = {
'database': {
'host': 'localhost',
'user': 'ai_user',
'password': 'AIPassword123!',
'database': 'mindspore_ai'
},
'model_name': 'SimpleNN',
'batch_size': 32,
'epochs': 10,
'learning_rate': 0.001
}
if __name__ == "__main__":
# 创建训练器并开始训练
trainer = MySQLIntegratedTrainer(config)
accuracy = trainer.train()
print(f"Final model accuracy: {accuracy:.4f}")
# data_generator.py
import numpy as np
from db_manager import DatabaseManager
def generate_sample_data(num_samples=1000):
"""生成示例训练数据"""
db_manager = DatabaseManager()
# 生成随机数据(模拟真实场景)
np.random.seed(42)
for i in range(num_samples):
# 生成特征
features = np.random.normal(0, 1, 4)
# 基于特征生成标签(简单逻辑)
if features[0] + features[1] > 0:
label = 0
elif features[2] * features[3] > 0.5:
label = 1
else:
label = 2
# 插入数据库
query = """
INSERT INTO training_data
(feature1, feature2, feature3, feature4, label, data_source)
VALUES (%s, %s, %s, %s, %s, %s)
"""
db_manager.execute_update(query, (
float(features[0]),
float(features[1]),
float(features[2]),
float(features[3]),
label,
'synthetic'
))
print(f"Generated {num_samples} sample data records")
if __name__ == "__main__":
generate_sample_data(1000)
# result_analyzer.py
import pandas as pd
import matplotlib.pyplot as plt
from db_manager import DatabaseManager
class TrainingResultAnalyzer:
def __init__(self, db_manager):
self.db_manager = db_manager
def get_training_history(self, task_id):
"""获取训练历史"""
query = """
SELECT epoch, loss, accuracy, learning_rate, log_time
FROM training_logs
WHERE task_id = %s
ORDER BY epoch, log_time
"""
df = pd.read_sql(query, self.db_manager.engine, params=[task_id])
return df
def plot_training_curves(self, task_id):
"""绘制训练曲线"""
df = self.get_training_history(task_id)
if df.empty:
print("No training data found for task:", task_id)
return
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# 损失曲线
ax1.plot(df['epoch'], df['loss'], 'b-', label='Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True)
# 准确率曲线
ax2.plot(df['epoch'], df['accuracy'], 'r-', label='Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig(f'training_curves_{task_id}.png')
plt.show()
def compare_tasks(self, task_ids):
"""比较多个训练任务"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
for task_id in task_ids:
df = self.get_training_history(task_id)
if not df.empty:
ax1.plot(df['epoch'], df['loss'], label=f'Task {task_id}')
ax2.plot(df['epoch'], df['accuracy'], label=f'Task {task_id}')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Loss Comparison')
ax1.legend()
ax1.grid(True)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Accuracy Comparison')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig('task_comparison.png')
plt.show()
# 使用示例
if __name__ == "__main__":
db_manager = DatabaseManager()
analyzer = TrainingResultAnalyzer(db_manager)
# 获取最近的任务
tasks = db_manager.execute_query(
"SELECT task_id FROM training_tasks ORDER BY created_time DESC LIMIT 2"
)
if tasks:
task_ids = [task[0] for task in tasks]
analyzer.plot_training_curves(task_ids[0])
analyzer.compare_tasks(task_ids)
# distributed_training.py
from mindspore.communication import init, get_rank, get_group_size
from db_manager import DatabaseManager
class DistributedMySQLTrainer:
def __init__(self, config):
self.config = config
self.db_manager = DatabaseManager(**config['database'])
# 初始化分布式训练
init()
self.rank = get_rank()
self.group_size = get_group_size()
def distributed_data_loading(self):
"""分布式数据加载"""
# 每个rank加载不同的数据分片
total_data = self.db_manager.execute_query(
"SELECT COUNT(*) FROM training_data"
)[0][0]
chunk_size = total_data // self.group_size
offset = self.rank * chunk_size
query = f"SELECT * FROM training_data LIMIT {chunk_size} OFFSET {offset}"
df = pd.read_sql(query, self.db_manager.engine)
# 转换为MindSpore数据集
# ... 数据转换逻辑
return dataset
# model_serving.py
import mindspore as ms
from flask import Flask, request, jsonify
from db_manager import DatabaseManager
app = Flask(__name__)
db_manager = DatabaseManager()
class ModelServer:
def __init__(self):
self.loaded_models = {}
def load_model(self, task_id):
"""从数据库加载模型"""
# 查询模型路径
result = db_manager.execute_query(
"SELECT model_path FROM model_results WHERE task_id = %s",
(task_id,)
)
if not result:
return None
model_path = result[0][0]
model = ms.load_checkpoint(model_path)
# 创建网络并加载参数
network = SimpleNN()
ms.load_param_into_net(network, model)
self.loaded_models[task_id] = network
return network
model_server = ModelServer()
@app.route('/predict', methods=['POST'])
def predict():
"""预测接口"""
data = request.json
task_id = data.get('task_id')
features = data.get('features')
if task_id not in model_server.loaded_models:
model_server.load_model(task_id)
model = model_server.loaded_models[task_id]
# 执行预测
input_tensor = ms.Tensor(features, dtype=ms.float32)
output = model(input_tensor)
prediction = ms.ops.argmax(output, axis=1).asnumpy().tolist()
# 记录预测日志
db_manager.execute_update(
"INSERT INTO prediction_logs (task_id, features, prediction) VALUES (%s, %s, %s)",
(task_id, str(features), str(prediction))
)
return jsonify({'prediction': prediction})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
通过本文的实战案例,我们实现了MindSpore与MySQL的深度集成,主要成果包括:
这种集成方案的优势:
这种模式特别适合需要严格数据管理、训练过程监控和结果分析的企业级AI应用场景。