Learning Rate Decay Algorithm Based on Mutual Information in Deep Learning

Shoujin Wang, Fan Wang, Yu Zhang

2024 International Conference on Distributed Computing and Optimization Techniques (ICDCOT)(2024)

Cited 0|Views2
No score
Abstract
The gradient descent method is a commonly used optimization algorithm in deep learning models,which achieves model tuning by adjusting parameters.In the high-dimensional non-convex loss function,the learning rate is used as the key hyperparameter in the gradient descent method to control the learning progress of the network model.If the learning rate is too large,the loss function will fail to converge.If the learning rate is too small,the convergence will be slow and It is easy to get stuck in a local minimum.An inappropriate learning rate algorithm will lead to problems such as incorrect update direction during model training,high variance of updated parameters,and sharp fluctuations in loss functions.To solve such problems,this paper improves a stochastic gradient descent algorithm with mutual information (MI) driven adaptive optimization of learning rate based on deep neural network (DNN).The algorithm firstly measures the learning rate range,and then uses the mutual information changes between the neural network output results and the actual real results in different training stages to attenuate the learning rate,so that the learning rate is automatically adjusted in each cycle.Adapt to adjust.The experimental results show that the improved algorithm performs well on multiple classic data sets.Compared with the conventional learning rate optimization algorithms Adam,SGD,CLR and other algorithms,it has improved related performance indicators and can be trained in a shorter training period.converges to the optimal value.
More
Translated text
Key words
deep learning,image classification,gradient descent,mutual information
AI Read Science
Must-Reading Tree
Example
Generate MRT to find the research sequence of this paper
Chat Paper
Summary is being generated by the instructions you defined