DEV Community

Jeff Liu
Jeff Liu

Posted on

10 armed bandit

10-Armed Bandit Experiment

Bandit Testbed

import numpy as np
import matplotlib.pyplot as plt

class Bandit:
    def __init__(self, num_arms=10):
        # True action values q*(a) sampled from N(0,1)
        self.q_star = np.random.normal(0, 1, num_arms)

    def get_reward(self, action):
        # Reward sampled from N(q*(a),1) given an action
        return np.random.normal(self.q_star[action], 1)

    def optimal_action(self):
        return np.argmax(self.q_star)
Enter fullscreen mode Exit fullscreen mode

Agent (Reinforcement Learning Strategy)

class Agent:
    def __init__(self, num_arms=10, epsilon=0.1):
        self.epsilon = epsilon
        self.q_estimates = np.zeros(num_arms)  # Initialize Q-values to 0
        self.action_counts = np.zeros(num_arms)  # Track action selection counts

    def select_action(self):
        if np.random.rand() < self.epsilon:
            return np.random.randint(len(self.q_estimates))  # Random action (exploration)
        else:
            return np.argmax(self.q_estimates)  # Greedy action (exploitation)

    def update(self, action, reward):
        self.action_counts[action] += 1
        alpha = 1 / self.action_counts[action]  # Incremental sample averaging
        self.q_estimates[action] += alpha * (reward - self.q_estimates[action])
Enter fullscreen mode Exit fullscreen mode

Running a Single Test and Observing Q-Value Updates

num_steps = 1000
bandit = Bandit()
agent = Agent(epsilon=0.1)

rewards = []
optimal_action_counts = []

for step in range(num_steps):
    action = agent.select_action()
    reward = bandit.get_reward(action)
    agent.update(action, reward)

    rewards.append(reward)
    optimal_action_counts.append(action == bandit.optimal_action())

print("Final Q estimates:", agent.q_estimates)
print("True q* values:", bandit.q_star)

plt.plot(rewards)
plt.xlabel("Steps")
plt.ylabel("Reward")
plt.title("Reward over time")
plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Running 2000 Experiments and Calculating Average Reward

num_experiments = 2000
epsilons = [0, 0.01, 0.1]
num_arms = 10

avg_rewards = {eps: np.zeros(num_steps) for eps in epsilons}
optimal_action_pct = {eps: np.zeros(num_steps) for eps in epsilons}

for experiment in range(num_experiments):
    bandit = Bandit()

    for eps in epsilons:
        agent = Agent(num_arms=num_arms, epsilon=eps)
        optimal_action = bandit.optimal_action()

        for step in range(num_steps):
            action = agent.select_action()
            reward = bandit.get_reward(action)
            agent.update(action, reward)

            avg_rewards[eps][step] += reward
            optimal_action_pct[eps][step] += (action == optimal_action)

# Compute final averages
for eps in epsilons:
    avg_rewards[eps] /= num_experiments
    optimal_action_pct[eps] = (optimal_action_pct[eps] / num_experiments) * 100

print("Finished 2000 experiments!")

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for eps in epsilons:
    plt.plot(avg_rewards[eps], label=f'ε={eps}')
plt.xlabel("Steps")
plt.ylabel("Average Reward")
plt.legend()
plt.title("Average Reward vs Steps")

plt.subplot(1, 2, 2)
for eps in epsilons:
    plt.plot(optimal_action_pct[eps], label=f'ε={eps}')
plt.xlabel("Steps")
plt.ylabel("% Optimal Action")
plt.legend()
plt.title("Optimal Action Selection vs Steps")

plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Summary

  • Bandit Testbed: Implements a 10-armed bandit with true action values sampled from a normal distribution.
  • Agent (RL Strategy): Implements an epsilon-greedy action selection method with incremental Q-value updates.
  • Single Test Run: Demonstrates Q-value updates and reward trends over time.
  • Multiple Experiments (2000 runs): Compares average reward and optimal action selection percentages for different epsilon values (0, 0.01, 0.1).

Top comments (0)

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more