DEV Community

es404020
es404020

Posted on

K-Fold Cross Validation In Machine Learning

Why K-Fold

Splitting data into training and test sets is a common challenge in machine learning. Typically, you decide on a fraction of the data for training and a fraction for testing. The issue here is that you want to maximize both sets: you need as much data as possible in the training set to achieve the best learning results, and you want a large test set to validate the model effectively. However, there's an inherent trade-off—every data point moved from the training set to the test set reduces the amount of data available for training.

This trade-off can be addressed with cross-validation.

K-fold cross-validation is a method where you divide the dataset into K equal-sized subsets, or "bins." For example, if you have 200 data points and you choose K = 10, each bin will contain 20 data points. Unlike traditional data splitting, where you have a single test set and a single training set, K-fold cross-validation involves running K separate learning experiments.

In each experiment, you select one of the K bins as the test set and use the remaining K-1 bins as the training set. You then train your machine learning model on the training set and evaluate its performance on the test set. This process is repeated K times, each time with a different bin as the test set.

The key advantage of K-fold cross-validation is that it allows you to use all of your data for both training and testing. By averaging the performance across all K experiments, you obtain a more reliable estimate of your model's effectiveness. This approach requires more computational resources since you have to run multiple experiments, but it provides a more accurate and robust assessment of your model. In essence, K-fold cross-validation helps balance the trade-off between training and testing data, giving you the best of both worlds.

Thanks for reading!

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv3D, UpSampling3D, concatenate, Input
from tensorflow.keras.models import Model
from sklearn.model_selection import KFold
import numpy as np


X = np.random.rand(100, 128, 128, 128, 1) 
y = np.random.randint(0, 2, size=(100, 128, 128, 128, 1)) 


def build_resnet18d_unet(input_shape=(128, 128, 128, 1)):
    inputs = Input(shape=input_shape)


    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same', dilation_rate=2)(conv1)
    pool1 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)


    bottleneck = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(pool1)

    # Decoder - Using UpSampling3D for the segmentation task
    up1 = UpSampling3D(size=(2, 2, 2))(bottleneck)
    up1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up1)
    concat1 = concatenate([up1, conv1], axis=-1)

    outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(concat1)

    model = Model(inputs, outputs)
    return model

kf = KFold(n_splits=5)

fold_no = 1
for train_index, test_index in kf.split(X):
    print(f'Training on fold {fold_no}...')

    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]


    model = build_resnet18d_unet(input_shape=(128, 128, 128, 1))


    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    model.fit(X_train, y_train, epochs=10, batch_size=2, validation_data=(X_test, y_test))


    scores = model.evaluate(X_test, y_test, verbose=0)
    print(f'Score for fold {fold_no}: {model.metrics_names[1]} of {scores[1]*100}%')

    fold_no += 1
Enter fullscreen mode Exit fullscreen mode

Top comments (2)

Collapse
 
vortico profile image
Vortico

Hey, great post! We really enjoyed it. You might be interested in knowing how to productionalise ML models with a simple line of code. If so, please have a look at flama for Python. We published some time ago a post here which might be useful for you in the future: Introducing Flama for Robust ML APIs. If you have any doubts, or you'd like to learn more about it and how it works in more detail, don't hesitate to give us a shout. And if you like it, please gift us a star ⭐ here.

Collapse
 
es404020 profile image
es404020

I would look into it .Thanks for sharing and great job