AI Based 6G Applications’ Security Mitigation: Defensive Distillation
How to build secure AI based 6G solutions for adversarial machine learning attacks
The design of security schemes for deep learning algorithms is critical for protecting wireless networks from adversarial attacks. However, there is no consensus about the best way to do this. In this blog, I present a security vulnerability in deep learning for beamforming prediction using deep neural networks (DNNs), and offer two methods for mitigating this vulnerability. Experimental results show that these methods are effective in defending DNN models against adversarial attacks in next-generation wireless networks.
Cite The Work
If you find those results useful please cite them :
@misc{kuzlu2022adversarial,
title={The Adversarial Security Mitigations of mmWave Beamforming Prediction Models using Defensive Distillation and Adversarial Retraining},
author={Murat Kuzlu and Ferhat Ozgur Catak and Umit Cali and Evren Catak and Ozgur Guler},
year={2022},
eprint={2202.08185},
archivePrefix={arXiv},
primaryClass={cs.CR}
}
1. Introduction
The first 5G standard was announced and approved by 3GPP in December 2017. The early standardization work on 5G is expected to provide a solid and stable foundation for the early adoption of 5G services. In addition, 5G will be essential for Internet of Things (IoT) applications and future mobile networks. There are many challenges in designing 5G networks, including a security scheme for beamforming prediction. It is a critical part of wireless networks, studied in communication systems and signal processing. Designing and implementing beamforming algorithms in next-generation (i.e., 6G) wireless networks is also crucial. In current wireless networks, deep learning (DL)-based beamforming prediction is vulnerable to adversarial machine learning attacks. Therefore, it is critical to design a security scheme for beamforming prediction in 6G networks.
6G is the latest wireless communication technology among cellular networks currently under development. In 6G solutions, artificial intelligence (AI)-based algorithms would be one of the main components of wireless communication systems to improve the overall system performance. The existing solutions in 5G would be migrated to the AI domain, more specifically into the DL area. Therefore, it is crucial to design secure DL solutions for the AI-based models in 6G wireless networks. The new attack surface and the existing 5G security problems are DL security vulnerabilities. Researchers and companies should mitigate their DL models’ security problems before deploying them to the production environments. They need to identify, document, and perform a risk assessment for new security threats in the next-generation wireless communication systems.
2. Defensive Distillation
Knowledge distillation was previously introduced by Hinton et al. to compress a large model into a smaller one. Papernot et al. proposed this technique for the adversarial machine learning defense against attacks. The defensive distillation mitigation method includes a larger teacher model and a compressed student model. The first step is to train the teacher model with a high temperature ($T$) parameter to soften the softmax probability outputs of the DNN model.
$i$th class and $z_{i}$ are the logits. The second step is to use the previously trained teacher model to obtain the soft labels of the training data. In this step, the teacher model predicts each of the samples in the training data using the same temperature ($T$) value, and the predictions are the labels (i.e., soft labels) for the training data to train the student model. The student model is trained with the soft labels acquired from the teacher model, again with a high $T$ value in the softmax. After the student model’s training phase, the $T$ parameter is set to 1 during the prediction time of the student model.
Figure 1 shows the overall steps for this technique.
3. Let’s get coding
I will show the our proposed approach with Python language. I import the usual standard libraries to build a deep learning model to predict RF beamforming codeword.
!pip install -q cleverhans
!pip install -q plot_keras_history
!pip install -q loguru
import util
from plot_keras_history import plot_history
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import mean_squared_error
import pandas as pd
import seaborn as sns
SCENARIO_NAME = 'O1'
nb_epoch = 2000
batch_size = 1000
num_tot_TX=4
num_beams=512
loss_fn='mean_squared_error'
DS_PATH = '/Users/ozgur/Documents/Workshop/DeepMIMO_Dataset_Generation_v1.1/DLCB_dataset/'
In this work, I will use publicly avaliable DeepMIMO dataset to attack RF beam forming prediction model. I will use FGSM attack, and Mean Squared Error loss function for the input manipulation. My FGSM attack implementation is here. In this scenario there are 4 basestations.
In_train, In_test, Out_train, Out_test = util.get_dataset(SCENARIO_NAME)
AP_models,AP_models_history = util.train(In_train, Out_train, In_test, Out_test,
nb_epoch, batch_size,loss_fn=loss_fn,
n_BS=num_tot_TX, n_beams=num_beams,
sc_name=SCENARIO_NAME)
for i in range(num_tot_TX):
plot_history(AP_models_history[i].history, side=3)
plt.show()0%| | 0/4 [00:00<?, ?it/s]
Restoring model weights from the end of the best epoch: 91.
Epoch 00111: early stopping
Restoring model weights from the end of the best epoch: 106.
Epoch 00126: early stopping
Restoring model weights from the end of the best epoch: 96.
Epoch 00116: early stopping
Restoring model weights from the end of the best epoch: 57.
Epoch 00077: early stopping
Adversarial Machine Learning Attacks
In this step, we will generate malicious uplink signals using Fast-Gradient Sign Method (FGSM), Basic Iterative Method (BIM), Momentum Iterative Method (MIM), Projected Gradient Descent (PGD) attacks.
FGSM (Fast-Gradient Sign Method)
epsilon_vals = [0.1, 0.2, 0.3]
fgm_eps_mse = []
fgsm_all_mse = []
fgsm_all_eps = []
for eps_val in tqdm(epsilon_vals):
fgsm_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_models[i_bs],
attack_name = 'FGSM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
fgsm_mse_malicious.append(tmp)
fgsm_all_mse.append(tmp)
fgsm_all_eps.append(eps_val)
fgm_eps_mse.append(np.mean(fgsm_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]
WARNING:tensorflow:5 out of the last 5 calls to <function compute_gradient at 0x7f9d7ba1eb80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function compute_gradient at 0x7f9d7ba1eb80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
PGD (Projected Gradient Descent)
pgd_eps_mse = []
pgd_all_mse = []
for eps_val in tqdm(epsilon_vals):
pgd_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_models[i_bs],
attack_name = 'PGD',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
pgd_mse_malicious.append(tmp)
pgd_all_mse.append(tmp)
pgd_eps_mse.append(np.mean(pgd_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]
BIM (Basic Iterative Method)
bim_eps_mse = []
bim_all_mse = []
for eps_val in tqdm(epsilon_vals):
bim_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_models[i_bs],
attack_name = 'BIM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
bim_mse_malicious.append(tmp)
bim_all_mse.append(tmp)
bim_eps_mse.append(np.mean(bim_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]
MIM (Momentum Iterative Method)
mim_eps_mse = []
mim_all_mse = []
for eps_val in tqdm(epsilon_vals):
mim_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_models[i_bs],
attack_name = 'MIM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
mim_mse_malicious.append(tmp)
mim_all_mse.append(tmp)
mim_eps_mse.append(np.mean(mim_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]df_avg_mse = pd.DataFrame({'epsilon_vals':epsilon_vals,'fgm_eps_mse':fgm_eps_mse,
'pgd_eps_mse':pgd_eps_mse,'bim_eps_mse':bim_eps_mse,
'mim_eps_mse':mim_eps_mse})
df_avg_mse.to_csv('df_avg_mse.csv', index=False)
plt.figure(figsize=(10,3))
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='fgm_eps_mse', ls='--', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='pgd_eps_mse', ls='-.', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='bim_eps_mse', ls='-', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='mim_eps_mse', ls='-.', lw=2)
plt.ylabel('MSE', fontsize=20)
plt.xlabel(r'$\epsilon$ values', fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
plt.legend(['FGSM','PGD','BIM','MIM'],shadow=True, ncol=2,fontsize=16)
plt.show()
Mean Square Error Distributions
MSE distribution histogram of the each attack. There is a strong negative correlation between $\epsilon$ and DL model’s prediction performance. The confidence interval value of the correlations (i.e., $p$-value) is 0. The $p$-value is the probability connected to the likelihood of acquiring the correlation result.
df_fgsm_res = pd.DataFrame({'epsilon':fgsm_all_eps,
'FGSM':fgsm_all_mse,'PGD':pgd_all_mse,
'BIM':bim_all_mse,'MIM':mim_all_mse})
df_fgsm_res.to_csv('df_fgsm_res.csv', index=False)
color_codes = ['r','g','b']
attack_types = ['FGSM','PGD','BIM','MIM']
fig,ax = plt.subplots(1, 4, figsize=(20, 5), sharey=True)
ax_idx = 0
for attack_type in attack_types:
for tmp_eps_val, c_name in zip(epsilon_vals,color_codes):
df_fgsm_res_tmp = df_fgsm_res.query('epsilon == ' + str(tmp_eps_val) )
sns.histplot(data=df_fgsm_res_tmp, x=attack_type, element="poly", cumulative=False,
lw=2,stat='count',log_scale=(False,False),
ax=ax[ax_idx], color=c_name)
ax[ax_idx].set_xlabel(attack_type, fontsize=20)
ax[ax_idx].set_ylabel(r'Count', fontsize=20)
ax[ax_idx].legend(epsilon_vals,shadow=False,ncol=1,fontsize=14)
ax[ax_idx].grid()
ax_idx += 1
plt.show()
Defensive Distillation
AP_distilled_models = util.train_distill(In_train, Out_train, In_test, Out_test,
nb_epoch, batch_size,loss_fn=loss_fn,
n_BS=num_tot_TX, n_beams=num_beams,
teacher_multiplication=3.0,student_multiplication=0.1,
distill_alpha=0.1, distill_temp=100,
sc_name=SCENARIO_NAME)0%| | 0/4 [00:00<?, ?it/s]
Attack to the distilled models
We will perform the same attacks for the distilled models. You can easily see from the plots that the distilled models are more robust to the attacks.
epsilon_vals = [0.1, 0.2, 0.3]
fgm_eps_mse = []
fgsm_all_mse = []
fgsm_all_eps = []
for eps_val in tqdm(epsilon_vals):
fgsm_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_distilled_models[i_bs],
attack_name = 'FGSM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_distilled_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
fgsm_mse_malicious.append(tmp)
fgsm_all_mse.append(tmp)
fgsm_all_eps.append(eps_val)
fgm_eps_mse.append(np.mean(fgsm_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]pgd_eps_mse = []
pgd_all_mse = []
for eps_val in tqdm(epsilon_vals):
pgd_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_distilled_models[i_bs],
attack_name = 'PGD',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_distilled_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
pgd_mse_malicious.append(tmp)
pgd_all_mse.append(tmp)
pgd_eps_mse.append(np.mean(pgd_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]bim_eps_mse = []
bim_all_mse = []
for eps_val in tqdm(epsilon_vals):
bim_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_distilled_models[i_bs],
attack_name = 'BIM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_distilled_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
bim_mse_malicious.append(tmp)
bim_all_mse.append(tmp)
bim_eps_mse.append(np.mean(bim_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]mim_eps_mse = []
mim_all_mse = []
for eps_val in tqdm(epsilon_vals):
mim_mse_malicious = []
for i_bs in range(num_tot_TX):
adv_inputs = util.attack_models(model = AP_distilled_models[i_bs],
attack_name = 'MIM',
eps_val = eps_val,
testset = In_test,
norm = np.inf)
adv_inputs_pred = AP_distilled_models[i_bs].predict(adv_inputs)
for j in range(adv_inputs_pred.shape[0]):
tmp = mean_squared_error(Out_test[j,i_bs*num_beams:(i_bs+1)*num_beams],
adv_inputs_pred[j,:])
mim_mse_malicious.append(tmp)
mim_all_mse.append(tmp)
mim_eps_mse.append(np.mean(mim_mse_malicious))0%| | 0/3 [00:00<?, ?it/s]df_avg_mse = pd.DataFrame({'epsilon_vals':epsilon_vals,'fgm_eps_mse':fgm_eps_mse,
'pgd_eps_mse':pgd_eps_mse,'bim_eps_mse':bim_eps_mse,
'mim_eps_mse':mim_eps_mse})
df_avg_mse.to_csv('df_avg_mse.csv', index=False)
plt.figure(figsize=(10,3))
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='fgm_eps_mse', ls='--', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='pgd_eps_mse', ls='-.', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='bim_eps_mse', ls='-', lw=2)
sns.lineplot(data=df_avg_mse, x='epsilon_vals', y='mim_eps_mse', ls='-.', lw=2)
plt.ylabel('MSE', fontsize=20)
plt.xlabel(r'$\epsilon$ values', fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.grid(True)
plt.legend(['FGSM','PGD','BIM','MIM'],shadow=True, ncol=2,fontsize=16)
plt.show()
df_fgsm_res = pd.DataFrame({'epsilon':fgsm_all_eps,
'FGSM':fgsm_all_mse,'PGD':pgd_all_mse,
'BIM':bim_all_mse,'MIM':mim_all_mse})
df_fgsm_res.to_csv('df_fgsm_res.csv', index=False)
color_codes = ['r','g','b']
attack_types = ['FGSM','PGD','BIM','MIM']
fig,ax = plt.subplots(1, 4, figsize=(20, 5), sharey=True)
ax_idx = 0
for attack_type in attack_types:
for tmp_eps_val, c_name in zip(epsilon_vals,color_codes):
df_fgsm_res_tmp = df_fgsm_res.query('epsilon == ' + str(tmp_eps_val) )
sns.histplot(data=df_fgsm_res_tmp, x=attack_type, element="poly", cumulative=False,
lw=2,stat='count',log_scale=(False,False),
ax=ax[ax_idx], color=c_name)
ax[ax_idx].set_xlabel(attack_type, fontsize=20)
ax[ax_idx].set_ylabel(r'Count', fontsize=20)
ax[ax_idx].legend(epsilon_vals,shadow=False,ncol=1,fontsize=14)
ax[ax_idx].grid()
ax_idx += 1
plt.show()
4. Conculusion
This blog presents a security scheme for RF beamforming prediction models’ vulnerabilities and their mitigation techniques. The experiments were performed with the selected DeepMIMO scenarios to investigate these questions. The results confirm that the original DL-based beamforming model is significantly vulnerable to FGSM, PGD, BIM, and MIM attacks. The MSE value increases in all three scenarios under a heavy BIM adversarial attack. On the other hand, the results show that the proposed mitigation methods successfully increase the RF beamforming prediction performance.