Vaibhav Mali's other Models Reports

Major Concepts


Sign-Up/Login to access Several ML Models and also Deploy & Monetize your own ML solutions for free

Models Home » Domain Usecases » Health Care and Pharmaceuticals » ECG Arrhythmia Classification Using Random Forest

ECG Arrhythmia Classification Using Random Forest

Models Status

Model Overview

An arrhythmia is an abnormal heart rhythm. Some arrhythmias can cause problems with contractions of your heart chambers by: Not allowing the lower chambers (ventricles) to fill with enough blood, because an abnormal electrical signal is causing your heart to pump too fast or too slow.

The dataset contains features extracted two-lead ECG signal (lead II, V) from the MIT-BIH Arrhythmia dataset (Physionet). In addition, we have programmatically extracted relevant features from ECG signals to classify regular/irregular heartbeats. The dataset can be used to classify heartbeats for arrhythmia detection.

There are four ECG arrhythmia datasets in here, each employing 2-lead ECG features. Datasets obtained from PhysioNet are MIT-BIH Supraventricular Arrhythmia DatabaseMIT-BIH Arrhythmia DatabaseSt Petersburg INCART 12-lead Arrhythmia Database, and Sudden Cardiac Death Holter Database.

In each of the datasets, the first column, named "record" is the name of the subject/patient.

Each data contain five classes/categories: N (Normal), S (Supraventricular ectopic beat), V (Ventricular ectopic beat), F (Fusion beat), and Q (Unknown beat). The column "type" contains the class information.

The remaining 34 columns contain 17 features for each ECG lead (17 features for lead-II and 17 features for lead-V5)

Link Of Dataset:

Importing The Necessary Library

import pickle

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams.update({'figure.figsize': (12.0, 8.0)})
plt.rcParams.update({'font.size': 14})

#default theme
sns.set(context='notebook', style='darkgrid', palette='colorblind', font='sans-serif', font_scale=1, rc=None)
matplotlib.rcParams['figure.figsize'] =[8,8]
matplotlib.rcParams.update({'font.size': 15})
matplotlib.rcParams[''] = 'sans-serif'​

Read Dataset and check first 5 rows of the dataset

# Reading MIT-BIH Arrhythmia Dataset as an example
data_df = pd.read_csv('MIT-BIH Arrhythmia Database.csv', low_memory=False)

print("Number of rows in data =",data_df.shape[0])
print("Number of columns in data =",data_df.shape[1])
print("**Sample data:**")

Description Of The Dataset


Checking for Null values


 Split the data into features and class labels

x_data = data_df.drop(['type','record'], axis = 1)
y_label = data_df[['type']]


Replacing Multiple Target Variable into Binary Class

# Transform multi-class labels into binary-class (arrhythmia and normal)
y_label.replace(['VEB','SVEB','F','Q'], 'arrhythmia', inplace=True)
y_label.replace(['N'], 'normal', inplace=True)

data_df['type'] = data_df['type'].replace(['normal','arrhythmia'],[0,1], inplace=True)

 Train-test Split

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(x_data, y_label, random_state=101)

Feature Scaling

from sklearn.preprocessing import MinMaxScaler

min_max_scaler = MinMaxScaler()
X_train = min_max_scaler.fit_transform(X_train)
X_test = min_max_scaler.transform(X_test)


pickle.dump(min_max_scaler, open('min_max_scaler.pkl','wb'))


importances = model.feature_importances_
# Sort the feature importance in descending order
sorted_indices = np.argsort(importances)[::-1]

feat_labels = data_df.columns[2:]

for f in range(X_train.shape[1]):
print("%2d) %-*s %f" % (f + 1, 30,

Model training

from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(random_state=101, n_estimators=150), y_train.values.ravel())

Model testing
a)Accuracy of the model

from sklearn import metrics
y_pred = model.predict(X_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

b)Confusion Matrix

print("*** Confusion Matrix ***")
print(metrics.confusion_matrix(y_test, y_pred))

c)Classification Report

print("*** Classificstion Report ***")