<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Akshay Ballal</title>
    <description>The latest articles on DEV Community by Akshay Ballal (@akshayballal).</description>
    <link>https://dev.to/akshayballal</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F1077423%2F2d712653-ad84-4734-9f44-35e0cb51d2d0.png</url>
      <title>DEV Community: Akshay Ballal</title>
      <link>https://dev.to/akshayballal</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/akshayballal"/>
    <language>en</language>
    <item>
      <title>Utility is all you need</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Tue, 07 Apr 2026 13:54:49 +0000</pubDate>
      <link>https://dev.to/akshayballal/utility-is-all-you-need-823</link>
      <guid>https://dev.to/akshayballal/utility-is-all-you-need-823</guid>
      <description>&lt;h2&gt;
  
  
  Closing the Agent Learning Loop with Utility-Ranked Memory
&lt;/h2&gt;

&lt;p&gt;Your agent failed on the same edge case last Tuesday. Someone noticed, tweaked the prompt, redeployed. This Tuesday it failed again: different input, same root cause. The prompt fix was too narrow. The trace was right there in your observability stack, but nothing connected that signal back to the agent's actual behavior.&lt;/p&gt;

&lt;p&gt;This is where most production agent systems sit today. Teams have traces. They have evals. What they don't have is a mechanism that turns failure signals into better retrieval on the next run.&lt;/p&gt;

&lt;p&gt;We built &lt;a href="https://reflect.starlight-search.com" rel="noopener noreferrer"&gt;Reflect&lt;/a&gt; to close that gap. It's an outcome-informed memory layer: it stores reflections from past runs, scores them by how useful they actually were, and re-ranks retrieval so your agent improves with every reviewed trace, not just every prompt rewrite.&lt;/p&gt;




&lt;h2&gt;
  
  
  The gap between evals and action
&lt;/h2&gt;

&lt;p&gt;Production agent stacks generally have three pieces: observability (traces), evaluation (pass/fail judgments), and the agent itself. The problem is these layers don't talk to each other.&lt;/p&gt;

&lt;p&gt;Your observability stack captures every tool call, every LLM completion, every exception. Your eval suite judges whether the final output was correct. But the agent that runs tomorrow? It starts from a blank slate. It has no access to &lt;em&gt;why&lt;/em&gt; yesterday's runs passed or failed. The eval signal just dies in a dashboard.&lt;/p&gt;

&lt;p&gt;What's missing is a system that consumes traces, absorbs eval outcomes, and converts both into retrievable guidance for future runs.&lt;/p&gt;

&lt;p&gt;Reflect sits between your evals and your agent. It treats traces not as passive audit logs but as training signal. When a review marks a trace as failed, Reflect doesn't just file it. It extracts a reflection, a compressed lesson about what went wrong, and stores it as a memory tied to that task type. When a similar task shows up, that reflection surfaces. The agent acts differently because it now has context about what not to do.&lt;/p&gt;

&lt;p&gt;The key idea is that memories become &lt;em&gt;outcome-addressable&lt;/em&gt;. You don't retrieve by keyword. You retrieve by semantic similarity to the current task, weighted by whether those memories have historically helped or hurt. The eval outcome becomes a first-class signal in retrieval ranking, not just a post-mortem footnote.&lt;/p&gt;




&lt;h2&gt;
  
  
  The manual improvement tax
&lt;/h2&gt;

&lt;p&gt;After talking to dozens of enterprise agent teams, we keep seeing the same pattern. They can tell when something went wrong. They can measure pass/fail rates. But improvement still happens outside the data loop:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;A developer reviews a failing trace, rewrites part of the prompt, redeploys.&lt;/li&gt;
&lt;li&gt;The team swaps to a more capable (and more expensive) model.&lt;/li&gt;
&lt;li&gt;Someone restructures the tool-call harness.&lt;/li&gt;
&lt;li&gt;In extreme cases, the team fine-tunes a custom model.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;These are all valid moves. But they share a weakness: they're coarse, slow, and global. A prompt change shifts behavior for every future request, whether or not the change was relevant. A model upgrade costs more across the board, even for tasks the cheaper model handled fine.&lt;/p&gt;

&lt;p&gt;None of these approaches directly answer the question every agent team actually cares about: &lt;em&gt;what from past experience helped this task succeed, and what made it fail?&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;Automated prompt optimization can reduce costs dramatically. &lt;a href="https://www.databricks.com/blog/building-state-art-enterprise-agents-90x-cheaper-automated-prompt-optimization" rel="noopener noreferrer"&gt;Databricks showed up to 90x savings&lt;/a&gt; while matching or exceeding hand-tuned performance. But even automated optimization treats the prompt as the unit of improvement. We think the unit should be smaller: the individual memory retrieved at inference time.&lt;/p&gt;




&lt;h2&gt;
  
  
  Why current agent memory systems stall
&lt;/h2&gt;

&lt;p&gt;Memory was supposed to solve this. In practice, most agent memory frameworks have focused on user continuity: preferences, profile facts, conversation history, personalization. That's useful for chat experiences, but it doesn't create a self-improving system for production agents.&lt;/p&gt;

&lt;p&gt;Here's a rough breakdown of the dominant approaches and where they fall short for outcome-based learning:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Approach&lt;/th&gt;
&lt;th&gt;What it stores&lt;/th&gt;
&lt;th&gt;Retrieval signal&lt;/th&gt;
&lt;th&gt;Learns from outcomes?&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Conversation buffer (LangChain, most frameworks)&lt;/td&gt;
&lt;td&gt;Raw chat history&lt;/td&gt;
&lt;td&gt;Recency&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Semantic memory (Mem0, LangMem)&lt;/td&gt;
&lt;td&gt;Extracted facts, preferences&lt;/td&gt;
&lt;td&gt;Embedding similarity&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Temporal knowledge graph (Zep/Graphiti)&lt;/td&gt;
&lt;td&gt;Entity relationships over time&lt;/td&gt;
&lt;td&gt;Graph traversal + recency&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Episodic + Reflexion&lt;/td&gt;
&lt;td&gt;Verbal self-reflections&lt;/td&gt;
&lt;td&gt;Similarity to current task&lt;/td&gt;
&lt;td&gt;Partially: reflections capture lessons, but retrieval isn't ranked by outcome&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Reflect&lt;/td&gt;
&lt;td&gt;Task-linked reflections with utility scores&lt;/td&gt;
&lt;td&gt;Similarity weighted by outcome-derived utility&lt;/td&gt;
&lt;td&gt;Yes: utility scores update after every reviewed run&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Systems like Mem0 get strong results on conversational benchmarks (26% improvement over OpenAI Memory on LoCoMo). But their retrieval is driven by semantic similarity and recency, not by whether a memory actually helped produce a good outcome. Reflexion (Shinn et al., NeurIPS 2023) introduced the important idea that agents can learn from verbal self-reflection, hitting 91% pass@1 on HumanEval. But Reflexion's retrieval doesn't incorporate a learned utility signal from downstream results.&lt;/p&gt;

&lt;p&gt;What production agents need is memory that's tied to tasks (not just users), updated by outcomes (not just recency), shaped by failure (not only similarity), and grounded in reflections about what actually happened.&lt;/p&gt;

&lt;p&gt;Without that, memory is a static index. It retrieves things that &lt;em&gt;sound&lt;/em&gt; similar, not things that have a track record of helping.&lt;/p&gt;




&lt;h2&gt;
  
  
  Utility: a credit score for memories
&lt;/h2&gt;

&lt;p&gt;Think of utility like a credit score. A credit score doesn't just record that you had a loan; it tracks whether you paid it back. Utility doesn't just record that a memory was retrieved; it tracks whether the run it participated in succeeded or failed.&lt;/p&gt;

&lt;p&gt;In Reflect, a memory carries:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The original task the agent was solving&lt;/li&gt;
&lt;li&gt;A reflection about what worked or failed&lt;/li&gt;
&lt;li&gt;The review outcome (pass, fail, or detailed feedback)&lt;/li&gt;
&lt;li&gt;A utility score that changes every time the memory is retrieved and the run is reviewed&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The ranking formula combines both signals:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;score = (1 - λ) × similarity + λ × utility
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Similarity tells you whether a memory is about the same kind of problem. Utility tells you whether that memory has historically helped. You need both. At &lt;code&gt;λ = 0&lt;/code&gt; retrieval is pure semantic search; at &lt;code&gt;λ = 1&lt;/code&gt; it ranks entirely by past outcomes.&lt;/p&gt;

&lt;p&gt;If a memory was retrieved and the run passed, its utility rises. If it was retrieved and the run failed, its utility falls. Over repeated runs, Reflect starts preferring memories that repeatedly contribute to successful outcomes and down-ranking the ones that keep showing up in failures.&lt;/p&gt;

&lt;p&gt;That's how the loop closes. The signal isn't going to a dashboard for a human to interpret. It feeds directly into retrieval ranking.&lt;/p&gt;




&lt;h2&gt;
  
  
  Storing reasoning, not just facts
&lt;/h2&gt;

&lt;p&gt;Most memory systems store facts: user preferences, conversation history, document chunks. These are static artifacts. Reflect stores something different: reasoning about outcomes.&lt;/p&gt;

&lt;p&gt;A fact memory says: "The user prefers dark mode." A reasoning memory says: "When processing a duplicate charge, check settlement status before issuing refunds because last time we skipped that check, the customer got two refunds."&lt;/p&gt;

&lt;p&gt;Reflect generates these by combining two inputs: the trace (what the agent actually did) and the review (whether it worked). The trace gives you the trajectory: tool calls, LLM completions, which memories influenced decisions. The review gives you the outcome signal.&lt;/p&gt;

&lt;p&gt;From those two inputs, Reflect constructs a reflection. Not a human-written note, but an LLM-generated compression of the lesson in that trace-review pair. What the agent tried, what went wrong, what should happen next time. The reflection gets embedded and stored as a memory, linked to the original task type.&lt;/p&gt;

&lt;p&gt;When your agent queries Reflect before a run, it doesn't get document snippets. It gets distilled experience: "Last time you saw a task like this, you retrieved memory X, took action Y, and it failed. Consider action Z instead."&lt;/p&gt;

&lt;p&gt;Over many runs, the memory store accumulates a layer of learned reasoning. Each memory carries not just the reflection text but a utility score tracking whether following that advice led to success. Retrieval becomes experience-weighted reasoning: the agent surfaces advice that has survived the test of multiple reviewed runs.&lt;/p&gt;




&lt;h2&gt;
  
  
  Where utility changes behavior: two examples
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Customer support: learning which resolution strategies work
&lt;/h3&gt;

&lt;p&gt;A support agent handles refund requests. Early on, Reflect retrieves a memory: "When a customer reports a duplicate charge, issue an immediate refund." The agent follows this and processes a refund, but the customer already had a pending settlement. The run gets reviewed as a failure: "Did not check existing settlement state before issuing refund."&lt;/p&gt;

&lt;p&gt;Two things happen. The utility of that "issue an immediate refund" memory drops. And Reflect stores a new reflection: "For duplicate charge complaints, check settlement status before initiating any refund." Next time a similar ticket comes in, the new reflection ranks higher because its utility starts neutral while the old one has been penalized. The agent now checks settlement state first.&lt;/p&gt;

&lt;p&gt;No prompt rewrite. No model swap. The retrieval layer learned from the outcome.&lt;/p&gt;

&lt;h3&gt;
  
  
  Code review: surfacing advice developers actually accept
&lt;/h3&gt;

&lt;p&gt;A code review agent suggests changes on PRs. Accepted suggestions count as a pass; dismissed ones as a fail.&lt;/p&gt;

&lt;p&gt;Over hundreds of PRs, Reflect learns which categories of advice are consistently useful. Memories about "add null checks before database lookups" accumulate high utility because reviewers almost always accept them. Memories about "rename this variable for clarity" accumulate lower utility because reviewers frequently dismiss them as noise.&lt;/p&gt;

&lt;p&gt;After a few weeks, the agent's suggestions shift. It still retrieves both types, but the ranking now favors the high-acceptance patterns. The agent naturally focuses on feedback that developers actually value.&lt;/p&gt;




&lt;h2&gt;
  
  
  Integration
&lt;/h2&gt;

&lt;p&gt;If you're using the Python SDK, the simplest way to wire in the loop is with &lt;code&gt;client.trace()&lt;/code&gt;. It retrieves memories on entry and records which memories were used when the trace is submitted.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;reflect_sdk&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;ReflectClient&lt;/span&gt;

&lt;span class="n"&gt;client&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;ReflectClient&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;base_url&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;https://api.starlight-search.com&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;api_key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;rf_live_...&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;project_id&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;support-agent&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Handle a duplicate refund request for an enterprise customer&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;ctx&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;messages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
        &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;ctx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;augmented_task&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
    &lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;my_llm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ctx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;augmented_task&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;assistant&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;response&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;

    &lt;span class="n"&gt;ctx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_output&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;trajectory&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;final_response&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;response&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pass&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gpt-5.4-mini&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That block handles the core work: query memories before the run, inject them into the task, submit the trace, attach the review result, and preserve the retrieved memory IDs so utility can update correctly.&lt;/p&gt;




&lt;h2&gt;
  
  
  What happens after a review
&lt;/h2&gt;

&lt;p&gt;Once a review comes in, Reflect does three things:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Creates a reflection from the task, trajectory, result, and feedback.&lt;/li&gt;
&lt;li&gt;Stores that reflection as a new memory with an initial utility score.&lt;/li&gt;
&lt;li&gt;Updates the utility of the memories that were retrieved for that run.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Every reviewed run leaves behind two kinds of value: a new reflection for future similar tasks and a better ordering of existing memories. This is what makes utility different from plain semantic search. Semantic search tells you what looks similar. Utility tells you what has earned trust.&lt;/p&gt;




&lt;h2&gt;
  
  
  Limitations and open questions
&lt;/h2&gt;

&lt;p&gt;Some things we haven't solved yet, and want to be upfront about.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Cold start.&lt;/strong&gt; A new project has no memories and no utility signal. The first N runs are pure semantic retrieval until enough reviews accumulate. We're exploring seeding strategies (importing reflections from related projects, bootstrapping from existing evals), but this is still active work.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Utility drift.&lt;/strong&gt; Task distributions change. A memory that was highly useful six months ago may be irrelevant or harmful today. We apply time-decay to utility scores, but the right decay schedule is task-dependent and hard to set in advance.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Review quality.&lt;/strong&gt; The whole loop depends on accurate reviews. Noisy pass/fail labels mean noisy utility scores. Automated review via judge models helps with throughput but introduces its own failure modes.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Lambda tuning.&lt;/strong&gt; The similarity-vs-utility balance (&lt;code&gt;λ&lt;/code&gt;) is a hyperparameter. We default to 0.5, but the right value depends on how much outcome signal your project has. We plan to make this adaptive.&lt;/p&gt;




&lt;h2&gt;
  
  
  Wrapping up
&lt;/h2&gt;

&lt;p&gt;The bottleneck in most agent systems isn't missing traces or missing evals. It's the missing bridge between observation and action. Utility is that bridge. it turns outcomes into a learned ranking signal so your agent retrieves what has worked, not just what sounds relevant.&lt;/p&gt;

&lt;p&gt;Once you have that, memory stops being a static retrieval layer and starts acting like experience.&lt;/p&gt;

&lt;p&gt;We're building Reflect to make this the default. Try it at &lt;a href="https://reflect.starlight-search.com" rel="noopener noreferrer"&gt;reflect.starlight-search.com&lt;/a&gt; and check out the full API reference in our &lt;a href="https://docs.starlight-search.com" rel="noopener noreferrer"&gt;docs&lt;/a&gt;.&lt;/p&gt;




&lt;p&gt;&lt;strong&gt;References&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Databricks Engineering Blog, "Building State-of-the-Art Enterprise Agents 90x Cheaper with Automated Prompt Optimization," 2025. &lt;a href="https://www.databricks.com/blog/building-state-art-enterprise-agents-90x-cheaper-automated-prompt-optimization" rel="noopener noreferrer"&gt;Link&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;N. Shinn et al., "Reflexion: Language Agents with Verbal Reinforcement Learning," NeurIPS 2023. &lt;a href="https://arxiv.org/abs/2303.11366" rel="noopener noreferrer"&gt;arXiv:2303.11366&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;P. Chhikara et al., "Mem0: Building Production-Ready AI Agents with Scalable Long-Term Memory," ECAI 2025. &lt;a href="https://arxiv.org/abs/2504.19413" rel="noopener noreferrer"&gt;arXiv:2504.19413&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>llm</category>
      <category>agentskills</category>
      <category>agents</category>
      <category>ai</category>
    </item>
    <item>
      <title>Infinite Compute Glitch - Why Local AI Matters</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Fri, 28 Mar 2025 13:13:05 +0000</pubDate>
      <link>https://dev.to/akshayballal/infinite-compute-glitch-why-local-ai-matters-4n64</link>
      <guid>https://dev.to/akshayballal/infinite-compute-glitch-why-local-ai-matters-4n64</guid>
      <description>&lt;h1&gt;
  
  
  The Infinite Compute Glitch: Why Local AI Matters
&lt;/h1&gt;

&lt;p&gt;In our headlong rush toward AI-powered everything, we've created an interesting paradox. Companies and individuals alike have begun to operate as if compute resources are infinite—just fire off another API call to OpenAI, Claude, or Gemini. But this mindset is creating a significant blind spot in how we approach technology, with consequences that extend far beyond our screens.&lt;/p&gt;

&lt;p&gt;I wasn't just creating another tool when I was building &lt;a href="https://app.starlight-search.com" rel="noopener noreferrer"&gt;Starlight&lt;/a&gt;, a desktop app that indexes and enables semantic search across your files using local models. I was making a statement about how we should approach AI in our daily lives.&lt;/p&gt;

&lt;p&gt;Starlight lets you chat with your data and find information across your files without ever sending that data to external services. Everything happens on your machine, using your compute resources. This blog explores why apps like Starlight matter and why we should not ignore local AI. &lt;/p&gt;

&lt;h2&gt;
  
  
  The Infinite Compute Glitch
&lt;/h2&gt;

&lt;p&gt;I've started calling our current approach to AI "the infinite compute glitch"—a collective delusion that we can just keep scaling up models and data centers without consequences.&lt;/p&gt;

&lt;p&gt;Let's look at the hard numbers:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;A single ChatGPT query consumes approximately 10-15x more energy than a Google search &lt;a href="https://www.rwdigital.ca/blog/how-much-energy-do-google-search-and-chatgpt-use/" rel="noopener noreferrer"&gt;RW Digital&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;Training GPT-4 is estimated to have used over &lt;a href="https://www.bestbrokers.com/forex-brokers/ais-power-demand-calculating-chatgpts-electricity-consumption-for-handling-over-78-billion-user-queries-every-year/" rel="noopener noreferrer"&gt;62 million kilowatt-hours of electricity&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;AI data centers can use between &lt;a href="https://www.scmp.com/news/china/science/article/3259230/chinas-growing-data-centres-and-ai-industry-could-strain-scarce-water-resources-according-new-report" rel="noopener noreferrer"&gt;3-5 million gallons&lt;/a&gt; of water per day for cooling&lt;/li&gt;
&lt;li&gt;The carbon footprint of training a single large language model can equal that of five cars over their entire lifetimes&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;A 2023 study from the University of Massachusetts found that training a single transformer-based AI model can emit as much carbon as five cars over their entire lifetimes. Meanwhile, Microsoft had to limit Azure AI services in some regions due to capacity constraints, showing that even the cloud giants are hitting physical limits.&lt;/p&gt;

&lt;h2&gt;
  
  
  Why Local First Matters
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Your Data Stays Home
&lt;/h3&gt;

&lt;p&gt;The most obvious benefit is privacy. When your files never leave your computer, you maintain complete control over your information. No terms of service changes, no worrying about how your data might be used to train future models.&lt;/p&gt;

&lt;p&gt;Consider this: in 2023 alone, major AI companies updated their terms of service multiple times, often expanding their rights to use user data. Samsung even had to ban employees from using generative AI tools after sensitive code was uploaded to ChatGPT and became part of its training data. With local AI, these scenarios become impossible.&lt;/p&gt;

&lt;p&gt;At ASML, we are not allowed to use AI tools because the company carries a lot of sensitive information that is also of national interest. I understand their concern, but when other less sensitive industries utilize AI to make their employees more productive, we feel stuck in the past era. This can easily be solved by empowering the employees with local AI, either on-device or on-prem.&lt;/p&gt;

&lt;h3&gt;
  
  
  Distributed Computing Is Resilient Computing
&lt;/h3&gt;

&lt;p&gt;We're creating a more distributed system by leveraging the computing power already sitting on our desks and in our laps. Instead of centralizing all AI compute in massive data centers owned by a handful of companies, we're spreading the load across millions of devices.&lt;/p&gt;

&lt;p&gt;The average modern laptop now has more computing power than what was used to train early versions of BERT, one of the breakthrough language models from 2018. A typical gaming PC with a decent GPU can run models with billions of parameters. We're sitting on a vast, untapped ocean of compute power.&lt;/p&gt;

&lt;h3&gt;
  
  
  Speed Without the Wait
&lt;/h3&gt;

&lt;p&gt;For many tasks, local processing is actually faster. No upload time, no API latency, no waiting for your request to be queued behind thousands of others. The results appear as quickly as your computer can process them.&lt;/p&gt;

&lt;p&gt;In our internal testing, Starlight's local search returned results from a 1GB document collection in under 200ms, while the equivalent cloud API call took over 2 seconds when accounting for network latency and server processing time. That's a 10x difference in speed for common queries.&lt;/p&gt;

&lt;h3&gt;
  
  
  Right-Sized Solutions
&lt;/h3&gt;

&lt;p&gt;Here's something that might surprise you: more than 80% of the AI tasks the average person needs day-to-day can be handled by smaller, local models. We don't need GPT-4 to summarize a document or help draft an email. By matching the task to the appropriate model size, we save enormous resources.&lt;/p&gt;

&lt;p&gt;Take document summarization: a 1.5 billion parameter model running locally can produce summaries nearly indistinguishable from those generated by models 100x larger for most business documents. The difference? The smaller model uses a fraction of the energy and requires no data transmission.&lt;/p&gt;

&lt;h2&gt;
  
  
  What is Propelling Local AI Forward
&lt;/h2&gt;

&lt;p&gt;Here's what's making local AI happen:&lt;/p&gt;

&lt;h3&gt;
  
  
  1. Hardware: Increased Processing Power
&lt;/h3&gt;

&lt;p&gt;It's pretty amazing how much more powerful our devices have become! Laptops, smartphones, and even smaller embedded systems now pack a serious punch with their processors and memory. This means they're not just for basic tasks anymore; they've got the muscle to handle sophisticated AI models directly. Think about it: your phone today has more computing power than entire computers from not too long ago. This increase in processing power is a huge factor in enabling local AI. With the recent release of ARM based processors on PCs, local AI has become even more plausible. The key advancements have been make CPU and GPU share an unified memory, allowing you to fit even larger models on the GPU and super efficient CPU/GPU. This is only going to improve going forward so we need to keep pushing AI on the edge to make use of this untapped compute power. &lt;/p&gt;

&lt;h3&gt;
  
  
  2. Model Architecture and Data
&lt;/h3&gt;

&lt;p&gt;There is no doubt that the LLM models have become ever more intelligent. Every day a new, SOTA model is released. But even more important is that the models are getting smaller and more efficient. While companies like OpenAI are pushing toward ever more larger models which I believe is the wrong direction to move in, companies like Google, Mistral and Meta continue to surprise us with small models that are as powerful as the largest models from last year. For example Gemma3 27b, an open-source 27 billion parameter model from Google is as good as their own closed-source Gemini-1.5-pro which is believed to be over 300 billion parameters.  We can see from the below plot that both Mistral Small 3.1 and Gemma-3 are better than GPT-4o-mini which must be a very large model at this point. Even though these small models are good, they can only work well on high end computers. The user has to settle for even smaller versions of these models if they have a medium spec computer. But this is a start none the less. &lt;/p&gt;

&lt;p&gt;Several new pieces of innovations have made this possible. Introduction of KV caching, Flash Attention, Rotary Position Embedding (RoPE) are just some of the innovations that have truly reduced the gap between the performance of a small local model and a large cloud model. Also the quality of data that is used to train these models has improved over the years. Instead of training on just raw data from the internet which can be very noisy, a lot of clever preprocessing is being done to allow smaller models to learn just as well.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcms.mistral.ai%2Fassets%2F7ede577c-1f94-4391-ac44-3a54af3d2962.webp%3Fwidth%3D2102%26height%3D1138" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcms.mistral.ai%2Fassets%2F7ede577c-1f94-4391-ac44-3a54af3d2962.webp%3Fwidth%3D2102%26height%3D1138" alt="graph" width="800" height="433"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  3. Model Serving: Open-Source Models and Tools
&lt;/h3&gt;

&lt;p&gt;Lastly, the final piece is how are these local AI models distributed. Here, the open source community has been quite active and projects like Ollama, LlamaCPP, Candle, vLLM are making it easier to allow consumer grade hardware to run these open source models efficiently. These projects are the interface between the model weights and the the actual hardware. Making this interaction efficient is the key to making the best models run on the least amount of compute. &lt;/p&gt;

&lt;p&gt;But just serving these models isn’t enough. Tools like ChatGPT provide several more user experiences like Canvas, Web Search, Voice to Text etc, which make people use their services. Even if we made best models, available locally, people will still gravitate towards, OpenAI, Claude etc for the user experience. That’s where products like Starlight come in. The new age software has to be as much local first as possible and allow users to get the same experience that they get on ChatGPT or other similar platforms using their own local AI models. At this point, it is more about breaking the habit. Let’s say I want to summarize a research paper. The default habit that has developed over the last three years for people is to take the paper and upload it to ChatGPT and ask summarization. But this is a very easy task for local models. So how can we push software to break this habit and exploit use cases where local AI can outperform cloud AI. This is going to be a big challenge which we at Starlight strive to solve by building features that make you gravitate towards local AI rather than cloud models for most of your tasks.&lt;/p&gt;

&lt;h2&gt;
  
  
  Innovating Within Constraints
&lt;/h2&gt;

&lt;p&gt;At Starlight, we're embracing constraints rather than ignoring them. We're finding clever ways to make smaller models more effective through better indexing, retrieval, and context management.&lt;/p&gt;

&lt;p&gt;For example, we've developed a hybrid approach that uses local embedding models to index documents and then employs efficient retrieval algorithms to find the most relevant content. This approach reduces the context window needed for generation tasks, allowing even small local models to produce high-quality results on par with much larger cloud models. Moreover, we are implementing methods like recursive summarization and graph traversal to compress large amounts of information into the constrained context of local models. Also, it's not just about context; even if the local models have a long context, usually, the hardware cannot process this context as it might not be able to fit on the memory or the computation takes just too long. &lt;/p&gt;

&lt;p&gt;This approach forces us to be more thoughtful about how we use AI. Rather than throwing compute at every problem, we ask: What's the most efficient way to solve this? How can we deliver value while minimizing resource use?&lt;/p&gt;

&lt;h2&gt;
  
  
  A More Balanced Future: The Hybrid Approach
&lt;/h2&gt;

&lt;p&gt;I'm not suggesting we abandon cloud AI entirely—there are absolutely use cases where larger models are necessary. But by being intentional about when and how we use these resources, we can create a more balanced, sustainable approach to AI.&lt;/p&gt;

&lt;p&gt;Imagine a world where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Your personal documents, emails, and notes are processed entirely on your devices, maintaining your privacy&lt;/li&gt;
&lt;li&gt;Basic creative tasks like drafting emails or summarizing articles happen locally, with instant response times&lt;/li&gt;
&lt;li&gt;Only specialized tasks that truly require massive models—like complex code generation or scientific research—call out to cloud services&lt;/li&gt;
&lt;li&gt;Organizations maintain their own small, specialized models trained on their data, running on local infrastructure&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;This isn't science fiction—all the technology exists today. What's missing is the mindset shift.&lt;/p&gt;

&lt;h2&gt;
  
  
  Taking Action
&lt;/h2&gt;

&lt;p&gt;So what can you do to be part of this shift?&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Try local-first AI tools&lt;/strong&gt; like Starlight for your personal and professional needs&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Be mindful of your API usage&lt;/strong&gt; when you do use cloud services&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Support open-source AI projects&lt;/strong&gt; that are making models more accessible and efficient&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The future isn't about unlimited compute—it's about smart compute. It's about knowing when a local model is sufficient and when you truly need something more powerful.&lt;/p&gt;

&lt;p&gt;So next time you're about to send that API call, ask yourself: Could this happen locally instead? Your privacy, your wallet, and our planet might thank you for it.&lt;/p&gt;

</description>
      <category>llm</category>
      <category>ai</category>
      <category>chatgpt</category>
      <category>openai</category>
    </item>
    <item>
      <title>Optimize VLM Tokens with EmbedAnything x ColPali</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Sun, 12 Jan 2025 12:38:49 +0000</pubDate>
      <link>https://dev.to/akshayballal/optimize-vlm-tokens-with-embedanything-x-colpali-2d1g</link>
      <guid>https://dev.to/akshayballal/optimize-vlm-tokens-with-embedanything-x-colpali-2d1g</guid>
      <description>&lt;p&gt;&lt;strong&gt;ColPali&lt;/strong&gt;, a late-interaction vision model, leverages this power to enable text searches within images. This means you can pinpoint the exact pages in a PDF containing relevant text, even if the text exists only as part of an image. For example, suppose you have hundreds of pages in a PDF and even hundreds of PDFs. In that case, &lt;strong&gt;ColPali&lt;/strong&gt; can identify the specific pages matching a query—an impressive feat for streamlining information retrieval. This system is widely come to be known as Vision RAG. &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhdlm7goq32mwh6u3hvgx.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhdlm7goq32mwh6u3hvgx.png" alt="Image of a robot reading documents" width="800" height="714"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;However, due to its computational demands, running the ColPali model directly on a local machine might not always be feasible. To address this, I developed a &lt;a href="https://huggingface.co/akshayballal/colpali-v1.2-merged-onnx" rel="noopener noreferrer"&gt;&lt;strong&gt;onnx version of ColPali&lt;/strong&gt;&lt;/a&gt; which can be quantized to different precisions. &lt;strong&gt;Quantization&lt;/strong&gt; reduces the precision of the model's weights, significantly lowering computational and memory requirements. Despite this optimization, the quantized model maintains performance nearly equivalent to the original. In this article we will look at how to use ColPali for Vision RAG by using the &lt;a href="https://github.com/StarlightSearch/EmbedAnything" rel="noopener noreferrer"&gt;EmbedAnything&lt;/a&gt; library that I have been developing for the last few months. You can read more about EmbedAnything &lt;a href="https://starlight-search.com/blog/2024/03/31/embed-anything/" rel="noopener noreferrer"&gt;here&lt;/a&gt; &lt;/p&gt;

&lt;h3&gt;
  
  
  What is Vision RAG?
&lt;/h3&gt;

&lt;p&gt;Let’s look a bit deeper into what Vision RAG is. Traditional RAG methods use text throughout the pipeline. They store text chunks and their embeddings in a vector database and then retrieve these chunks for further downstream tasks. A simplest / naive RAG attaches these chunks as context to the original query and aims to provide more information to the model. There are two problems here. One is that getting text from many data sources may not be possible. Think about scanned PDFs or documents with many graphics, like design pamphlets, etc. The traditional RAG falls apart if any documents you work with are like this. A bandaid to the problem is to use OCR engines to somehow extract text. This adds additional moving parts to the process, and OCR engines are pretty fragile.  The second problem, even if you manage to get the text, is the chunking process. Again, how do you decide what the chunk size should be and what the overlap should be? Even if you find optimal parameters for a few documents, will they hold for new ones? All these parameters add to the design space, and the RAG performance needs to be continuously evaluated based on these design choices. Vision RAG tries to solve this by removing the whole chunking process from the system and instead storing the image as a multi-vector embedding in the database. When there is a query, a Late Interaction Score (LIS), similar to the classical cosine similarity but for multi-vector, is measured, and the DB returns the document pages with the highest LIS scores. These documents can now be sent to a Vision Language Model (VLM) along with the original query to get the answer to the questions. &lt;/p&gt;

&lt;p&gt;The image below shows this process from start to end. Since vision language models are more expensive than text models, Vision RAG is even more important because you don’t have to send complete PDFs to the model. You are just sending the relevant pages. This can save a lot of costs. The document embedding generation happens offline and is taken care of by EmbedAnything. One drawback with this approach is that not all vector databases today support storing multi-vectors. A few that support these are Qdrant and Vespa. &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcfdzxja46d3zf20zciya.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcfdzxja46d3zf20zciya.png" alt="Vision RAG Flow in EmbedAnything" width="800" height="401"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Let us look at how you can use Colpali models with EmbedAnything and convert PDFs into multi-vector embeddings. In this example, we will not use a vector database but find the late interaction score of the query against all the pages. &lt;/p&gt;

&lt;h3&gt;
  
  
  Step 1: Install the dependencies
&lt;/h3&gt;

&lt;p&gt;Since we are going to convert pdfs into images, we need poppler-utils. &lt;/p&gt;

&lt;p&gt;EmbedAnything requires poppler to convert pdfs to images. So make sure you have it installed.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;For Linux:
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;apt &lt;span class="nb"&gt;install &lt;/span&gt;poppler-utils
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;ul&gt;
&lt;li&gt;For Mac
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;brew install poppler

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;ul&gt;
&lt;li&gt;For Windows&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a href="https://github.com/oschwartz10612/poppler-windows/releases/tag/v24.08.0-0" rel="noopener noreferrer"&gt;https://github.com/oschwartz10612/poppler-windows/releases/tag/v24.08.0-0&lt;/a&gt;&lt;br&gt;
Download the binary from here, unzip it and add the &lt;code&gt;bin&lt;/code&gt; folder to your system path.&lt;/p&gt;

&lt;p&gt;Using the GPU version of EmbedAnything is highly recommended because ColPali is based on paligemma and requires a computation like any other small language model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;embed-anything-gpu tabulate openai
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let’s import EmbedAnything and the other dependencies:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;base64&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;EmbedData&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ColpaliModel&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tabulate&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tabulate&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;pathlib&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Path&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;io&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;openai&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 2: Get the files that need to be indexed
&lt;/h3&gt;

&lt;p&gt;For this demo, we will clone the EmbedAnything repo which has some test pdfs with the “Attention is all you need” and a Mistral paper.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exists&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;EmbedAnything&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="err"&gt;!&lt;/span&gt;&lt;span class="n"&gt;git&lt;/span&gt; &lt;span class="n"&gt;clone&lt;/span&gt; &lt;span class="n"&gt;https&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="n"&gt;github&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;com&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;StarlightSearch&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;EmbedAnything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gi&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 3 : Load the ColPali Onnx Model
&lt;/h3&gt;

&lt;p&gt;Use the &lt;code&gt;embed_anything&lt;/code&gt; function with &lt;code&gt;from_pretrained_onnx&lt;/code&gt; to load the &lt;strong&gt;ColPali Onnx model&lt;/strong&gt; from the specified link. This initializes the model for embedding tasks. If you are using a python notebook, this can take some time because the model is being downloaded. Unfortunately, the progress bar is not visible on a notebook. You can also load the original Colpali model and not the &lt;code&gt;onnx&lt;/code&gt; model using the &lt;code&gt;from_pretrained_hf&lt;/code&gt; function.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;ColpaliModel&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ColpaliModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained_onnx&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;starlight-ai/colpali-v1.2-merged-onnx&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 4: Load the files and embed them.
&lt;/h3&gt;

&lt;p&gt;Now, we just load all the files from the directory with a PDF extension. Then, for each file, we run the &lt;code&gt;embed_file&lt;/code&gt; function with a &lt;code&gt;batch_size&lt;/code&gt; of 1. You can increase the batch size if you have higher VRAM, but one works well.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;directory&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Path&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;EmbedAnything/test_files&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;files&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;directory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;glob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;*.pdf&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;EmbedData&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="nb"&gt;file&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;files&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;try&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;EmbedData&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embed_file&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;file&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;except&lt;/span&gt; &lt;span class="nb"&gt;Exception&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error embedding file &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nb"&gt;file&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;file_embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Embedded Files: &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;files&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;code&gt;file_embeddings&lt;/code&gt; is a list of &lt;code&gt;EmbedData&lt;/code&gt; object which contains other metadata along with the embeddings like page number, file name and the image of the page in string base64 format. You can now store these embeddings in a vector database of choice. &lt;/p&gt;

&lt;h3&gt;
  
  
  Step 5: Process the query
&lt;/h3&gt;

&lt;p&gt;We do the same for the query as well using &lt;code&gt;embed_query&lt;/code&gt; function.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;query&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;What is positional encoding?&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;query_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embed_query&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;query_embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;query_embedding&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 6: Compute Similarity Scores
&lt;/h3&gt;

&lt;p&gt;We can calculate the Late Interaction Score between &lt;strong&gt;query&lt;/strong&gt; and &lt;strong&gt;file embeddings using the Einstein summation function&lt;/strong&gt;. This identifies the most relevant pages based on the highest scores. Extract the &lt;strong&gt;top 3 pages&lt;/strong&gt; for further processing. We also take out the &lt;code&gt;image&lt;/code&gt; field from the &lt;code&gt;EmbedData&lt;/code&gt; object of the embeddings. This is a base64 string representation of the image that will send to GPT.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query_embeddings&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;file_embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;einsum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;bnd,csd-&amp;gt;bcns&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query_embeddings&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;file_embeddings&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# Get top pages
&lt;/span&gt;    &lt;span class="n"&gt;top_pages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)[::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;][:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Extract file names and page numbers
&lt;/span&gt;    &lt;span class="n"&gt;table&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
        &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;page&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;metadata&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;file_path&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;page&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;metadata&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;page_number&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;page&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;top_pages&lt;/span&gt;
    &lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="c1"&gt;# Print the results in a table
&lt;/span&gt;    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;tabulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;table&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;headers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;File Name&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Page Number&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;tablefmt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;grid&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;results_str&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tabulate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;table&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;headers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;File Name&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Page Number&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;tablefmt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;grid&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;file_embed_data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;page&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;metadata&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;page&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;top_pages&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;images_pil&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;io&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;BytesIO&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;base64&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;b64decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;images_pil&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;results_str&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;images_str&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The result will look something like this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;+----------------------------------------+---------------+
| File Name                              |   Page Number |
+========================================+===============+
| EmbedAnything/test_files/attention.pdf |             6 |
+----------------------------------------+---------------+
| EmbedAnything/test_files/attention.pdf |             9 |
+----------------------------------------+---------------+
| EmbedAnything/test_files/linear.pdf    |            34 |
+----------------------------------------+---------------+
| EmbedAnything/test_files/attention.pdf |             3 |
+----------------------------------------+---------------+
| EmbedAnything/test_files/attention.pdf |            15 |
+----------------------------------------+---------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We can visualize the top 3 pages using this command &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffctc5p6ujnsc5armmzev.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffctc5p6ujnsc5armmzev.png" alt="image.png" width="800" height="316"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 7: Send these images to OpenAI
&lt;/h3&gt;

&lt;p&gt;Now we can send these top 3 retrieved images to OpenAI gpt-4o-mini model along with the original query. You can add further instructions for the model here as per your needs. Don’t forget to add your OpenAI key to the client.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;openai&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;OpenAI&lt;/span&gt;

&lt;span class="n"&gt;client&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;OpenAI&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;api_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="n"&gt;openai&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;image_contents&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;type&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image_url&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image_url&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;url&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;data:image/jpeg;base64,&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;image_str&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image_str&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;images_str&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;completions&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gpt-4o-mini&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;
        &lt;span class="p"&gt;{&lt;/span&gt;
            &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
                &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;type&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;text&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;text&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
            &lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;image_contents&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="p"&gt;],&lt;/span&gt;

&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The output looks like this&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;
Positional encoding is a critical concept in transformer models, which addresses the inherent limitation of self-attention mechanisms: they do not consider the order of input tokens. Since transformers process all tokens simultaneously, they require a way to encode the order of tokens in a sequence to maintain their relative positions.

&lt;span class="gu"&gt;### Key Aspects of Positional Encoding:&lt;/span&gt;
&lt;span class="p"&gt;
1.&lt;/span&gt; &lt;span class="gs"&gt;**Purpose**&lt;/span&gt;: It helps the model understand the sequence of data since transformers lack recurrence or convolution that traditionally encode this information.
&lt;span class="p"&gt;
2.&lt;/span&gt; &lt;span class="gs"&gt;**Method**&lt;/span&gt;: 
&lt;span class="p"&gt;   -&lt;/span&gt; Positional encodings are added to the input embeddings of tokens.
&lt;span class="p"&gt;   -&lt;/span&gt; A common approach is to use sine and cosine functions of different frequencies, defined mathematically as:&lt;span class="sb"&gt;

     \[
     PE(pos, 2i) = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
     \]
     \[
     PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)
     \]

&lt;/span&gt;&lt;span class="p"&gt;   -&lt;/span&gt; Here, &lt;span class="se"&gt;\(&lt;/span&gt; pos &lt;span class="se"&gt;\)&lt;/span&gt; is the position of the token, &lt;span class="se"&gt;\(&lt;/span&gt; i &lt;span class="se"&gt;\)&lt;/span&gt; is the dimension, and &lt;span class="se"&gt;\(&lt;/span&gt; d_{model} &lt;span class="se"&gt;\)&lt;/span&gt; is the dimensionality of the embedding.
&lt;span class="p"&gt;
3.&lt;/span&gt; &lt;span class="gs"&gt;**Frequency**&lt;/span&gt;: The functions allow for various wavelengths, making it possible to learn relationships at different scales, which enables the model to understand both short-range and long-range dependencies in the sequence.
&lt;span class="p"&gt;
4.&lt;/span&gt; &lt;span class="gs"&gt;**Alternatives**&lt;/span&gt;: While sinusoidal encodings are widely used, learned positional embeddings can also be employed, which allows the model to learn the optimal way to encode positions during training.
y
In summary, positional encoding is vital for allowing transformer models to grasp the order of tokens in sequences, facilitating effective learning from sequential data.

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This response used a total of 2500 tokens which translates to $0.006. If we would have sent the entire &lt;code&gt;pdf&lt;/code&gt; of 15 pages, without retrieval to the model, it would have cost about 12,500 tokens which is five times higher than this system. And this is assuming we know which &lt;code&gt;pdf&lt;/code&gt; to send. Also the response may not be accurate because the model has too much unnecessary information to filter out.&lt;/p&gt;

&lt;p&gt;I hope that this blog was useful. Vision RAG is going to be a staple block in future retrieval systems so its good to start using it to make your LLM pipelines more efficient. &lt;/p&gt;

&lt;p&gt;Check out the demo notebook at &lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/drive/1JoqVs3athdd9FZOrrHb9WagxIK45CZUX?usp=sharing" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open in Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;br&gt;
😸&lt;a href="https://github.com/akshayballal95/rl-blog-series/tree/reinforce" rel="noopener noreferrer"&gt;GitHub Repo&lt;/a&gt;&lt;br&gt;
🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt;&lt;/p&gt;

</description>
      <category>rag</category>
      <category>llm</category>
      <category>openai</category>
      <category>ai</category>
    </item>
    <item>
      <title>Build ReAct Agents using SLMs from Scratch</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Sun, 29 Sep 2024 10:00:52 +0000</pubDate>
      <link>https://dev.to/akshayballal/build-react-agents-using-slms-from-scratch-4877</link>
      <guid>https://dev.to/akshayballal/build-react-agents-using-slms-from-scratch-4877</guid>
      <description>&lt;p&gt;In this post, I will demonstrate how to create a function-calling agent using &lt;strong&gt;Small Language Models (SLMs)&lt;/strong&gt;. Leveraging SLMs offers a range of benefits, especially when paired with tools like LoRA adapters for efficient fine-tuning and execution. While Large Language Models (LLMs) are powerful, they can be resource-intensive and slow. On the other hand, SLMs are more lightweight, making them ideal for environments with limited hardware resources or specific use cases where lower latency is critical.&lt;/p&gt;

&lt;p&gt;By using &lt;strong&gt;SLMs&lt;/strong&gt; with &lt;strong&gt;LoRA adapters&lt;/strong&gt;, we can separate reasoning and function execution tasks to optimize performance. For instance, the model can execute complex function calls using the adapter and handle reasoning or thinking tasks without it, thus conserving memory and improving speed. This flexibility is perfect for building applications like function-calling agents without needing the infrastructure required for larger models.&lt;/p&gt;

&lt;p&gt;Moreover, SLMs can be easily scaled to run on devices with limited computational power, making them ideal for production environments where cost and efficiency are prioritized. In this example, we'll use a custom model trained on the &lt;strong&gt;Salesforce/xlam-function-calling-60k&lt;/strong&gt; dataset via Unsloth, demonstrating how you can utilize SLMs to create high-performance, low-resource AI applications.&lt;/p&gt;

&lt;p&gt;Additionally, the approach discussed here can be scaled to more powerful models, such as &lt;strong&gt;LLaMA 3.1-8B&lt;/strong&gt;, which have in-built function-calling capabilities, offering a smooth transition when larger models are necessary.&lt;/p&gt;

&lt;h3&gt;
  
  
  1. Initiate the Model and Tokenizer with Unsloth
&lt;/h3&gt;

&lt;p&gt;We’ll first set up the model and tokenizer using &lt;strong&gt;Unsloth&lt;/strong&gt;. Here, we define a max sequence length of 2048, though this can be adjusted. We also enable &lt;strong&gt;4-bit quantization&lt;/strong&gt; to reduce memory usage, ideal for running models on lower-memory hardware.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;unsloth&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;FastLanguageModel&lt;/span&gt;
&lt;span class="n"&gt;max_seq_length&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2048&lt;/span&gt; &lt;span class="c1"&gt;# Choose any! We auto support RoPE Scaling internally!
&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt; &lt;span class="c1"&gt;# None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
&lt;/span&gt;&lt;span class="n"&gt;load_in_4bit&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt; &lt;span class="c1"&gt;# Use 4bit quantization to reduce memory usage. Can be False.
&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;FastLanguageModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;akshayballal/phi-3.5-mini-xlam-function-calling&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;max_seq_length&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;max_seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;load_in_4bit&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;load_in_4bit&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;FastLanguageModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;for_inference&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;);&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  2. Implement Stopping Criteria for Controlled Generation
&lt;/h3&gt;

&lt;p&gt;To ensure that the agent pauses execution after function calls, we define a &lt;strong&gt;stopping criteria&lt;/strong&gt;. This will halt the generation when the model outputs the keyword "PAUSE," allowing the agent to fetch the result of the function call.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;StoppingCriteria&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;StoppingCriteriaList&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;KeywordsStoppingCriteria&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;StoppingCriteria&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keywords_ids&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keywords&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;keywords_ids&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;LongTensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;FloatTensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keywords&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;

&lt;span class="n"&gt;stop_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;17171&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;stop_criteria&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;KeywordsStoppingCriteria&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  3. Define the Tools for Function Calling
&lt;/h3&gt;

&lt;p&gt;Next, we define the functions the agent will use during execution. These Python functions will act as "tools" that the agent can call. The return type must be clear, and the function should include a descriptive docstring, as the agent will rely on this to choose the correct tool.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;add_numbers&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
    This function takes two integers and returns their sum.

    Parameters:
    a (int): The first integer to add.
    b (int): The second integer to add.
    &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt; 

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;square_number&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
    This function takes an integer and returns its square.

    Parameters:
    a (int): The integer to be squared.
    &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;square_root_number&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
    This function takes an integer and returns its square root.

    Parameters:
    a (int): The integer to calculate the square root of.
    &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  4. Generate Tool Descriptions for the Agent
&lt;/h3&gt;

&lt;p&gt;These function descriptions will be structured into a list of dictionaries. The agent will use these to understand the available tools and their parameters.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;tool_descriptions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;name&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;__name__&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;description&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;__doc__&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;parameters&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
            &lt;span class="p"&gt;{&lt;/span&gt;
                &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;name&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;param&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;type&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;arg&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;__name__&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;hasattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;arg&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;__name__&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="nf"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;arg&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="p"&gt;}&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;param&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;arg&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;__annotations__&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;param&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;return&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
        &lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="n"&gt;tool_descriptions&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;spec&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;tool_descriptions&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is how the output looks like&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="o"&gt;[{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'add_numbers'&lt;/span&gt;,
  &lt;span class="s1"&gt;'description'&lt;/span&gt;: &lt;span class="s1"&gt;'This function takes two integers and returns their sum.\n\n    Parameters:\n    a (int): The first integer to add.\n    b (int): The second integer to add.'&lt;/span&gt;,
  &lt;span class="s1"&gt;'parameters'&lt;/span&gt;: &lt;span class="o"&gt;[{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'type'&lt;/span&gt;: &lt;span class="s1"&gt;'int'&lt;/span&gt;&lt;span class="o"&gt;}&lt;/span&gt;, &lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'b'&lt;/span&gt;, &lt;span class="s1"&gt;'type'&lt;/span&gt;: &lt;span class="s1"&gt;'int'&lt;/span&gt;&lt;span class="o"&gt;}]}&lt;/span&gt;,
 &lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'square_number'&lt;/span&gt;,
  &lt;span class="s1"&gt;'description'&lt;/span&gt;: &lt;span class="s1"&gt;'This function takes an integer and returns its square.\n\n    Parameters:\n    a (int): The integer to be squared.'&lt;/span&gt;,
  &lt;span class="s1"&gt;'parameters'&lt;/span&gt;: &lt;span class="o"&gt;[{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'type'&lt;/span&gt;: &lt;span class="s1"&gt;'int'&lt;/span&gt;&lt;span class="o"&gt;}]}&lt;/span&gt;,
 &lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'square_root_number'&lt;/span&gt;,
  &lt;span class="s1"&gt;'description'&lt;/span&gt;: &lt;span class="s1"&gt;'This function takes an integer and returns its square root.\n\n    Parameters:\n    a (int): The integer to calculate the square root of.'&lt;/span&gt;,
  &lt;span class="s1"&gt;'parameters'&lt;/span&gt;: &lt;span class="o"&gt;[{&lt;/span&gt;&lt;span class="s1"&gt;'name'&lt;/span&gt;: &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'type'&lt;/span&gt;: &lt;span class="s1"&gt;'int'&lt;/span&gt;&lt;span class="o"&gt;}]}]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  5. Create the Agent Class
&lt;/h3&gt;

&lt;p&gt;We then create the agent class that takes the system prompt, the function calling prompt, the tools and the messages as input and returns the response from the agent.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;__call__&lt;/code&gt; is the function that is called when the agent is called with a message. It adds the message to the messages list and returns the response from the agent.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;execute&lt;/code&gt; is the function that is called to generate the response from the agent. It uses the model to generate the response.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;function_call&lt;/code&gt; is the function that is called to generate the response from the agent. It uses the function calling model to generate the response.
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;ast&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;system&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;""&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;function_calling_prompt&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;""&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;system&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;system&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tools&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tools&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;function_calling_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;function_calling_prompt&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;system&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;system&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;system&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;message&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;""&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;message&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;message&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
        &lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;execute&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;assistant&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;execute&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;disable_adapter&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;  &lt;span class="c1"&gt;# disable the adapter for thinking and reasoning
&lt;/span&gt;            &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;apply_chat_template&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;add_generation_prompt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;generate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;stopping_criteria&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nc"&gt;StoppingCriteriaList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;stop_criteria&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;function_call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;message&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;apply_chat_template&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="p"&gt;[&lt;/span&gt;
                &lt;span class="p"&gt;{&lt;/span&gt;
                    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;function_calling_prompt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;format&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                        &lt;span class="n"&gt;tool_descriptions&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tool_descriptions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;message&lt;/span&gt;
                    &lt;span class="p"&gt;),&lt;/span&gt;
                &lt;span class="p"&gt;}&lt;/span&gt;
            &lt;span class="p"&gt;],&lt;/span&gt;
            &lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;add_generation_prompt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;generate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;temperature&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;prompt_length&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

        &lt;span class="n"&gt;answer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ast&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;literal_eval&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;prompt_length&lt;/span&gt;&lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)[&lt;/span&gt;
            &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="p"&gt;]&lt;/span&gt;  &lt;span class="c1"&gt;# get the output of the function call model as a dictionary
&lt;/span&gt;        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;answer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;tool_output&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;run_tool&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;answer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;name&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;answer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;arguments&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tool_output&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;run_tool&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;tool&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;__name__&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;tool&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  6. Define System and Function-Calling Prompts
&lt;/h3&gt;

&lt;p&gt;We now define two key prompts:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;System Prompt&lt;/strong&gt;: The core logic for the agent's reasoning and tool use, following the &lt;strong&gt;ReAct&lt;/strong&gt; pattern.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Function-Calling Prompt&lt;/strong&gt;: This enables function calling by passing the relevant tool descriptions and queries.
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;system_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
You run in a loop of Thought, Action, PAUSE, Observation.
At the end of the loop you output an Answer
Use Thought to describe your thoughts about the question you have been asked.
Use Action to run one of the actions available to you - then return PAUSE.
Observation will be the result of running those actions. Stop when you have the Answer. 
Your available actions are:

&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;

Example session:

Question: What is the mass of Earth times 2?
Thought: I need to find the mass of Earth
Action: get_planet_mass: Earth
PAUSE 

Observation: 5.972e24

Thought: I need to multiply this by 2
Action: calculate: 5.972e24 * 2
PAUSE

Observation: 1,1944×10e25

If you have the answer, output it as the Answer.

Answer: &lt;/span&gt;&lt;span class="se"&gt;\\&lt;/span&gt;&lt;span class="s"&gt;{{1,1944×10e25&lt;/span&gt;&lt;span class="se"&gt;\\&lt;/span&gt;&lt;span class="s"&gt;}}.
PAUSE
Now it&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s your turn:
&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;function_calling_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
You are a helpful assistant. Below are the tools that you have access to.  &lt;/span&gt;&lt;span class="se"&gt;\n\n&lt;/span&gt;&lt;span class="s"&gt;### Tools: &lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;{tool_descriptions} &lt;/span&gt;&lt;span class="se"&gt;\n\n&lt;/span&gt;&lt;span class="s"&gt;### Query: &lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;{query} &lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;
&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  7. Implement the ReAct Loop
&lt;/h3&gt;

&lt;p&gt;Finally, we define the loop that enables the agent to interact with the user, execute the necessary function calls, and return the correct answers.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;loop_agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;question&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_iterations&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="n"&gt;next_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;question&lt;/span&gt;
    &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;max_iterations&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_prompt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Answer:&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;

        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;re&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;findall&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Action: (.*)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;tool_output&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;function_call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;next_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Observation: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;tool_output&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_prompt&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;next_prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Observation: tool not found&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
        &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;result&lt;/span&gt;

&lt;span class="n"&gt;agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt; &lt;span class="n"&gt;system&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;system_prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;function_calling_prompt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;function_calling_prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tools&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;loop_agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;what is the square root of the difference between 32^2 and 54&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;);&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Check out the complete notebook on Colab &lt;a href="https://colab.research.google.com/drive/1rpJZ-hSGG0E83xUNE5oAU2R0UmANn7Tb" rel="noopener noreferrer"&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Conclusion
&lt;/h3&gt;

&lt;p&gt;By following this step-by-step guide, you can create a function-calling agent using a custom model trained with &lt;strong&gt;Unsloth&lt;/strong&gt; and &lt;strong&gt;LoRA adapters&lt;/strong&gt;. This approach ensures efficient memory use while maintaining robust reasoning and function execution capabilities.&lt;/p&gt;

&lt;p&gt;Explore further by extending this method to larger models or customizing the functions available to the agent.&lt;/p&gt;

</description>
      <category>llm</category>
      <category>ai</category>
      <category>python</category>
      <category>machinelearning</category>
    </item>
    <item>
      <title>Vector Streaming with EmbedAnything</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Thu, 12 Sep 2024 19:05:46 +0000</pubDate>
      <link>https://dev.to/akshayballal/vector-streaming-with-embedanything-4hnb</link>
      <guid>https://dev.to/akshayballal/vector-streaming-with-embedanything-4hnb</guid>
      <description>&lt;p&gt;In my previous article &lt;a href="https://dev.to/akshayballal/introducing-embedanything-950"&gt;Introducing EmbedAnything&lt;/a&gt;, I discussed the idea behind EmbedAnything and how it makes creating embeddings from multiple modalities easy. In this article, I want to introduce a new feature of EmbedAnything called vector streaming and see how it works with Weaviate Vector Database. &lt;/p&gt;

&lt;h3&gt;
  
  
  What is the problem?
&lt;/h3&gt;

&lt;p&gt;First, let's examine the current problem with creating embeddings, especially in large-scale documents. The current embedding frameworks operate on a two-step process: chunking and embedding. First, the text is extracted from all the files, and chunks/nodes are created. Then, these chunks are fed to an embedding model with a specific batch size to process the embeddings. While this is done, the chunks and the embeddings stay on the system memory. This is not a problem when the files are small, and the embedding dimensions are small. But this becomes a problem when there are many files and you are working with large models and, even worse, multi-vector embeddings. Thus, to work with this, a high RAM is required to process the embeddings. Also, if this is done synchronously, a lot of time is wasted while the chunks are being created, as chunking is not a compute-heavy operation. As the chunks are being made, passing them to the embedding model would be efficient. &lt;/p&gt;

&lt;h3&gt;
  
  
  Our Solution
&lt;/h3&gt;

&lt;p&gt;The solution is to create an asynchronous chunking and embedding task. We can effectively spawn threads to handle this task using Rust's concurrency patterns and thread safety. This is done using Rust's MPSC (Multi-producer Single Consumer) module, which passes messages between threads. Thus, this creates a stream of chunks passed into the embedding thread with a buffer. Once the buffer is complete, it embeds the chunks and sends the embeddings back to the main thread, where they are sent to the vector database. This ensures no time is wasted on a single operation and no bottlenecks. Moreover, only the chunks and embeddings in the buffer are stored in the system memory. They are erased from the memory once moved to the vector database. &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fiy2dss0x8gzotwwyqbap.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fiy2dss0x8gzotwwyqbap.png" alt="streaming"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Example Use Case
&lt;/h3&gt;

&lt;p&gt;Now, let's see this feature in action. &lt;/p&gt;

&lt;p&gt;With EmbedAnything, streaming the vectors from a directory of files to the vector database is a simple three-step process. &lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Create an adapter for your vector database:&lt;/strong&gt; This is a wrapper around the database's functions that allows you to create an index, convert metadata from EmbedAnything's format to the format required by the database, and the function to insert the embeddings in the index. Adapters for the prominent databases are already created and present &lt;a href="https://github.com/StarlightSearch/EmbedAnything/tree/main/examples/adapters" rel="noopener noreferrer"&gt;here&lt;/a&gt;: &lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Initiate an embedding model of your choice:&lt;/strong&gt; You can choose from different local models or even cloud models. The configuration can also be determined to set the chunk size and buffer size for how many embeddings need to be streamed at once. Ideally, this should be as high as possible, but the system RAM limits this. &lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Call the embedding function from EmbedAnything:&lt;/strong&gt; Just pass the directory path to be embedded, the embedding model, the adapter, and the configuration. &lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;In this example, we will embed a directory of images and send it to the vector databases. &lt;/p&gt;

&lt;h4&gt;
  
  
  Step 1: Create the Adapter
&lt;/h4&gt;

&lt;p&gt;In EmbedAnything, the adapters are created outside so as to not make the library heavy and you get to choose which database you want to work with. Here is a simple adapter for Weaviate.&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;EmbedData&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;embed_anything.vectordb&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Adapter&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;WeaviateAdapter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Adapter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;api_key&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;url&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;api_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;weaviate&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;connect_to_weaviate_cloud&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;cluster_url&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;url&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;auth_credentials&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;wvc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Auth&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;api_key&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;api_key&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_ready&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Weaviate is ready&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;create_index&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;index_name&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collection&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vectorizer_config&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;wvc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Configure&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Vectorizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;none&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collection&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;convert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;List&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;EmbedData&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
        &lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="nb"&gt;property&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;metadata&lt;/span&gt;
            &lt;span class="nb"&gt;property&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;text&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;
            &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;wvc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataObject&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;properties&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;property&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;upsert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;convert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;insert_many&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;delete_index&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;delete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;### Start the client and index
&lt;/span&gt;
&lt;span class="n"&gt;URL&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;your-weaviate-url&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;API_KEY&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;your-weaviate-api-key&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;weaviate_adapter&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;WeaviateAdapter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;API_KEY&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;URL&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;index_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Test_index&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;index_name&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;weaviate_adapter&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collections&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;list_all&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;weaviate_adapter&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;delete_index&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;index_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;weaviate_adapter&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create_index&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Test_index&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h4&gt;
  
  
  Step 2: Create the Embedding Model
&lt;/h4&gt;

&lt;p&gt;Here, since we are embedding images, we can use the clip model &lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;WhichModel&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;EmbeddingModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained_cloud&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;embed_anything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;WhichModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Clip&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model_id&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;""&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;



&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Step 3
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;


&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embed_image_directory&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;\image_directory&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;embeder&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;adapter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;weaviate_adapter&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;embed_anything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ImageEmbedConfig&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;buffer_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;



&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Step 4
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;query_vector&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embed_anything&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embed_query&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image of a cat&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;embeder&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;
    &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Step 5
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;weaviate_adapter&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;collection&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;near_vector&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;near_vector&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;query_vector&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;limit&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;return_metadata&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;wvc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MetadataQuery&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;certainty&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Check the response;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F15fuvpuk26048oraonjd.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F15fuvpuk26048oraonjd.png" alt="Output"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can check out the notebook here in Colab&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/drive/17vUZEh-ZSpN339pIXSkyxtDHS5Sz6DqD?usp=sharing" rel="noopener noreferrer"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open in Colab"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We think it’s one of the features that will empower many engineers to opt for a more optimized and no-tech debt solution. Instead of using bulky frameworks on the cloud, you can use a lightweight streaming option. Please don't forget to give us a ⭐ on our GitHub repo over here: &lt;a href="https://github.com/StarlightSearch/EmbedAnything" rel="noopener noreferrer"&gt;https://github.com/StarlightSearch/EmbedAnything&lt;/a&gt;&lt;/p&gt;

</description>
      <category>ai</category>
      <category>rust</category>
      <category>rag</category>
      <category>vectordatabase</category>
    </item>
    <item>
      <title>Reinforcement Learning from Scratch - Part 3 - REINFORCE Algorithm</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Tue, 06 Aug 2024 14:01:08 +0000</pubDate>
      <link>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-3-reinforce-algorithm-4724</link>
      <guid>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-3-reinforce-algorithm-4724</guid>
      <description>&lt;p&gt;Hey all! Welcome to the third part of the Reinforcement Learning Series, where I explain the different RL algorithms and introduce you to some tips and tricks that RL practitioners use to make them work. So far, we have looked at the Tabular Q Learning Method and the Deep Q Learning Method. If you haven’t checked these out yet, look at the links below for motivation as to why we are still investigating more algorithms. &lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;&lt;a href="https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-1-tabular-q-learning-22k3"&gt;Tabular Q Learning&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-2-deep-q-learning-4opo"&gt;Deep Q Learning&lt;/a&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The main problem with these algorithms was that they did not allow us to use continuous action space environments, and we always had to discretize the action space, which may not always be desirable. So, in this part, we will explore a new RL variant that is different from the Q-learning variant we have looked at until now. These methods are called Policy Gradient methods.  Let’s look at some details about Policy Learning or Policy Gradient Methods.&lt;/p&gt;

&lt;p&gt;With Q-learning, our neural network model learns the Q function at each state, from which we derive the policy. The policy was simple: take the epsilon greedy action over the Q values or select the action with the largest Q value. But do we actually need to learn the Q-values when we want the policy? The answer is a resounding Yes. We can learn the policy directly, and the REINFORCE algorithm we are looking at in this part is a straightforward algorithm that does just that. While it may not be the most suitable algorithm for all environments, it serves as an excellent introduction to policy learning or policy gradient algorithms.  &lt;/p&gt;

&lt;h3&gt;
  
  
  Benefits of Policy Gradient Methods
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Stochastic Policies:&lt;/strong&gt; Policy gradient methods enable learning state-dependent stochastic policies. Unlike Q-learning, where exploration is introduced through the epsilon-greedy method, policy gradient methods allow for state-specific stochastic behavior. This approach is particularly valuable for partially observable states, where actions need to be sampled from a probability distribution rather than determined deterministically. Importantly, this method doesn't preclude deterministic policies for certain states. The result is a flexible mix of stochastic and deterministic actions, tailored to the confidence level in each state.&lt;/li&gt;
&lt;li&gt;This is a more straightforward objective for many environments. Learning the value function can be more challenging and may also not be required. The quote below sums it up.&lt;/li&gt;
&lt;/ul&gt;

&lt;blockquote&gt;
&lt;p&gt;When solving a problem of interest, do not solve a more general problem as an intermediate step&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;&lt;em&gt;Vladimir Vapnik&lt;/em&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;/blockquote&gt;

&lt;p&gt;This means that if our objective is to get the policy and we are not interested in the values, we don’t need to solve for the values to get to the policy. &lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Policy Gradient methods also show better convergence properties as the action probabilities change smoothly. Value-based techniques are prone to oscillations because the policy can change drastically throughout training.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Deriving Policy Gradient
&lt;/h3&gt;

&lt;p&gt;Let's explore how we can derive the policy. This theory could be a little dense if you are new to the concept. But trust me, if you can understand the idea of policy gradient, many algorithms like A2C and PPO become pretty easy to understand. So it’s OK to spend some time here. You can check out the references at the end of this article to some sources that elaborate on this topic.  First, we need an objective to optimize. In Q-learning, we aimed to minimize the loss between predicted and target values. Specifically, our goal was to match the actual action-value function of a given policy. We parameterized a value function and minimized the mean squared error between predicted and target values. Note that we didn't have true target values; instead, we used actual returns in Monte Carlo methods or predicted returns in bootstrapping methods.&lt;/p&gt;

&lt;p&gt;In policy-based methods, however, the objective is to maximize the performance of a parameterized policy. We're running gradient ascent (or regular gradient descent on the negative performance). An agent’s performance is the expected total discounted reward from the initial state—equivalent to the expected state-value function from all initial states of a given policy.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F03kfuqjyxcpiec9zregx.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F03kfuqjyxcpiec9zregx.png" alt="policy loss"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Let’s understand these equations step by step.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="p"&gt;1.&lt;/span&gt; The objective function L(θ) is defined as the expected value of the initial state s_0
&lt;span class="p"&gt;2.&lt;/span&gt; The expected value of the initial state is nothing but the expected total return from the trajectory.
&lt;span class="p"&gt;3.&lt;/span&gt; The expectation can be shown as the weighted sum of the expected total return across all the different possible trajectories weighted by the probability of a trajectory given the weights θ of the policy network.
&lt;span class="p"&gt;4.&lt;/span&gt; We take the gradient of the objective as it is required for the gradient ascent. 
&lt;span class="p"&gt;5.&lt;/span&gt; We then use the gradient of a log trick to refactor the gradient of the probability of the trajectory. 
&lt;span class="p"&gt;6.&lt;/span&gt; Finally, we can pack the whole thing back into an expectation.
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3tv7vqeoa4guw82x6faq.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3tv7vqeoa4guw82x6faq.png" alt="likelihood"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="p"&gt;7.&lt;/span&gt; We can see that the probability of a trajectory is the product of the likelihood of the state transition and the probability of taking the action (π).
&lt;span class="p"&gt;8.&lt;/span&gt; We can compactly write the above part.
&lt;span class="p"&gt;9.&lt;/span&gt; We take the log and the gradient. The log operations make the product a sum. 
&lt;span class="p"&gt;10.&lt;/span&gt; Only the last term depends on theta as π depends on theta.
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fomunprkblg62s409evmm.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fomunprkblg62s409evmm.png" alt="policy loss"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="p"&gt;11.&lt;/span&gt; Plugging 10 into 6, we get the final (almost) equation of the policy loss / objective gradient. We usually calculate the value without the gradient and then use the PyTorch backward function to get the gradient.
&lt;span class="p"&gt;12.&lt;/span&gt; One final modification we make is that instead of using the return of the complete trajectory we calculate the return from the time step t. This is because the action at time step t cannot influence the previous rewards and can only influence the future rewards.
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;This is the so called Vanilla Reinforcement Learning equation. But as you will see later in the results of this algorithm, there are some issues with this vanilla approach. Since the return is for the whole trajectory, there can be a high variance in the return value across different trajectories. Thus, to get a reasonable estimate, we need to sample a lot of trajectories. We need to reduce this variance and introduce some bias. This is the typical bias variance trade-off we see everywhere in machine learning. One of the simplest ways to do this is by introducing the Value Function estimate and a new term called Advantage. We define advantage as follows. &lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;At(s)=Gt(τ)−Vt(s)
A_t(s) = G_t(\tau) - V_t(s)
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;A&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;G&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;τ&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;This says how good the actual return is compared to the estimated return for that state. If the advantage is positive, that means the action that the agent took at time step t was good compared to the other actions, and if the advantage is negative, that means the action was bad and made the agent perform worse subsequently. &lt;/p&gt;

&lt;p&gt;We then modify equation 12 as follows:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fasijexuownniuiti7n44.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fasijexuownniuiti7n44.png" alt="final policy loss"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is called REINFORCE with Baseline or Value Policy Gradient (VPG) algorithm. What this equation does is that if the advantage is positive for a given action and its log probability given by log⁡ 𝜋(𝐴𝑡∣𝑆𝑡)﻿, likelihood of taking this action. If the advantage is negative, it decreases the probability of taking that action. &lt;/p&gt;

&lt;p&gt;In addition to the policy loss, in the policy gradient methods, we also have the entropy loss. Entropy defines how much randomness is there in the system. More entropy corresponds to more randomness. It is good to maintain some entropy in the agent so that the agent does not become completely deterministic.&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Lent(θ)=∑aπθ(a∣s)θlog⁡πθ(a∣s)
\begin{equation}
L_{ent}(\theta) = \sum_a \pi_\theta(a|s)\theta \log\pi\theta(a|s)
\end{equation}
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mtable"&gt;&lt;span class="col-align-c"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;L&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;e&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;θ&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mop op-limits"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;a&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="mop op-symbol large-op"&gt;∑&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;π&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;θ&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;a&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord mathnormal"&gt;θ&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mop"&gt;lo&lt;span&gt;g&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;π&lt;/span&gt;&lt;span class="mord mathnormal"&gt;θ&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;a&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="tag"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="eqn-num"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;Finally the total loss is given as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ftq809ypcoeis21hz06oi.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ftq809ypcoeis21hz06oi.png" alt="total loss"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The value loss is given by the mean square error between the actual returns and the estimated value from the value network. The parameters of the value network are given by η.&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Lvalue(η)=1T(∑t=0T−1(Gt(τ)−V(st))2)
L_{value}(\eta) = \frac{1}{T} \left(\sum_{t=0}^{T-1} (G_t(\tau) - V(s_t))^2 \right)
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;L&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;v&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;a&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;l&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;u&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;e&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;η&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;T&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;&lt;span class="mopen delimcenter"&gt;&lt;span class="delimsizing size4"&gt;(&lt;/span&gt;&lt;/span&gt;&lt;span class="mop op-limits"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mrel mtight"&gt;=&lt;/span&gt;&lt;span class="mord mtight"&gt;0&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="mop op-symbol large-op"&gt;∑&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;T&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;G&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;τ&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mclose"&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose delimcenter"&gt;&lt;span class="delimsizing size4"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;The value loss is given by the mean square error between the actual returns and the estimated value from the value network. The parameters of the value network are given by $\eta$.&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;Lvalue(η)=1T(∑t=0T−1(Gt(τ)−V(st))2)
L_{value}(\eta) = \frac{1}{T} \left(\sum_{t=0}^{T-1} (G_t(\tau) - V(s_t))^2 \right)
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;L&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;v&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;a&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;l&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;u&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;e&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;η&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;T&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="minner"&gt;&lt;span class="mopen delimcenter"&gt;&lt;span class="delimsizing size4"&gt;(&lt;/span&gt;&lt;/span&gt;&lt;span class="mop op-limits"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mrel mtight"&gt;=&lt;/span&gt;&lt;span class="mord mtight"&gt;0&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="mop op-symbol large-op"&gt;∑&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;T&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;G&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;τ&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;V&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mclose"&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose delimcenter"&gt;&lt;span class="delimsizing size4"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;That’s it. To use this equation, all we need to do is the following:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Create two neural networks: one for the policy and another for the value function. &lt;/li&gt;
&lt;li&gt;Use the network to get a trajectory by sampling and performing the actions until the termination state. &lt;/li&gt;
&lt;li&gt;We also store the actions' log probability and each state's value estimate.&lt;/li&gt;
&lt;li&gt;Calculate the return $G_t$ for the trajectory at each time step. &lt;/li&gt;
&lt;li&gt;Using the log probability, $G_t$ , and Values, we use equation 13 by calculating the advantage to get the policy objective. &lt;/li&gt;
&lt;li&gt;We can take the negative of this to get the loss that can be minimized because PyTorch optimizers support minimization by default.
&lt;/li&gt;
&lt;li&gt;The value loss is calculated. &lt;/li&gt;
&lt;li&gt;We can also calculate the entropy of the policy using 13. &lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Disadvantages of REINFORCE Algorithm
&lt;/h3&gt;

&lt;p&gt;Even though REINFORCE algorithm is a simple algorithm to get started with Policy Gradient methods there is a disadvantage. As this algorithm relies on complete episodes for the learning we need to have a terminal state so that an episode can be terminated. That’s why we cannot use our Pendulum environment here because it does not have a terminal state. We are going to use the Cartpole environment as that gives us a termination condition. But don’t worry, in the next part we will look into how to solve this problem. But if you understand this algorithm rest of the algorithms will only be minor adjustments. &lt;/p&gt;

&lt;p&gt;Also the learning is slow because even with baseline the variance is high. &lt;/p&gt;

&lt;p&gt;Enough with the theory. Now let’s look at the code implementation first without the Value Estimate or Baseline.&lt;/p&gt;

&lt;h3&gt;
  
  
  Code Implementation: REINFORCE w/o Baseline
&lt;/h3&gt;

&lt;p&gt;You can check out the GitHub Repo here: &lt;br&gt;
&lt;a href="https://github.com/akshayballal95/rl-blog-series/tree/reinforce" rel="noopener noreferrer"&gt;https://github.com/akshayballal95/rl-blog-series/tree/reinforce&lt;/a&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Import the dependencies:
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.optim&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AdamW&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.tensorboard&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SummaryWriter&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;ul&gt;
&lt;li&gt;Create the Policy Network&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;We create a simple network with the number of features as the input dimension and the number of available actions as the output. &lt;code&gt;torch.distributions.Categorical&lt;/code&gt; function helps us to create a distribution from the output logits. Its very similar to using SoftMax but we get some handy functions like &lt;code&gt;sample()&lt;/code&gt; , &lt;code&gt;entropy()&lt;/code&gt; and &lt;code&gt;log_prob()&lt;/code&gt; .&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PolicyNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PolicyNet&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc3&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;dist&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;distributions&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Categorical&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;entropy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;entropy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;log_prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;entropy&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;ul&gt;
&lt;li&gt;The REINFORCE Agent Class&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In our REINFORCE class we have the &lt;code&gt;__init__&lt;/code&gt;  function that initiates the attributes that we need. The main hyperparameters here are &lt;code&gt;gamma&lt;/code&gt; , learning rate, and the number of steps. You can see that we no longer have epsilon because the exploration is intrinsically baked into the algorithm. We are using the Adam optimizer as usual.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Reinforce&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;n_steps&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;policy_net&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PolicyNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;policy_net&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

        &lt;span class="c1"&gt;# stats
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean_episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Next is the rollout function. Here rollout is nothing but one episode with a fixed policy. We get the action, &lt;code&gt;log_prob&lt;/code&gt; and entropy using the policy network. After taking a step in the environment using the action we store the rewards, &lt;code&gt;log_probs&lt;/code&gt; and the entropies to a list. We also write the current mean episodic reward to TensorBoard after 100 episodes.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;rollout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;entropies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

        &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

            &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;entropy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;policy_net&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
            &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_prob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;entropies&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;entropy&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;

            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_rewards&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_steps&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

                &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

                    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean_episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_rewards&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt;
                    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_description&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Reward: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean_episode_reward&lt;/span&gt; &lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_scalar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Reward&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean_episode_reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_steps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;
                    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

                &lt;span class="k"&gt;break&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This handy method allows us to calculate the returns at every time step of the trajectory. Notice that we do it in a reverse way but we could as easily implement in the forward way. I just find this implementation more neat.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;calculate_returns&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;    

        &lt;span class="n"&gt;next_returns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;returns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;reversed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;))):&lt;/span&gt;
            &lt;span class="n"&gt;next_returns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;next_returns&lt;/span&gt;
            &lt;span class="n"&gt;returns&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_returns&lt;/span&gt;   

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;returns&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we can write our &lt;code&gt;learn()&lt;/code&gt; function. Here we just retrieve everything that is required to be plugged in equation 15. We weight the entropy loss by 0.001 to keep some entropy for exploration.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;learn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;entropies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;entropies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 

        &lt;span class="n"&gt;returns&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;calculate_returns&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;advantages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;returns&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; 

        &lt;span class="n"&gt;policy_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;advantages&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;entropy_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;entropies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;policy_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;policy_loss&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;0.001&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;entropy_loss&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_policy&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;policy_loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_grad_norm_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;policy_net&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_policy&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Finally we have the &lt;code&gt;train&lt;/code&gt; function which just calls the rollout and the learn function until the number of steps is completed.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SummaryWriter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;runs/reinforce_logs/REINFORCE_NO_BASELINE&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pbar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;total&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;leave&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;total_steps&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rollout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;learn&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Results
&lt;/h3&gt;

&lt;p&gt;We can run this agent as we did in the earlier parts:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
               &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;human&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
               &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;policy_net&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;render&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951574%2FBlogs%2Freinforcement_learning%2Fwithout_baseline_ezquaq.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951574%2FBlogs%2Freinforcement_learning%2Fwithout_baseline_ezquaq.png" alt="Results without Baseline"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We can see that vanilla REINFORCE algorithm struggles to learn reliably. The learning is very unstable due to the high variance. The maximum reward in the CartPole environment is 500 and our agent is able to achieve 300. So there is scope for improvement.  Now let’s see how baseline can help. &lt;/p&gt;

&lt;h3&gt;
  
  
  REINFORCE with Baseline (Value Policy Gradient - VPG)
&lt;/h3&gt;

&lt;h3&gt;
  
  
  Code Implementation: REINFORCE with Baseline
&lt;/h3&gt;

&lt;p&gt;This code is almost similar. The only difference is that we have a value network and store the values during rollout. We use these values in the learn function to gain an advantage. &lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Creating the Value Network
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;ValueNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ValueNet&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;256&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;256&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc3&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then we just introduce the values in the agent class. I am not writing the whole code here as most of the other things are same.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_net&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;ValueNet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_net&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt;
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;rollout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

              &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt;
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;


        &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

             &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt;
&lt;/span&gt;           &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;value_net&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_numpy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;

           &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
           &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt;
&lt;/span&gt;
        &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;learn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;value_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mse_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;returns&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

                &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt;
&lt;/span&gt;
          &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_value&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;value_loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_grad_norm_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_net&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer_value&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="c1"&gt;#&amp;lt;SAME AS BEFORE&amp;gt; 
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Results
&lt;/h3&gt;

&lt;p&gt;We can run this agent just like before.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951646%2FBlogs%2Freinforcement_learning%2Fcartpole_uu4p8e.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951646%2FBlogs%2Freinforcement_learning%2Fcartpole_uu4p8e.gif" alt="Cartpole"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951575%2FBlogs%2Freinforcement_learning%2Fwith_baseline_pfnq2b.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fres.cloudinary.com%2Fdltwftrgc%2Fimage%2Fupload%2Fv1722951575%2FBlogs%2Freinforcement_learning%2Fwith_baseline_pfnq2b.png" alt="With Baseline result"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The baseline version shows amazing improvements. The learning is stable, and the episodic reward grows consistently. So, if you ever have an episodic environment, REINFORCE with Baseline can be a viable choice. &lt;/p&gt;

&lt;h3&gt;
  
  
  Conclusion
&lt;/h3&gt;

&lt;p&gt;In this part, we learned about the policy gradient method and wrote our first algorithm, REINFORCE. We saw that although this algorithm is powerful, it has some limitations,. Namely, it cannot work well with non-episodic tasks, and there is high variance. Thus, we must implement bootstrapping to reduce our dependence on complete episodic returns. We will look at this in the next part, where we will see our first actor critic method: Advantage Actor Critic (A2C). With that we will also see our first continuous action space implementation. Things are going to get exciting. Trust me!&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;br&gt;
😸&lt;a href="https://github.com/akshayballal95/rl-blog-series/tree/reinforce" rel="noopener noreferrer"&gt;GitHub Repo&lt;/a&gt;&lt;br&gt;
🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt;&lt;/p&gt;

</description>
      <category>machinelearning</category>
      <category>python</category>
      <category>ai</category>
      <category>tutorial</category>
    </item>
    <item>
      <title>Using YOLO with CLIP to improve Retrieval</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Fri, 02 Aug 2024 14:10:15 +0000</pubDate>
      <link>https://dev.to/akshayballal/using-yolo-with-clip-to-improve-retrieval-4l61</link>
      <guid>https://dev.to/akshayballal/using-yolo-with-clip-to-improve-retrieval-4l61</guid>
      <description>&lt;p&gt;In this article we are going to see how we can use object detection models like YOLO along with multimodal embedding models like CLIP to make image retrieval better. &lt;/p&gt;

&lt;p&gt;Here is the idea: CLIP image retrieval works as follows: We embed the images we have using a CLIP model and store them somewhere, like in a vector database. Then, during inference, we can use a query image or a prompt, embed that, and find the closest images from the stored embeddings that can be retrieved. The problem is when the embedded images have too many objects or some objects are in the background, and we still want our system to retrieve them. This is because CLIP embeds the image as a whole. Think of it like what a word embedding model is to a sentence embedding model. We want to be able to search for words that are equivalent to objects in an image. So, the solution is to decompose the image into different objects using an object detection model. Then, embed these decomposed images but link them to their parent image. This will allow us to retrieve the crops and get the parent from which the crop originated.  Let’s see how it works. &lt;/p&gt;

&lt;h3&gt;
  
  
  Install the Dependencies and import them
&lt;/h3&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="err"&gt;!&lt;/span&gt;&lt;span class="n"&gt;pip&lt;/span&gt; &lt;span class="n"&gt;install&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="n"&gt;ultralytics&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="n"&gt;matplotlib&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="n"&gt;pillow&lt;/span&gt; &lt;span class="n"&gt;zipfile36&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;ultralytics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;YOLO&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pillow&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;Zipfile&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Zipfile&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;BadZipFile&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CLIPProcessor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;CLIPModel&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;CLIPVisionModelWithProjection&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;CLIPTextModelWithProjection&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Download the COCO Dataset and unzip
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="err"&gt;!&lt;/span&gt;&lt;span class="n"&gt;wget&lt;/span&gt; &lt;span class="n"&gt;http&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cocodataset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;org&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;zips&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;val2017&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;zip&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;O&lt;/span&gt; &lt;span class="n"&gt;coco_val2017&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;zip&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;extract_zip_file&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extract_path&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;try&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nc"&gt;ZipFile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extract_path&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;.zip&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;zfile&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;zfile&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extractall&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extract_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# remove zipfile
&lt;/span&gt;        &lt;span class="n"&gt;zfileTOremove&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;extract_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;.zip&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;isfile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;zfileTOremove&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;remove&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;zfileTOremove&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error: %s file not found&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;zfileTOremove&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;except&lt;/span&gt; &lt;span class="n"&gt;BadZipFile&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error:&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;e&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;extract_val_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;./coco_val2017&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="nf"&gt;extract_zip_file&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extract_val_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;We can then take some of the images and create a list of examples.&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;source&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;coco_val2017/val2017/000000000139.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/content/coco_val2017/val2017/000000000632.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/content/coco_val2017/val2017/000000000776.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/content/coco_val2017/val2017/000000001503.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/content/coco_val2017/val2017/000000001353.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;/content/coco_val2017/val2017/000000003661.jpg&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Initiate the YOLO model and the CLIP Model
&lt;/h3&gt;

&lt;p&gt;In this example we are going to use the latest Ultralytics &lt;code&gt;Yolo10x&lt;/code&gt; model along with OpenAI &lt;code&gt;clip-vit-base-patch32&lt;/code&gt; .&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;

device &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;"cuda"&lt;/span&gt;

 &lt;span class="c"&gt;# YOLO Model&lt;/span&gt;
model &lt;span class="o"&gt;=&lt;/span&gt; YOLO&lt;span class="o"&gt;(&lt;/span&gt;&lt;span class="s1"&gt;'yolov10x.pt'&lt;/span&gt;&lt;span class="o"&gt;)&lt;/span&gt;

&lt;span class="c"&gt;# Clip model&lt;/span&gt;
model_id &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;"openai/clip-vit-base-patch32"&lt;/span&gt;
image_model &lt;span class="o"&gt;=&lt;/span&gt; CLIPVisionModelWithProjection.from_pretrained&lt;span class="o"&gt;(&lt;/span&gt;model_id, device_map &lt;span class="o"&gt;=&lt;/span&gt; device&lt;span class="o"&gt;)&lt;/span&gt;
text_model &lt;span class="o"&gt;=&lt;/span&gt; CLIPTextModelWithProjection.from_pretrained&lt;span class="o"&gt;(&lt;/span&gt;model_id, device_map &lt;span class="o"&gt;=&lt;/span&gt; device&lt;span class="o"&gt;)&lt;/span&gt;
processor &lt;span class="o"&gt;=&lt;/span&gt; CLIPProcessor.from_pretrained&lt;span class="o"&gt;(&lt;/span&gt;model_id&lt;span class="o"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Running the detection model
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;

results &lt;span class="o"&gt;=&lt;/span&gt; model&lt;span class="o"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;source&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;source&lt;/span&gt;, device &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="s2"&gt;"cuda"&lt;/span&gt;&lt;span class="o"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Let’s show us results with this code snippet&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="c1"&gt;# Visualize the results
&lt;/span&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;results&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Plot results image
&lt;/span&gt;    &lt;span class="n"&gt;im_bgr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# BGR-order numpy array
&lt;/span&gt;    &lt;span class="n"&gt;im_rgb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fromarray&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;im_bgr&lt;/span&gt;&lt;span class="p"&gt;[...,&lt;/span&gt; &lt;span class="p"&gt;::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# RGB-order PIL image
&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;im_rgb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Image &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;



&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A1100%2Fformat%3Awebp%2F1%2Ajg23xIlQ6oZNHwcI5Dr0YA.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A1100%2Fformat%3Awebp%2F1%2Ajg23xIlQ6oZNHwcI5Dr0YA.png" alt="Yolo"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;So we can see that the YOLO model works quite well in detecting the objects in the images. It does make some mistakes where it has tagged the monitor as TV. But that is fine. The actual classes that YOLO assigns are not that essential because we are going to use CLIP to do the inference.&lt;/p&gt;

&lt;h3&gt;
  
  
  Defining some helper Classes
&lt;/h3&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;CroppedImage&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;parent&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;parent&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;box&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;box&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;display&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;im_rgb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parent&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;im_rgb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;crop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_cropped_image&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;im_rgb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parent&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;im_rgb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;crop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__str__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;CroppedImage(parent=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;parent&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, boxes=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, cls=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__repr__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;__str__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;YOLOImage&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cropped_images&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cropped_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cropped_images&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_image&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_caption&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;cls&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cropped_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;unique_cls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;count_cls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;count&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;unique_cls&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;

    &lt;span class="n"&gt;count_string&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;count&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;,&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;count&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;count_cls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;this image contains &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;count_string&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__str__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;__repr__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__repr__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;cls&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cropped_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;YOLOImage(image=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, cropped_images=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;ImageEmbedding&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;image_path&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt;



&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  CroppedImage Class
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;CroppedImage&lt;/code&gt; class represents a portion of an image cropped from a larger parent image. It is initialized with the path to the parent image, the bounding box defining the crop area, and a class label (e.g., "cat" or "dog"). This class includes methods to display the cropped image and to retrieve it as an image object. The &lt;code&gt;display&lt;/code&gt; method allows for visualizing the cropped portion either on a provided axis or by creating a new figure, making it versatile for different use cases. Additionally, &lt;code&gt;__str__&lt;/code&gt; and &lt;code&gt;__repr__&lt;/code&gt; methods are implemented for easy and informative string representation of the object.&lt;/p&gt;
&lt;h3&gt;
  
  
  YOLOImage Class
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;YOLOImage&lt;/code&gt; class is designed to handle images processed with the YOLO object detection model. It takes the path to the original image and a list of &lt;code&gt;CroppedImage&lt;/code&gt; instances that represent the detected objects within the image. The class provides methods to open and display the full image and to generate a caption summarizing the objects detected in the image. The caption method aggregates and counts the unique class labels from the cropped images, providing a concise description of the image contents. This class is particularly useful for managing and interpreting results from object detection tasks.&lt;/p&gt;
&lt;h3&gt;
  
  
  ImageEmbedding Class
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;ImageEmbedding&lt;/code&gt; class has an image and its associated embedding, which is a numerical representation of the image's features. This class can be initialized with the path to the image, the embedding vector, and optionally a &lt;code&gt;CroppedImage&lt;/code&gt; instance if the embedding corresponds to a specific cropped portion of the image. The &lt;code&gt;ImageEmbedding&lt;/code&gt; class is essential for tasks involving image similarity, classification, and retrieval, as it provides a structured way to store and access the image data alongside its computed features. This integration facilitates efficient image processing and machine learning workflows.&lt;/p&gt;
&lt;h3&gt;
  
  
  Crop each image and create a list of YOLOImage Objects
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;yolo_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;YOLOImage&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="n"&gt;names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;names&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;results&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;crops&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;CroppedImage&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
  &lt;span class="n"&gt;boxes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;boxes&lt;/span&gt;
  &lt;span class="n"&gt;classes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;boxes&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cls&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;box&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;boxes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;box&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;xyxy&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;cpu&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
    &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CroppedImage&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;parent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;box&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;box&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;classes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()])&lt;/span&gt;
    &lt;span class="n"&gt;crops&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;yolo_images&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nc"&gt;YOLOImage&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cropped_images&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;crops&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Embed Images using CLIP
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;image_embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;yolo_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;processor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;image_processor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_image&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;image_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pixel_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;image_embeds&lt;/span&gt;
  &lt;span class="n"&gt;embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Normalize the embeddings
&lt;/span&gt;  &lt;span class="n"&gt;image_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;ImageEmbedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_embedding&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cropped_images&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;processor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;image_processor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_cropped_image&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;image_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pixel_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pixel_values&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;image_embeds&lt;/span&gt;
    &lt;span class="n"&gt;embeddings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Normalize the embeddings
&lt;/span&gt;
    &lt;span class="n"&gt;image_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;ImageEmbedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embeddings&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_embedding&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

   &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;image_embeddings_tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;image_embedding&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image_embedding&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;]).&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;We can now take these image embeddings and store in a vector database if we want to. But in this example we will just use the inner dot product technique to check the similarity and retrieve the images.&lt;/p&gt;

&lt;h3&gt;
  
  
  Retrieval
&lt;/h3&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;query&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;image of a flowerpot&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

&lt;span class="n"&gt;text_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;processor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;text_embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;text_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;text_embedding&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;text_embeds&lt;/span&gt;

&lt;span class="n"&gt;similarities&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;matmul&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text_embedding&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;image_embeddings_tensor&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;cpu&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# get the top 5 similar images
&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;
&lt;span class="n"&gt;top_k_indices&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;similarities&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;

&lt;span class="c1"&gt;# Display the top 5 results
&lt;/span&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;top_k_indices&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;cropped_image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;display&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
  &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image_embeddings&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;image_path&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
  &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;axis&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;off&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;axis&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;off&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Original Image&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2AyvMhqWoetAgw8oEM_NzXpw.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2AyvMhqWoetAgw8oEM_NzXpw.png" alt="living room"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2AllUvk40MKaUAzjYivpFLCw.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2AllUvk40MKaUAzjYivpFLCw.png" alt="monitor"&gt;&lt;/a&gt;&lt;br&gt;
&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2Ar_VWg5cS4TLbVHwlmFBTKg.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2Ar_VWg5cS4TLbVHwlmFBTKg.png" alt="plant"&gt;&lt;/a&gt;&lt;br&gt;
&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2Ajzmfmkd0zrw_Tnr_AREMuw.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A2000%2Fformat%3Awebp%2F1%2Ajzmfmkd0zrw_Tnr_AREMuw.png" alt="boy"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can see that we are able to retrieve even small plants which are hidden away in the background. Also sometimes it pulls the original image as the result because we are also embedding that .&lt;/p&gt;

&lt;p&gt;This can be a very powerful technique. You can also finetune both the models for detection and embedding for your own images and improve the performance even more.&lt;/p&gt;

&lt;p&gt;One downside is that we have to run the CLIP model on all the objects detected. One way to mitigate this is by limiting the number of boxes that YOLO produces.&lt;/p&gt;

&lt;p&gt;You can check out the code on Colab at this link.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/drive/1wBnRxX-0DaWtbgLzbNoPvjOcPTRzkQnP?usp=sharing" rel="noopener noreferrer"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open in Colab"&gt;&lt;/a&gt;&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;/p&gt;

&lt;p&gt;🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt;&lt;/p&gt;

</description>
      <category>machinelearning</category>
      <category>computervision</category>
      <category>ai</category>
      <category>python</category>
    </item>
    <item>
      <title>Reinforcement Learning from Scratch - Part 2 - Deep Q Learning</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Tue, 30 Jul 2024 05:46:17 +0000</pubDate>
      <link>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-2-deep-q-learning-4opo</link>
      <guid>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-2-deep-q-learning-4opo</guid>
      <description>&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5tyfondpk48vfmruej1l.jpg" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5tyfondpk48vfmruej1l.jpg" alt="cover image"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Hey everyone. This is the second part of my Reinforcement Learning Series, where we look at the different RL algorithms and discuss their implementation details. In the &lt;a href="https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-1-tabular-q-learning-22k3"&gt;last part&lt;/a&gt;, we saw a basic RL algorithm that uses tabular Q learning. One disadvantage of that technique was that it required us to discretize the observation and action spaces. This increases the dimensionality of the observation space quite a bit and makes learning hard and slower. &lt;/p&gt;

&lt;p&gt;To overcome this problem, at least partially, in this part, we look at Deep Q Learning (DQN), a prominent technique that brought Reinforcement learning into the mainstream. We can use the continuous observation space in this technique but still need a discrete action space. First, we will put together a simple DQN algorithm and then see how we can use LSTM architecture as the network in DQN. Let’s dive in.&lt;/p&gt;

&lt;h2&gt;
  
  
  DQN Theory
&lt;/h2&gt;

&lt;p&gt;If you are not interested in the theory of DQN, that is fine. The code implementation below explains a lot. However, I am including some of the main driving equations below for completeness. &lt;/p&gt;

&lt;p&gt;From the previous part, we know that the q value of a state action pair is the expected total return from that state if a certain action is taken. Using the Bellman Equation, we can write it out as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh0k6ynj0fwv16g0omk7.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh0k6ynj0fwv16g0omk7.png" alt="bellman-equation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;If we are following the policy π we can say that:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs4x8p90u0ebd3x1127kt.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs4x8p90u0ebd3x1127kt.png" alt="bootstraping"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;From this, we can say that the q-value of a state action pair is equal to the average (expectation) of the sum of reward and γ times the maximum q-value of the next state. This is called bootstrapping&lt;/p&gt;

&lt;p&gt;For example, if we have a transition (s, 2, r, s’), we take action 2. Then, we do the following.&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Pass the state through the q-network and get the q-values of all the actions. &lt;/li&gt;
&lt;li&gt;Pick the q-value of the action taken, which is 2 here. &lt;/li&gt;
&lt;li&gt;Pass the next state s’ through the &lt;strong&gt;target q-network&lt;/strong&gt; and get q-values of all the actions of s’.&lt;/li&gt;
&lt;li&gt;Pick the maximum q-value.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjca4s695yof9qxnvsjxq.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjca4s695yof9qxnvsjxq.png" alt="target"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We can then take a Mean Square Error Loss between the target and the prediction and perform backward propagation. &lt;/p&gt;

&lt;p&gt;The below image can make this more clear&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9a600jm7bg5jbso3vn0b.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9a600jm7bg5jbso3vn0b.png" alt="visualization"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  The Algorithm
&lt;/h3&gt;

&lt;p&gt;Below you can see the pseudocode for the Deep Q Learning. It looks more complex than it actually is. You can follow along with the code to get a better understanding.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9tm3378h11yn5nh78yyd.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9tm3378h11yn5nh78yyd.png" alt="pseudocode"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Code Implementation
&lt;/h2&gt;

&lt;p&gt;We will again use the Pendulum environment we used in the last part. We will also log some metrics to TensorBoard to see how the learning is going. &lt;/p&gt;

&lt;p&gt;You can check out the code at: &lt;a href="https://github.com/akshayballal95/rl-blog-series/tree/deep-q-learning" rel="noopener noreferrer"&gt;GitHub Repo&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Let’s start with importing the dependencies&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.optim&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AdamW&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;copy&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;deepcopy&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.tensorboard&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SummaryWriter&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Pendulum Environment Wrapper
&lt;/h3&gt;

&lt;p&gt;We define the Gaussian Function that I introduced in Part 1, which creates the set of possible actions.&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Then, we create the wrapper. This is very similar, in fact more simpler than the previous wrapper because we can use the observation space from the Pendulum Environment as it is and don't have to do any discretization. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Wrapper&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;

        &lt;span class="c1"&gt;# Create a Discrete action space
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Discrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;kernel&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; 
                                               &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mf"&gt;16.2736044&lt;/span&gt; &lt;span class="c1"&gt;# normalize the reward between -1 and 1
&lt;/span&gt;        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt; &lt;span class="c1"&gt;# normalize the observation between -1 and 0
&lt;/span&gt;        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Resets the environment.

        Returns:
        - The initial discrete observation and additional information.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Notable Implementation Details&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Reward Scaling&lt;/strong&gt;: The rewards were scaled using MinMax Scaler to make the learning smoother. This was done by dividing the returned reward by the minimum possible reward: rscaled = r/rmin. The rmin is defined in the documentation of the pendulum environment on gymnasium. The intuition behind this is that the network does not have to output significantly big q values, and the weights can remain smaller, making learning smooth&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Scaled Observations&lt;/strong&gt;: The observation space is scaled by dividing the maximum value of the observation space.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Define the Q-Network
&lt;/h3&gt;

&lt;p&gt;Here we define a simple Fully Connected Network which maps the state to the q-values of all possible actions. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;QNetwork&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;QNetwork&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc3&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Create ReplayMemory Class
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;ReplayMemory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;capacity&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;  &lt;span class="c1"&gt;# Number of dimensions in the state space
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;  &lt;span class="c1"&gt;# Number of discrete actions
&lt;/span&gt;
        &lt;span class="c1"&gt;# Initialize arrays to store the replay memory
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_states&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;next_states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_states&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;push&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;states&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;next_states&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;indices&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;choice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;replace&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;states&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;int64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;next_states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;next_states&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;indices&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;states&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_states&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Create Agent Class
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;Initialization and some helper functions&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Here we define the agent initialization function that initiates the various variables and hyperparameters. We also have a &lt;code&gt;get_action&lt;/code&gt; function that takes random action with a probability of ϵ and otherwise takes an action by passing the state through the q-network and taking argmax i.e selecting action with the highest q-value. &lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0003&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;initial_epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;min_epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;decay_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9999&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;n_rollouts&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;  &lt;span class="c1"&gt;# Environment
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;  &lt;span class="c1"&gt;# Computation device (CPU or GPU)
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;  &lt;span class="c1"&gt;# Discount factor
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;  &lt;span class="c1"&gt;# Learning rate
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;initial_epsilon&lt;/span&gt;  &lt;span class="c1"&gt;# Initial epsilon value for exploration
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;  &lt;span class="c1"&gt;# Batch size for training
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_rollouts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;n_rollouts&lt;/span&gt;  &lt;span class="c1"&gt;# Number of rollouts to collect
&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;  &lt;span class="c1"&gt;# Initial epsilon value
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;min_epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;min_epsilon&lt;/span&gt;  &lt;span class="c1"&gt;# Minimum epsilon value
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt;  &lt;span class="c1"&gt;# Epsilon decay rate
&lt;/span&gt;
        &lt;span class="c1"&gt;# Replay memory to store experiences
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;replay_memory&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;ReplayMemory&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;capacity&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Q-network and target network for Q-learning
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;QNetwork&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
        &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_network&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;deepcopy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Optimizer for Q-network
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Number of dimensions in the state space
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

        &lt;span class="c1"&gt;# For metrics
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_time_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Number of time steps
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Number of episodes
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Number of gradient updates
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;  &lt;span class="c1"&gt;# Best reward seen so far
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;greedy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;greedy&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="c1"&gt;# Epsilon-greedy exploration
&lt;/span&gt;            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Random action
&lt;/span&gt;        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Convert observation to tensor
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Set Q-network to evaluation mode
&lt;/span&gt;        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;q_values&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get Q-values for the observation
&lt;/span&gt;            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;q_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Return action with highest Q-value
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;sample_experience&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;replay_memory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Sample experiences from replay memory
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update_target&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;  &lt;span class="c1"&gt;# Update target network
&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Collect Rollouts:&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;collect_rollouts&lt;/code&gt; function collects the transitions by simulating the environment by using the epsilon-greedy policy that we have defined in the &lt;code&gt;get_actions&lt;/code&gt; function. Each transition is stored in the replay memory and we decay the epsilon to reduce exploration. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;collect_rollouts&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Reset environment
&lt;/span&gt;    &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Total rewards
&lt;/span&gt;    &lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Total episodes
&lt;/span&gt;    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_rollouts&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;greedy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get action
&lt;/span&gt;        &lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Step environment
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;replay_memory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;push&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Save the transition
&lt;/span&gt;        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_obs&lt;/span&gt;  &lt;span class="c1"&gt;# Update observation
&lt;/span&gt;        &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;  &lt;span class="c1"&gt;# Accumulate reward
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_time_steps&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;  &lt;span class="c1"&gt;# Increment time steps
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="c1"&gt;# Check if episode ended
&lt;/span&gt;            &lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Reset environment
&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;min_epsilon&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Decrease epsilon
&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;episodes&lt;/span&gt;  &lt;span class="c1"&gt;# Return the average reward per episode
&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Learn&lt;/strong&gt; &lt;/p&gt;

&lt;p&gt;The &lt;code&gt;learn&lt;/code&gt; function samples the experiences of &lt;code&gt;batch_size&lt;/code&gt; from the replay memory, predicts the q-values and gets the target values. We then get the loss and perform backward propagation. We do this for the defined number of &lt;code&gt;epochs&lt;/code&gt;. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;learn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Set Q-network to training mode
&lt;/span&gt;
    &lt;span class="n"&gt;average_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample_experience&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Sample a batch of experiences
&lt;/span&gt;        &lt;span class="n"&gt;q_values&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get Q-values for the batch
&lt;/span&gt;        &lt;span class="n"&gt;next_q_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;target_network&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get Q-values for the next states
&lt;/span&gt;
        &lt;span class="n"&gt;q_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gather&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Gather Q-values for taken actions
&lt;/span&gt;        &lt;span class="n"&gt;next_q_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_q_values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;  &lt;span class="c1"&gt;# Get max Q-value for next states
&lt;/span&gt;        &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;next_q_value&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Compute target Q-value
&lt;/span&gt;
        &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;smooth_l1_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;q_value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Compute loss
&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Zero gradients
&lt;/span&gt;        &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Backpropagate
&lt;/span&gt;        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_grad_norm_&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Clip gradients
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Update Q-network
&lt;/span&gt;
        &lt;span class="n"&gt;average_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;average_loss&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Update average loss
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;  &lt;span class="c1"&gt;# Increment update count
&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="c1"&gt;# Update target network periodically
&lt;/span&gt;            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update_target&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;average_loss&lt;/span&gt;  &lt;span class="c1"&gt;# Return average loss
&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Some Implementation Details&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Smooth L1 Loss:&lt;/strong&gt; This was inspired by the Stable Baselines implementation of DQN. Smooth L1 loss prevents exploding gradients due to outliers by transitioning from L2 loss to L1 loss. This is useful in off-policy algorithms because when transitions are sampled from the Replay Buffer, they can belong to an old, much different policy. It is not desirable to fit to these ”outliers.”&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Gradient Clipping:&lt;/strong&gt; Before performing the optimizer step, the L2 norm of the gradients is clipped to 10 to counteract the influence of outliers further. This prevents exploding gradients by keeping all gradient norms below 10.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Evaluate the current Policy&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;According to &lt;a href="https://arxiv.org/abs/2307.09218" rel="noopener noreferrer"&gt;Zhenyi Wang&lt;/a&gt;, in many cases, DQN starts to forget and then relearn again once it has reached near the maximum reward. This can also be seen in the below figure, where the agent sometimes drops to a low reward. This can happen because even when the agent has achieved an optimal greedy policy, the rollout still happens with ϵ − greedy policy. This can cause the agent to suddenly take sub-optimal action, which can cause a large change in the network and cause the agent to unlearn. This can also be sometimes beneficial for making the agent more robust to perturbance. Nevertheless, the policy is evaluated after every learning cycle, and the best model is saved. This makes sure that, in the end, the best-performing model is retrievable.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2lxy93rkmuo583ruo5f6.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2lxy93rkmuo583ruo5f6.jpeg" alt="evaluation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Set Q-network to evaluation mode
&lt;/span&gt;    &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Total rewards
&lt;/span&gt;    &lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Total episodes
&lt;/span&gt;    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Reset environment
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;greedy&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get action
&lt;/span&gt;            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Step environment
&lt;/span&gt;            &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;  &lt;span class="c1"&gt;# Accumulate reward
&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="c1"&gt;# Check if episode ended
&lt;/span&gt;                &lt;span class="n"&gt;episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Reset environment
&lt;/span&gt;
    &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;/=&lt;/span&gt; &lt;span class="n"&gt;episodes&lt;/span&gt;  &lt;span class="c1"&gt;# Compute average reward per episode   
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_reward&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="c1"&gt;# Save best model if improved
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;save&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dqn_best_model.pth&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;New best model saved!&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Train&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Finally we have the &lt;code&gt;train&lt;/code&gt; function that runs the &lt;code&gt;collect_rollouts&lt;/code&gt;, &lt;code&gt;learn&lt;/code&gt; and &lt;code&gt;evaluate&lt;/code&gt; functions in a loop for a number of &lt;code&gt;epochs&lt;/code&gt;. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

 &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SummaryWriter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dqn_logs/DQN_2&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# TensorBoard writer
&lt;/span&gt;
        &lt;span class="n"&gt;pbar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  &lt;span class="c1"&gt;# Progress bar
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;collect_rollouts&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Collect rollouts
&lt;/span&gt;            &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;learn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_rollouts&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  &lt;span class="c1"&gt;# Perform learning
&lt;/span&gt;            &lt;span class="n"&gt;eval_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;evaluate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Evaluate agent
&lt;/span&gt;
            &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_description&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Iteration &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || Reward: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || Eval Reward: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;eval_reward&lt;/span&gt; &lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || Loss: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || Epsilon: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || Time steps: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_time_steps&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; || N updates: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Update progress bar
&lt;/span&gt;            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_scalar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Training/Loss&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log training loss
&lt;/span&gt;            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_scalar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Training/Rollout: Mean Episode Reward&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log mean episode reward during training
&lt;/span&gt;            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_scalar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Evaluation/Mean Episode Reward&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eval_reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n_updates&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Log mean episode reward during evaluation
&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Hyperparameters:&lt;/strong&gt; &lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;For learning and exploration:&lt;code&gt;gamma&lt;/code&gt;, &lt;code&gt;alpha&lt;/code&gt;, &lt;code&gt;initial_epsilon&lt;/code&gt;, &lt;code&gt;min_epsilon&lt;/code&gt;, &lt;code&gt;decay_rate&lt;/code&gt;:&lt;/li&gt;
&lt;li&gt;For experience replay and training:  &lt;code&gt;batch_size&lt;/code&gt;, &lt;code&gt;n_rollouts&lt;/code&gt;, &lt;code&gt;capacity&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Start Training
&lt;/h3&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Pendulum-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_max_episode_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;
&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Testing
&lt;/h3&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Pendulum-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="c1"&gt;#    render_mode='human'
&lt;/span&gt;               &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;_max_episode_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;
&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Agent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_network&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;models/deep_q_learning/dqn_best_model.pth&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;tot_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;greedy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="c1"&gt;# print(action)
&lt;/span&gt;            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arctan2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;tot_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
            &lt;span class="n"&gt;n_steps&lt;/span&gt;&lt;span class="o"&gt;+=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="c1"&gt;# env.render()
&lt;/span&gt;&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Total reward: &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tot_reward&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;h3&gt;
  
  
  Results
&lt;/h3&gt;

&lt;p&gt;Below, we can see the results of our trained agent.  The agent can reach an angle of zero no matter where it starts. Also, the reward is zero.  &lt;/p&gt;

&lt;p&gt;The average episodic reward across 100 episodes is &lt;strong&gt;9.8&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fshdvqui9oqyvw5yvrkit.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fshdvqui9oqyvw5yvrkit.png" alt="results"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;
  
  
  LSTM Implementation
&lt;/h2&gt;

&lt;p&gt;The transitions we collect are temporal. That is, there is a correlation between nearby transitions. To capture this temporal relation, we can change the state of the environment to be a history of past (state, action) pairs. So we need to add the previous action to the observation and maintain a list of &lt;code&gt;Nhist&lt;/code&gt; number of previous observations. Let’s see if this improves the agent since now we have more information and a better network.  &lt;/p&gt;

&lt;p&gt;In the figure below, you can see the LSTM network architecture. There are several ways to implement this. Apart from this, another option is to take a linear output from the last hidden layer directly. But let’s proceed with this for now.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6lime926dxvpxdgq5krx.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6lime926dxvpxdgq5krx.png" alt="lstm"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="c1"&gt;# Qnetwork with LSTM
&lt;/span&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;QNetwork&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;QNetwork&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_first&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Tensor&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;lstm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;relu&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Also, we need to change the wrapper to get the history of observations.&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;deque&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Wrapper&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;

        &lt;span class="c1"&gt;# Check if the observation space is of type Box
&lt;/span&gt;        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;isinstance&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Box&lt;/span&gt;
        &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error: observation space is not of type Box&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

        &lt;span class="c1"&gt;# Create a Discrete action space
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Discrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


        &lt;span class="c1"&gt;# Define the possible actions
&lt;/span&gt;        &lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;low&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;temp_low&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;temp_low&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;temp_low&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
            &lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;temp_low&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;low&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
        &lt;span class="n"&gt;high&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;


        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Box&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;float32&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;prev_action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;deque&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;maxlen&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Initialize the history with zeros
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nhist&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mf"&gt;16.2736044&lt;/span&gt; &lt;span class="c1"&gt;# normalize the reward between -1 and 1
&lt;/span&gt;        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt; &lt;span class="c1"&gt;# normalize the observation between -1 and 0
&lt;/span&gt;        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Initialize the history with the same observation
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Here, you can see that we are using &lt;code&gt;deque&lt;/code&gt; to keep a list of the history of fixed-size &lt;code&gt;Nhist&lt;/code&gt;. We also added the action to the observation. Now, we can train the agent as usual. Below, we can see how the training occurs in comparison to FCN. It is pretty similar. This means the LSTM is not having too big of an effect here. But in some complex environments this can be useful and worth trying.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Froakolski2lt3mao8wcf.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Froakolski2lt3mao8wcf.png" alt="results"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;That's all for this part. We saw how to implement a deep Q learning algorithm from scratch and use any network as the Q network, not just FCN. We saw LSTM, but you do this with transformers for an even longer context history. In the next part we will look into how to solve the limitation of using a discrete action space by implementing the Reinforce Algorithm.&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;/p&gt;

&lt;p&gt;🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt; &lt;/p&gt;

</description>
      <category>machinelearning</category>
      <category>ai</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
    <item>
      <title>Reinforcement Learning from Scratch - Part 1 - Tabular Q Learning</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Fri, 26 Jul 2024 05:36:08 +0000</pubDate>
      <link>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-1-tabular-q-learning-22k3</link>
      <guid>https://dev.to/akshayballal/reinforcement-learning-from-scratch-part-1-tabular-q-learning-22k3</guid>
      <description>&lt;p&gt;Reinforcement Learning is becoming the new trend. From controlling robots to optimizing logistics to tuning language models, reinforcement learning is the go-to strategy. However, newcomers to the field face a fragmented landscape and heavy reliance on implementation details. Even if you find a suitable RL algorithm from the many available, implementing it requires attention to fine details that aren’t part of the main RL algorithm, and getting these details right is challenging. While this is often problem-dependent, this series aims to show examples of implementation details across different RL problems and algorithms.&lt;/p&gt;

&lt;p&gt;In this first edition, we explore the most basic form of Q-learning: Vanilla or Tabular Q-learning. You’ll see that even with the most basic algorithm, there are several tricks we can use to improve and adapt it to various problems.&lt;/p&gt;

&lt;p&gt;This RL series is geared towards practitioners who already have basic knowledge of RL and have trained some basic RL agents. For this reason, I may not go into complete theory details and will mainly focus on finer implementation tricks. Nonetheless, let’s begin with some required basics. Also, in practice, one would use libraries like stable-baselines, but they are black boxes and do not provide a good understanding of how things work under the hood.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the agent’s objective in RL?
&lt;/h3&gt;

&lt;p&gt;Every RL agent is trained for one simple objective: to maximize the total rewards accumulated. To do this, we need the agent to learn to take the most optimal decision in every state it can face. The equation below shows the total reward accumulated by the agent starting at time step &lt;em&gt;t&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fbkh65unbpdjz49zsrpke.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fbkh65unbpdjz49zsrpke.png" alt="Returns Equation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Action-value function
&lt;/h3&gt;

&lt;p&gt;The action-value function, often denoted as Q(s,a), is a fundamental concept in reinforcement learning. It represents the expected return (total accumulated reward) starting from a state s, taking an action a, and subsequently following a policy π. Formally, it can be expressed as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglavq8mgbmtni4zff67f.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglavq8mgbmtni4zff67f.png" alt="Action value equation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is all the theory we need to implement our first Q-learning algorithm: Tabular Q-learning. Initially, we'll use Tabular Q-Learning to solve the Frozen Lake Environment, ensuring our algorithm setup is correct. Following this, we will tackle the Pendulum Environment. This environment is chosen for several reasons: it has relatively few states and only one action. The states and actions are also continuous, meaning they can take any real value within specified limits. For Tabular Q-Learning, we require discrete states and actions. This scenario will demonstrate how we can modify the environment to fit our algorithm.&lt;/p&gt;

&lt;h2&gt;
  
  
  Pseudocode for Tabular Q Learning
&lt;/h2&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F0ekwie1e2khvr0afq1go.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F0ekwie1e2khvr0afq1go.png" alt="pseudocode"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Initialize Parameters&lt;/strong&gt;: Set the initial values for all the parameters required for the algorithm. These include the exploration rate (ϵ), learning rate (α), reward discounting factor (γ), and the total number of steps the algorithm should run.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Initialize Q-matrix and Step Counter&lt;/strong&gt;: The Q-matrix (

&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;QmatQmat&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;Q&lt;/span&gt;&lt;span class="mord mathnormal"&gt;ma&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
) stores the Q-values for state-action pairs. Initially, it is empty. The step counter (&lt;em&gt;k&lt;/em&gt;) keeps track of the number of steps executed.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Get first observation&lt;/strong&gt;: The environment is reset to its initial state to start the learning process. The initial observation (
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;xkx_k&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
) is obtained from the environment.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Action Selection&lt;/strong&gt;: At each step, the algorithm decides whether to explore or exploit. With probability 1−ϵ, it exploits by choosing the action that has the highest Q-value for the current state. With probability ϵ, it explores by choosing a random action.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Step Environment&lt;/strong&gt;: The selected action is executed in the environment, which returns the new state (
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;xk+1x_{k+1} &lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;span class="mbin mtight"&gt;+&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
) and the reward (
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;rkr_k&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;r&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Check if Terminal State&lt;/strong&gt;: If the new state is terminal (end of an episode), the Q-value for the state-action pair is updated using the reward received, and the environment is reset.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Non-Terminal State&lt;/strong&gt;: If the new state is not terminal, the Q-value is updated considering the reward received and the maximum Q-value for the next state.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Decay ϵ&lt;/strong&gt;: The exploration rate (ϵ) is decayed to balance exploration and exploitation over time.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Increment Step Counter&lt;/strong&gt;: The step counter is incremented to proceed to the next step.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;End While Loop&lt;/strong&gt;: The process continues until the specified number of steps (
&lt;span class="katex-element"&gt;
  &lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;nstepsnsteps&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;span class="mord mathnormal"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal"&gt;e&lt;/span&gt;&lt;span class="mord mathnormal"&gt;p&lt;/span&gt;&lt;span class="mord mathnormal"&gt;s&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/span&gt;
) is reached.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Code Implementation for FrozenLake Environment
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;First, let’s import the required dependencies. Here, we use the gym library, which was originally developed by OpenAI and has several environments for RL.&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;defaultdict&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;copy&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;deepcopy&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;typing&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Optional&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;ul&gt;
&lt;li&gt;Next, we create a method for building the environment. For the FrozenLake environment, we build the environment as such without any modification, but in this method, we can wrap the environment in different modifications. We will see these modifications for the Pendulum Environment.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_env&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;rgb_array&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;ul&gt;
&lt;li&gt;We then define a custom &lt;code&gt;argmax&lt;/code&gt; function to choose the action with the maximum q value. But if multiple actions have the same q value, this function chooses uniformly between the ties. The standard &lt;code&gt;np.argmax&lt;/code&gt; function would just pick the first action and we don’t want that.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# random argmax
&lt;/span&gt;    &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;choice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;ul&gt;
&lt;li&gt;Now we implement the main algorithm&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;Qlearn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;min_epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;nsteps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;800000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;callback_freq&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;callback&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;Qmat&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;Qmat&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;defaultdict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;mean_episodic_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;

    &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="n"&gt;pbar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nsteps&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;colour&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;green&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="c1"&gt;# Exploration: choose a random action
&lt;/span&gt;            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="c1"&gt;# Exploitation: choose the action with the highest Q-value
&lt;/span&gt;            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;

        &lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

            &lt;span class="n"&gt;Q_next&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action_next&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;action_next&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Get the maximum Q-value for the next state
&lt;/span&gt;
            &lt;span class="c1"&gt;# Update Q-value using Q-learning update rule
&lt;/span&gt;            &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;Q_next&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

            &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_obs&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="c1"&gt;# Episode ends, update mean episodic reward and reset environment
&lt;/span&gt;            &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="n"&gt;mean_episodic_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean_episodic_reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;
            &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;min_epsilon&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.99995&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;callback_freq&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;callback&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="c1"&gt;# Execute callback function if provided
&lt;/span&gt;                &lt;span class="nf"&gt;callback&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_description&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Mean episodic reward: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;mean_episodic_reward&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Epsilon: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Best reward: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_reward&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;Qbest&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;p&gt;You can see here a few fine details:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;We initiate the ϵ at 1.0 and then decay it at every step by multiplying it with 0.99995 until it reaches some minimum epsilon. This decay rate becomes one of the hyperparameters.&lt;/li&gt;
&lt;li&gt;We are keeping track of the mean episodic reward.&lt;/li&gt;
&lt;li&gt;We have a callback function which we haven't talked about yet. This is the test function that we define next. This is used for a couple of reasons. One needs to know how the greedy policy is performing because, in the end, we are going to just use the greedy policy. This is similar to a validation set that is used in Machine Learning. Another reason is to save the best-performing policy every time the test is performed. This is done to take into account the fact that RL algorithms face a phenomenon of catastrophic forgetting. This happens when the agent learns good policies, but then suddenly an update puts it on the wrong track, and it forgets what it has learned and then relearns. This is quite common with my RL algorithms and Q learning is one of them. That's why we save the best policy so that if the agent is in a forgetting state at the end of the training, we can still retrieve the best policy. Here is the test code.&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;test_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;global&lt;/span&gt; &lt;span class="n"&gt;Qbest&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_reward&lt;/span&gt;

    &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;tot_rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;test_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Choose the action with the highest Q-value
&lt;/span&gt;        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
        &lt;span class="n"&gt;next_obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;tot_rewards&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_obs&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;best_reward&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;tot_rewards&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="c1"&gt;# Update the best reward and Q-best if a better reward is achieved
&lt;/span&gt;        &lt;span class="n"&gt;best_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tot_rewards&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;
        &lt;span class="n"&gt;Qbest&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;deepcopy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Qmat&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tot_rewards&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;p&gt;You can see that we check if the total reward during testing is better than the best rewards up to now, then update the best rewards and save the current &lt;code&gt;*Qmat&lt;/code&gt;* as &lt;em&gt;&lt;code&gt;Qbest&lt;/code&gt;.&lt;/em&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Next, we just run the algorithm, by first initiating &lt;code&gt;*Qbest&lt;/code&gt;* as None and best reward as &lt;code&gt;-np.inf&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;Qbest&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
&lt;span class="n"&gt;best_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;
&lt;span class="n"&gt;build_env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nf"&gt;get_env&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;FrozenLake-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;Qbest&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Qlearn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;callback&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nsteps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;ul&gt;
&lt;li&gt;After training, we can just test the agent using this code&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;get_env&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;FrozenLake-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;human&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# Slippery Frozen Lake is Default
&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;n_episodes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;truncated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;Qbest&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;p&gt;A new window opens, and we can see our little guy walking around. You will see that even though the ground is slippery the agent takes "safe" actions to both avoid the holes and still reach the goal.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fz27ub4agpyej47i1e475.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fz27ub4agpyej47i1e475.gif" alt="Frozen Lake"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h2&gt;
  
  
  Code Implementation for the Pendulum Environment
&lt;/h2&gt;

&lt;p&gt;Alright, now that we know that our algorithm works for the Frozen Lake Environment, we are ready to switch to a slightly complex environment. For the Pendulum Environment, we need to make a few modifications because our algorithm supports only discrete observation and action space whereas the Pendulum environment has continuous observation and action space. So, we need to discretize the environmental spaces. We will be creating a wrapper for this. But first, let’s see what this environment contains.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fgytg1e4pschg6uacylz8.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fgytg1e4pschg6uacylz8.png" alt="Pendulum Environment"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;So for the action space, we need to discretize the actions between -2 and 2. We can just choose a list of actions like [-2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2] , but this is manual and may not scale very well to other environments. Also, we need more actions that are smaller, because we need to have more fine control once the pendulum learns to swing up. Swinging up can be done easily using some large actions. For this, we can create a Gaussian kernel. A Gaussian kernel looks like this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9yfn1z4023wwhs1vovrv.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9yfn1z4023wwhs1vovrv.png" alt="Gaussian Kernel"&gt;&lt;/a&gt;&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;y=A∗ex−μ2σ2
y = A * e^{\frac{x-\mu}{2 \sigma^2}}
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;y&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;A&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;∗&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;e&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mopen nulldelimiter sizing reset-size3 size6"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;σ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line mtight"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;x&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;μ&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter sizing reset-size3 size6"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;Here, A ﻿is the maximum value of the action. In our case, we have μ = 0. You can see that changing the sigma changes the way our actions are distributed. Thus σ becomes one of the hyperparameters along with the number of actions available. Let’s define a function to create the Gaussian kernel. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;p&gt;The above plot is created using this code:&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Compute the Gaussian kernel centered at 0 with standard deviation of 1
&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Sigma = 0.2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Sigma = 0.4&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_amp&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Sigma = 0.6&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;x&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Amplitude&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;1D Scaled Gaussian Kernel&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;p&gt;We are going to use $\sigma$ = 0.5 for the rest of the code. To discretize the state we use the formula, &lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;xdiscrete=x−xminxmax−xmin⋅nvecs
x_{discrete} = \frac{x - x_{min}}{x_{max}-x_{min}} \cdot nvec_s
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;d&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;i&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;scre&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;t&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;e&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;ma&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;x&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;min&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;x&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;min&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;⋅&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;n&lt;/span&gt;&lt;span class="mord mathnormal"&gt;v&lt;/span&gt;&lt;span class="mord mathnormal"&gt;e&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;c&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;s&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;Where &lt;code&gt;nvec_s&lt;/code&gt; is the number of discrete states that we want to make. For this implementation, we are going to use [10, 10, 10], which means that the x value, y value, and the angular velocity of the pendulum are discretized to 10 values. For example, state corresponding to [0, 0, 0] will map to [-1, -1, -8], and state corresponding to [10, 10, 10] will map to [1, 1, 8]. Here is the code to do all this. We inherit the &lt;code&gt;gym.Wrapper&lt;/code&gt; &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Wrapper&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;

        &lt;span class="c1"&gt;# Check if the observation space is of type Box
&lt;/span&gt;        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;isinstance&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Box&lt;/span&gt;
        &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error: observation space is not of type Box&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

        &lt;span class="c1"&gt;# Check if the length of nvec_s matches the shape of the observation space
&lt;/span&gt;        &lt;span class="nf"&gt;assert &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Error: nvec_s does not match the shape of the observation space&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

        &lt;span class="c1"&gt;# Create a MultiDiscrete observation space and action space
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MultiDiscrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;spaces&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Discrete&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Define the possible actions
&lt;/span&gt;
        &lt;span class="n"&gt;kernel&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;gaussian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="c1"&gt;# self.actions = [-2.0, -1.0, -0.5, -0.25, -0.15, 0.0, 0.15, 0.25, 0.5, 1.0, 2.0]
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_discretize_observation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt; &lt;span class="o"&gt;|&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="c1"&gt;# Discretize each dimension of the observation
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;high&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
                &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;low&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# equation: (x - min) / (max - min) * nvec
&lt;/span&gt;        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;bool&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;actions&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_discretize_observation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="nb"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ndarray&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nb"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;

        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_discretize_observation&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;Now we can modify our method to build the environment to include this wrapper.&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_env&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Optional&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;list&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Optional&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;time_limit&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Optional&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Optional&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;rgb_array&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Env&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;render_mode&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PendulumDiscreteStateAction&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nvec_u&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;time_limit&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;wrappers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;TimeLimit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_episode_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;time_limit&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;We can also provide a different time limit to end the episode within a certain number of steps. Now, we can just call the &lt;code&gt;Qlearn&lt;/code&gt; method again to train the agent on the new environment. &lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;

&lt;span class="n"&gt;build_env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nf"&gt;get_env&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Pendulum-v1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nc"&gt;Qlearn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;build_env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;callback&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nsteps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A1100%2Fformat%3Awebp%2F0%2AMlVBFre3Bm5F_BzS.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fmiro.medium.com%2Fv2%2Fresize%3Afit%3A1100%2Fformat%3Awebp%2F0%2AMlVBFre3Bm5F_BzS.gif" alt="Pendulum Agent"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;As you can see, the agent learns to swing up pretty well, but keeping it centered on the top is quite hard. This shows that Tabular Q Learning is not the best method for complex environments. This is because discretizing to just 10 states yields 10 x 10 x 10 observation space. And this increases rapidly with more states and more discretization. Thus we need a better method. In the next part, we will look into Deep Q Learning, which uses a neural network to map the states to the Q values.&lt;/p&gt;

&lt;p&gt;You can find the code for this part at the following GitHub Repo:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://github.com/akshayballal95/rl-blog-series/tree/tabular" rel="noopener noreferrer"&gt;akshayballal95/rl-blog-series at tabular (github.com)&lt;/a&gt;&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;/p&gt;

&lt;p&gt;🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt; &lt;/p&gt;

</description>
      <category>ai</category>
      <category>machinelearning</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
    <item>
      <title>Supercharge your Tests with CodiumAI Cover-Agent</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Mon, 20 May 2024 18:02:13 +0000</pubDate>
      <link>https://dev.to/akshayballal/supercharge-your-tests-with-codiumai-cover-agent-4j9m</link>
      <guid>https://dev.to/akshayballal/supercharge-your-tests-with-codiumai-cover-agent-4j9m</guid>
      <description>&lt;p&gt;Writing great tests is a skill. Achieving comprehensive test coverage across a codebase is even more challenging. In this article, I will introduce a new tool to boost your test coverage and ensure that all your functions are tested thoroughly. Cover-Agent, an open-source tool by CodiumAI, leverages AI to generate tests aimed at increasing test coverage. Previously, I wrote about another tool from Codium, the PR-Agent, which you can read &lt;a href="https://dev.to/akshayballal/i-used-codiumais-pr-agent-and-i-love-it-1nb8"&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;Understanding Test Coverage&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;Before diving into CodiumAI Cover-Agent, let’s discuss what test coverage entails, particularly code coverage. Code coverage measures how much of the code is executed by tests. Key aspects of code coverage include:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Function Coverage&lt;/strong&gt;: The percentage of functions or methods called by the test suite.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Statement Coverage&lt;/strong&gt;: The percentage of executable statements run by the test suite.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Branch Coverage&lt;/strong&gt;: The percentage of branches (like if-else and switch-case) executed.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Condition Coverage&lt;/strong&gt;: The percentage of Boolean expressions evaluated as both true and false.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;Benefits of Test Coverage&lt;/strong&gt;
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Identifying Gaps in Testing&lt;/strong&gt;: Helps developers spot untested parts of the code.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Improving Software Quality&lt;/strong&gt;: Comprehensive tests catch bugs and issues early.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Maintaining Code Reliability&lt;/strong&gt;: Regularly measuring test coverage maintains and enhances code reliability.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;However, high test coverage alone does not guarantee high-quality tests. It is possible to have high coverage with tests that do not thoroughly validate the code's correctness. Therefore, while aiming for high test coverage, it’s crucial to ensure that the tests are meaningful and well-designed. Codium Cover-Agent addresses this by not only aiming for coverage but also generating quality test cases that test edge cases, invalid inputs, and errors.&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;A Simple Example with CodiumAI Cover-Agent&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;To illustrate how CodiumAI Cover-Agent works, let’s start with a basic example. We will create a simple &lt;strong&gt;&lt;code&gt;calculator.py&lt;/code&gt;&lt;/strong&gt; file with functions for addition, subtraction, multiplication, and division.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# calculator.py
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;subtract&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;raise&lt;/span&gt; &lt;span class="nc"&gt;ValueError&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Cannot divide by zero&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Next, we write a test file &lt;strong&gt;&lt;code&gt;test_calculator.py&lt;/code&gt;&lt;/strong&gt; and place it in the tests folder.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# tests/test_calculator.py
&lt;/span&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;calculator&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;subtract&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;divide&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TestCalculator&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;To see the test coverage, we need to install &lt;strong&gt;&lt;code&gt;pytest-cov&lt;/code&gt;&lt;/strong&gt;, a pytest extension for coverage reporting.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;pytest-cov

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Run the coverage analysis with:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pytest &lt;span class="nt"&gt;--cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;calculator
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The output shows:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="gh"&gt;Name            Stmts   Miss  Cover
-----------------------------------
calculator.py      10      5    50%
-----------------------------------
&lt;/span&gt;TOTAL              10      5    50%
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This indicates that 5 of the 10 statements in &lt;strong&gt;&lt;code&gt;calculator.py&lt;/code&gt;&lt;/strong&gt; are not executed, resulting in 50% code coverage. For simple examples like this, adding more tests manually is straightforward. But let's see how Codium Cover-Agent can enhance this process.&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;Setting Up CodiumAI Cover-Agent&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;To set up Codium Cover-Agent, follow these steps:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;p&gt;Install Cover-Agent:&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;git+https://github.com/Codium-ai/cover-agent.git
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;li&gt;&lt;p&gt;Ensure &lt;strong&gt;&lt;code&gt;OPENAI_API_KEY&lt;/code&gt;&lt;/strong&gt; is set in your environment variables, as it is required for the OpenAI API.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;Create the command to start generating tests:&lt;br&gt;
&lt;/p&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;cover-agent &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--source-file-path&lt;/span&gt; &lt;span class="s2"&gt;"calculator.py"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--test-file-path&lt;/span&gt; &lt;span class="s2"&gt;"tests/test_calculator.py"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--code-coverage-report-path&lt;/span&gt; &lt;span class="s2"&gt;"coverage.xml"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--test-command&lt;/span&gt; &lt;span class="s2"&gt;"pytest --cov=. --cov-report=xml --cov-report=term"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--test-command-dir&lt;/span&gt; &lt;span class="s2"&gt;"./"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--coverage-type&lt;/span&gt; &lt;span class="s2"&gt;"cobertura"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--desired-coverage&lt;/span&gt; 80 &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--max-iterations&lt;/span&gt; 3 &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--openai-model&lt;/span&gt; &lt;span class="s2"&gt;"gpt-4o"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
&lt;span class="nt"&gt;--additional-instructions&lt;/span&gt; &lt;span class="s2"&gt;"Since I am using a test class, each line of code (including the first line) needs to be prepended with 4 whitespaces. This is extremely important to ensure that every line returned contains that 4 whitespace indent; otherwise, my code will not run."&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;Command Arguments Explained&lt;/strong&gt;
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;source-file-path&lt;/code&gt;&lt;/strong&gt;: Path of the file containing the functions for which tests need to be generated.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;test-file-path&lt;/code&gt;&lt;/strong&gt;: Path of the file where the tests will be written by the agent. It’s best to create a skeleton of this file with at least one test and the necessary import statements.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;code-coverage-report-path&lt;/code&gt;&lt;/strong&gt;: Path where the code coverage report is saved.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;test-command&lt;/code&gt;&lt;/strong&gt;: Command to run the tests (e.g., pytest).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;test-command-dir&lt;/code&gt;&lt;/strong&gt;: Directory where the test command should run. Set this to the root or the location of your main file to avoid issues with relative imports.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;coverage-type&lt;/code&gt;&lt;/strong&gt;: Type of coverage to use. Cobertura is a good default.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;desired-coverage&lt;/code&gt;&lt;/strong&gt;: Coverage goal. Higher is better, though 100% is often impractical.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;max-iterations&lt;/code&gt;&lt;/strong&gt;: Number of times the agent should retry to generate test code. More iterations may lead to higher OpenAI token usage.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;additional-instructions&lt;/code&gt;&lt;/strong&gt;: Prompts to ensure the code is written in a specific way. For example, here we specify that the code should be formatted to work within a test class.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;On running the command, the agent starts writing and iterating on the tests. &lt;/p&gt;

&lt;p&gt;This is the code that it generates&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytest&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;calculator&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;add&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;subtract&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;divide&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TestCalculator&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;assert&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;add&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_subtract&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test subtracting two numbers.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;subtract&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;subtract&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test multiplying two numbers.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;multiply&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test dividing two numbers.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_divide_by_zero&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test dividing by zero, should raise ValueError.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pytest&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;raises&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;ValueError&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;match&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Cannot divide by zero&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="nf"&gt;divide&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You can see that the agent also wrote tests for checking errors for edge cases. &lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F82nz3gti8qlie4gwws96.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F82nz3gti8qlie4gwws96.gif" alt="Video of calculator test generation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We can now test the coverage again&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;pytest --cov=calculator
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="gh"&gt;Name            Stmts   Miss  Cover
-----------------------------------
calculator.py      10      0   100%
-----------------------------------
&lt;/span&gt;TOTAL              10      0   100%
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;In this simple example we reached 100% test coverage. &lt;/p&gt;

&lt;h3&gt;
  
  
  Trying Cover-Agent in a real Codebase
&lt;/h3&gt;

&lt;p&gt;The previous example was simple, and even a beginner can completely cover the codebase. Let’s see how the agent works for a more real-world application. I have a codebase with more than 15 files with approximately 1000 lines of code. This is the coverage of a certain module of the application.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="gh"&gt;Name                                      Stmts   Miss  Cover
-------------------------------------------------------------
&lt;/span&gt;maintenance&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                       0      0   100%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\d&lt;/span&gt;egradation_event.py       6      2    67%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\e&lt;/span&gt;vent.py                  22      8    64%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\r&lt;/span&gt;each_machine.py           4      1    75%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\r&lt;/span&gt;epaired_event.py          4      1    75%
maintenance&lt;span class="se"&gt;\f&lt;/span&gt;uture_event_set.py              24     14    42%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                0      0   100%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\a&lt;/span&gt;ction.py                 12      0   100%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\e&lt;/span&gt;ngineer.py               15      5    67%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\f&lt;/span&gt;actory.py                63     36    43%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\m&lt;/span&gt;achine.py                79      1    99%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\s&lt;/span&gt;tate.py                  14      0   100%
maintenance&lt;span class="se"&gt;\p&lt;/span&gt;olicies&lt;span class="se"&gt;\p&lt;/span&gt;olicy.py               12      1    92%
maintenance&lt;span class="se"&gt;\p&lt;/span&gt;olicies&lt;span class="se"&gt;\r&lt;/span&gt;eactive_policy.py      27      8    70%
maintenance&lt;span class="se"&gt;\s&lt;/span&gt;imResults.py                    65     65     0%
maintenance&lt;span class="se"&gt;\s&lt;/span&gt;imulation.py                    43     43     0%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                 0      0   100%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\c&lt;/span&gt;onftest.py                25      0   100%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\f&lt;/span&gt;actory_test.py            12      0   100%
&lt;span class="gh"&gt;maintenance\tests\machine_test.py            63      0   100%
-------------------------------------------------------------
&lt;/span&gt;TOTAL                                       490    185    62%
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The goal is to increase the coverage by writing tests for the &lt;code&gt;factory.py&lt;/code&gt; file. Let’s do it&lt;/p&gt;

&lt;p&gt;As before we create a skeleton for the test file. I have written one simple test here.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytest&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.factory&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.tests.conftest&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.action&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ActionType&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.policies.reactive_policy&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;ReactivePolicy&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.events.event&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;EventType&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TestFactory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_execute_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;execute_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;engineers&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;determine_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fse_available&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fse_location&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then we run the following command&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;cover-agent &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--source-file-path&lt;/span&gt; &lt;span class="s2"&gt;"maintenance/models/factory.py"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--test-file-path&lt;/span&gt; &lt;span class="s2"&gt;"maintenance/tests/factory_test.py"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--code-coverage-report-path&lt;/span&gt; &lt;span class="s2"&gt;"coverage.xml"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--test-command&lt;/span&gt; &lt;span class="s2"&gt;"pytest --cov=maintenance &lt;/span&gt;&lt;span class="se"&gt;\&lt;/span&gt;&lt;span class="s2"&gt;
 --cov-report=xml --cov-report=term"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--test-command-dir&lt;/span&gt; &lt;span class="s2"&gt;"./"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--coverage-type&lt;/span&gt; &lt;span class="s2"&gt;"cobertura"&lt;/span&gt; &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--desired-coverage&lt;/span&gt; 80 &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--max-iterations&lt;/span&gt; 3 &lt;span class="se"&gt;\&lt;/span&gt;
 &lt;span class="nt"&gt;--openai-model&lt;/span&gt; &lt;span class="s2"&gt;"gpt-4o"&lt;/span&gt; &lt;span class="se"&gt;\ &lt;/span&gt;
 &lt;span class="nt"&gt;--additional-instructions&lt;/span&gt; &lt;span class="s2"&gt;"Since I am using a test class, each line of code (including the first line), In your response will need to be prepended with 4 whitespaces. This is extremely important to check to make sure every line returned contains that 4 whitespace indent otherwise my code will not run."&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsnoyhtcpmx7jdabhacvt.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsnoyhtcpmx7jdabhacvt.gif" alt="Video of factory test generation"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The agent starts writing the code. This takes about a minute to finish. Once the code is generated, the updated coverage report can now be checked.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight markdown"&gt;&lt;code&gt;&lt;span class="gh"&gt;Name                                      Stmts   Miss  Cover
-------------------------------------------------------------
&lt;/span&gt;maintenance&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                       0      0   100%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\d&lt;/span&gt;egradation_event.py       6      0   100%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\e&lt;/span&gt;vent.py                  22      3    86%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\r&lt;/span&gt;each_machine.py           4      1    75%
maintenance&lt;span class="se"&gt;\e&lt;/span&gt;vents&lt;span class="se"&gt;\r&lt;/span&gt;epaired_event.py          4      0   100%
maintenance&lt;span class="se"&gt;\f&lt;/span&gt;uture_event_set.py              24     13    46%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                0      0   100%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\a&lt;/span&gt;ction.py                 12      0   100%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\e&lt;/span&gt;ngineer.py               15      3    80%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\f&lt;/span&gt;actory.py                63     10    84%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\m&lt;/span&gt;achine.py                79      0   100%
maintenance&lt;span class="se"&gt;\m&lt;/span&gt;odels&lt;span class="se"&gt;\s&lt;/span&gt;tate.py                  14      0   100%
maintenance&lt;span class="se"&gt;\p&lt;/span&gt;olicies&lt;span class="se"&gt;\p&lt;/span&gt;olicy.py               12      1    92%
maintenance&lt;span class="se"&gt;\p&lt;/span&gt;olicies&lt;span class="se"&gt;\r&lt;/span&gt;eactive_policy.py      27      2    93%
maintenance&lt;span class="se"&gt;\s&lt;/span&gt;imResults.py                    65     65     0%
maintenance&lt;span class="se"&gt;\s&lt;/span&gt;imulation.py                    43     43     0%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\_&lt;/span&gt;&lt;span class="ge"&gt;_init_&lt;/span&gt;_.py                 0      0   100%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\c&lt;/span&gt;onftest.py                25      0   100%
maintenance&lt;span class="se"&gt;\t&lt;/span&gt;ests&lt;span class="se"&gt;\f&lt;/span&gt;actory_test.py            45      0   100%
&lt;span class="gh"&gt;maintenance\tests\machine_test.py            63      0   100%
-------------------------------------------------------------
&lt;/span&gt;TOTAL                                       523    141    73%
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We have increased the total coverage for the module by 10% and the coverage of &lt;code&gt;factory.py&lt;/code&gt; from 43% to 84%. This is amazing, given almost no effort from our side. This is the test code written by the agent. More edge cases can be tested. We can add them manually or instruct the agent to handle some specific edge cases.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytest&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.factory&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.tests.conftest&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.action&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ActionType&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.policies.reactive_policy&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;ReactivePolicy&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.events.event&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;EventType&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.engineer&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Engineer&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;maintenance.models.machine&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Machine&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TestFactory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_execute_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;execute_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;engineers&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;determine_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fse_available&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fse_location&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_machine_from_id_invalid_id&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the machine_from_id method with an invalid machine ID.
        Ensures that the method returns None for an invalid machine ID.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;machine_from_id&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Invalid machine ID
&lt;/span&gt;        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_process_event_degradation_machine_under_repair&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the process_event method for the DEGRADATION event type when the machine is under repair.
        Ensures that the degradation event is ignored.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;machines&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;under_repair&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EventType&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DEGRADATION&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;intensity&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;process_event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;degradation&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# Degradation should not change
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_process_event_degradation_machine_failed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the process_event method for the DEGRADATION event type when the machine has failed.
        Ensures that the degradation event is ignored.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;machines&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;degradation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;failure_theshold&lt;/span&gt;
        &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EventType&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DEGRADATION&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;intensity&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;process_event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;degradation&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;failure_theshold&lt;/span&gt;  &lt;span class="c1"&gt;# Degradation should not change
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_process_event_degradation_threshold_reached&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the process_event method for the DEGRADATION event type when the degradation reaches the failure threshold.
        Ensures that the machine handles failure correctly.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;machines&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EventType&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DEGRADATION&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;intensity&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;failure_theshold&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;process_event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;degradation&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;failure_theshold&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;has_failed&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_process_event_repaired_machine_not_under_repair&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the process_event method for the REPAIRED event type when the machine is not under repair.
        Ensures that the method handles this scenario gracefully.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;machines&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;under_repair&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;EventType&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;REPAIRED&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;process_event&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;machine&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;under_repair&lt;/span&gt;  &lt;span class="c1"&gt;# Machine should still not be under repair
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;test_get_results_zero_sim_time&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;Factory&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;
        Test the get_results method with zero simulation time.
        Ensures that the method handles division by zero gracefully.
        &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
        &lt;span class="n"&gt;sim_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pytest&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;raises&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;ZeroDivisionError&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;factory&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_results&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sim_time&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  &lt;strong&gt;Conclusion&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;CodiumAI Cover-Agent simplifies achieving high test coverage and generating meaningful, high-quality tests. By using AI to identify and address gaps in testing, it ensures robust software development practices. Try CodiumAI Cover-Agent on your projects to see the difference it can make in your testing workflow.&lt;/p&gt;

&lt;p&gt;You can check out the open-source repository for the cover agent at &lt;a href="https://github.com/Codium-ai/cover-agent" rel="noopener noreferrer"&gt;https://github.com/Codium-ai/cover-agent&lt;/a&gt;.&lt;/p&gt;




&lt;p&gt;Want to connect?&lt;/p&gt;

&lt;p&gt;🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt; &lt;/p&gt;

</description>
      <category>testing</category>
      <category>tooling</category>
      <category>ai</category>
      <category>programming</category>
    </item>
    <item>
      <title>Introducing EmbedAnything</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Thu, 18 Apr 2024 10:01:21 +0000</pubDate>
      <link>https://dev.to/akshayballal/introducing-embedanything-950</link>
      <guid>https://dev.to/akshayballal/introducing-embedanything-950</guid>
      <description>&lt;h2&gt;
  
  
  Motivation
&lt;/h2&gt;

&lt;p&gt;Transformer models have become increasingly popular in recent times. One crucial requirement for all transformer models is a latent representation of the input data in embeddings. Word embeddings are used in language models, while vision models rely on patch embeddings. However, currently, there are no existing solutions to extract embeddings and custom metadata from various file formats. LangChain offers some solutions, but it is a bulky package, and extracting only the embedding data is not easy. Moreover, LangChain is not very suitable for vision-related tasks. Embeddings are helpful for language models and other models trained for various tasks, such as semantic segmentation and object detection.&lt;/p&gt;

&lt;p&gt;This is where EmbedAnything comes in. It is a lightweight library that allows you to generate embeddings from different file formats and modalities. Currently, EmbedAnything supports PDFs and images, with many more formats in the pipeline. The idea is to provide an end-to-end solution where you can give the file and get the embeddings with the appropriate metadata.&lt;/p&gt;

&lt;p&gt;Development of EmbedAnything started with these goals in mind:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Compatibility with Local and Cloud Models&lt;/strong&gt;: Seamless integration with local and cloud-based embedding models.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;High-Speed Performance&lt;/strong&gt;: Fast processing to meet demanding application requirements.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multimodal Capability&lt;/strong&gt;: Flexibility to handle various modalities.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;CPU and GPU Compatibility&lt;/strong&gt;: Performance optimization for both CPU and GPU environments.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Lightweight Design&lt;/strong&gt;: Minimized footprint for efficient resource utilization.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;In this blog, we will see how we achieve these goals and what more must be done to improve EmbedAnything. We will also see why EmbedAnything is packaged the way it is with Rust as a backend and a Python interface. &lt;/p&gt;

&lt;h2&gt;
  
  
  Keeping it Local
&lt;/h2&gt;

&lt;p&gt;While cloud-based embedding services like OpenAI, Jina, and Mistral offer convenience, many users require the flexibility and control of local embedding models. Here's why local models are crucial for some use cases:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Cost-Effectiveness:&lt;/strong&gt; Cloud services often charge per API call or model usage. Running embeddings locally on your own hardware can significantly reduce costs, especially for projects with frequent or high-volume embedding needs.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Data Privacy:&lt;/strong&gt; Certain data, like medical records or financial documents, might be too sensitive to upload to the cloud. Local embedding keeps your data confidential and under your control.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Offline Functionality:&lt;/strong&gt; An internet connection isn't always guaranteed. Local models ensure your embedding tasks can run uninterrupted even without an internet connection.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Performance
&lt;/h2&gt;

&lt;p&gt;EmbedAnything is built with Rust. This makes it faster and provides type safety and a much better development experience. But why is speed so crucial in this process?&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;The Need for Speed&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;Creating embeddings from files involves two steps that demand significant computational power:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Extracting Text from Files, Especially PDFs:&lt;/strong&gt; Text can exist in different formats such as markdown, PDFs, and Word documents. However, extracting text from PDFs can be challenging and often causes slowdowns. It is especially difficult to extract text in manageable batches as embedding models have a context limit. Breaking the text into paragraphs containing focused information can help.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Inferencing on the Transformer Embedding Model:&lt;/strong&gt; The transformer model is usually at the core of the embedding process, but it is known for being computationally expensive. To address this, EmbedAnything utilizes the Candle Framework by Hugging Face, a machine-learning framework built entirely in Rust for optimized performance.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;The Benefit of Rust for Speed&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;By using Rust for its core functionalities, EmbedAnything offers significant speed advantages:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Rust is Compiled:&lt;/strong&gt; Unlike Python,  Rust compiles directly to machine code, resulting in faster execution.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Memory Management:&lt;/strong&gt; Rust enforces memory management simultaneously, preventing memory leaks and crashes that can plague other languages.&lt;/li&gt;
&lt;li&gt;Rust achieves true multithreading.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  What does Candle bring to the table?
&lt;/h3&gt;

&lt;p&gt;Running language models or embedding models locally can be difficult, especially when you want to deploy a product that utilizes these models. If you use the transformers library from Hugging Face in Python, you will depend on PyTorch for tensor operations. This, in turn, has a dependency on Libtorch, which means that you will need to include the entire Libtorch library with your product. Also, Candle allows inferences on CUDA-enabled GPUs right out of the box. We will soon post on how we use Candle to increase the performance and decrease the memory usage of EmbedAnything.&lt;/p&gt;

&lt;h2&gt;
  
  
  Multimodality
&lt;/h2&gt;

&lt;p&gt;Finally, let's see how EmbedAnything handles multimodality. When a directory is passed for embedding to EmbedAnything, the file extension is checked to see if it is text or image and a suitable embedding model is used to generate the embeddings.&lt;/p&gt;

&lt;p&gt;Check out an example of an image search on this Google Colab Notebook: &lt;a href="https://2ly.link/1xf4A" rel="noopener noreferrer"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open in Colab"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can start using EmbedAnything with&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;pip &lt;span class="nb"&gt;install &lt;/span&gt;embed_anything
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You can view and contribute to the project on GitHub at:&lt;br&gt;
&lt;a href="https://github.com/StarlightSearch/EmbedAnything" rel="noopener noreferrer"&gt;EmbedAnything&lt;/a&gt;&lt;/p&gt;




&lt;p&gt;my website: &lt;a href="http://www.akshaymakes.com/" rel="noopener noreferrer"&gt;http://www.akshaymakes.com/&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;linkedin: &lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;https://www.linkedin.com/in/akshay-ballal/&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;twitter: &lt;a href="https://twitter.com/akshayballal95/" rel="noopener noreferrer"&gt;https://twitter.com/akshayballal95/&lt;/a&gt;&lt;/p&gt;

</description>
      <category>python</category>
      <category>rust</category>
      <category>machinelearning</category>
      <category>ai</category>
    </item>
    <item>
      <title>Building an AI Game Bot 🤖 Using Imitation Learning and 3D Convolution ResNet</title>
      <dc:creator>Akshay Ballal</dc:creator>
      <pubDate>Tue, 02 Jan 2024 19:40:33 +0000</pubDate>
      <link>https://dev.to/akshayballal/building-an-ai-game-bot-using-imitation-learning-and-3d-convolution-resnet-3fe4</link>
      <guid>https://dev.to/akshayballal/building-an-ai-game-bot-using-imitation-learning-and-3d-convolution-resnet-3fe4</guid>
      <description>&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F29jy0lkg0lj3pi1q5bfv.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F29jy0lkg0lj3pi1q5bfv.png" alt="Cover"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Introduction
&lt;/h2&gt;

&lt;p&gt;Hey guys, this article is going to be fun, and I am sure you will get to learn a lot about using AI in a fun way. Even though this article is about making an AI engine to play a game, the learnings can be used in more serious applications because the process is more or less the same. The good thing about learning to build AI with games, as seen in my previous &lt;a href="https://dev.to/akshayballal/how-to-build-an-ai-powered-game-bot-with-pytorch-and-efficientnet-18jm"&gt;article&lt;/a&gt;, is that you get to experiment a lot without consequences, and getting the data is cheap. Normally, for games, one would use reinforcement learning to train an agent. This can be done but requires a lot more effort to set up in regards to building your own simulator. But with this one, we are going to build an AI agent that plays the Google Snake game. You don’t have to build the game yourself. &lt;/p&gt;

&lt;p&gt;Broadly, this is what we are going to do. We are going to use the technique of imitation learning, where the agent learns from the human counterpart about how to make decisions. This means that we as humans are first going to play the game for a while and collect the data. And then, the AI uses this data to train itself on how to play the game and pick up patterns from the human counterpart. This is known as imitation learning (not to be confused with transfer learning). &lt;/p&gt;

&lt;p&gt;Getting into the technicalities, we are going to train a 3D convolution ResNet model. This is because we want to capture motion to know which direction the snake is heading. For this reason, we are going to feed our model 4 frames of the game at a time to infer motion. You could try just to use one frame and use a standard convolution model like EfficientNet, but it is not going to work that well without the motion information. &lt;/p&gt;

&lt;p&gt;Without further ado, let’s get started. First, we will see how to collect the data.&lt;/p&gt;

&lt;h2&gt;
  
  
  Data Collection
&lt;/h2&gt;

&lt;p&gt;There are two methods to gather data. The first involves using a Python script to capture the game screen and label the images with the keyboard command. The second method, which we will be using, relies on selenium, a Python tool that automates browser navigation and control. This is the code to use selenium and save the screen captures as images with the appropriate labels.&lt;/p&gt;

&lt;p&gt;We can create a python script &lt;code&gt;capture.py&lt;/code&gt; to collect the screenshots.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Import the required modules
&lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;base64&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;io&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;keyboard&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;datetime&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;datetime&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;selenium&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;webdriver&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;selenium.webdriver.common.by&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;By&lt;/span&gt;

&lt;span class="c1"&gt;# Check if the captures folder exists and delete any existing files in it
&lt;/span&gt;&lt;span class="n"&gt;isExist&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exists&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;isExist&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nb"&gt;dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;listdir&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;dir&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;remove&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mkdir&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;current_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;1&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="nb"&gt;buffer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="c1"&gt;# Define a function to record the keyboard inputs
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;keyboardCallBack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;keyboard&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;KeyboardEvent&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;global&lt;/span&gt; &lt;span class="n"&gt;current_key&lt;/span&gt;

    &lt;span class="c1"&gt;# If a key is pressed and not in the buffer, add it to the buffer
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;event_type&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;down&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# If a key is released, remove it from the buffer
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;event_type&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;up&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;remove&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Sort the buffer and join the keys with spaces
&lt;/span&gt;    &lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sort&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;current_key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Hook the function to the keyboard events
&lt;/span&gt;&lt;span class="n"&gt;keyboard&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hook&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;callback&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;keyboardCallBack&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Create a webdriver instance using Firefox
&lt;/span&gt;&lt;span class="n"&gt;driver&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;webdriver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Firefox&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="c1"&gt;# Navigate to the Google Snake game website
&lt;/span&gt;&lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;https://www.google.com/fbx?fbx=snake_arcade&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Loop indefinitely
&lt;/span&gt;&lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Find the canvas element on the webpage
&lt;/span&gt;    &lt;span class="n"&gt;canvas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;find_element&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;By&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CSS_SELECTOR&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;canvas&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Get the base64 encoded image data from the canvas
&lt;/span&gt;    &lt;span class="n"&gt;canvas_base64&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;execute_script&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;return arguments[0].toDataURL(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;image/png&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;).substring(21);&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;canvas&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Decode the base64 data to get the PNG image
&lt;/span&gt;    &lt;span class="n"&gt;canvas_png&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;base64&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;b64decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;canvas_base64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Convert the PNG image to a grayscale numpy array
&lt;/span&gt;    &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cvtColor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;io&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;BytesIO&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;canvas_png&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;COLOR_BGR2GRAY&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Save the image to the captures folder with the current timestamp and keyboard inputs as the file name
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;buffer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imwrite&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures/&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;datetime&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;now&lt;/span&gt;&lt;span class="p"&gt;()).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;-&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;current_key&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;.png&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imwrite&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures/&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;str&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;datetime&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;now&lt;/span&gt;&lt;span class="p"&gt;()).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;-&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;replace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;_&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt; n&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
            &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;.png&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Once you run this script, you can start playing the game. In the background, the script will keep saving screenshots of the game screen and name the images with a unique timestamp and the key being pressed right now. When no key is pressed, it is marked as &lt;code&gt;n&lt;/code&gt;. &lt;/p&gt;

&lt;p&gt;This is how the directory will look after the images are saved.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzvnjxtft8gdpfrxqp844.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzvnjxtft8gdpfrxqp844.png" alt="Image Directory"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Convert Image Folder to CSV File
&lt;/h2&gt;

&lt;p&gt;Now we can convert these images into a csv file with the file names and the corresponding actions. We do this in a new file &lt;code&gt;process.ipynb&lt;/code&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight jsx"&gt;&lt;code&gt;&lt;span class="k"&gt;import&lt;/span&gt; &lt;span class="nx"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nx"&gt;pd&lt;/span&gt;
&lt;span class="k"&gt;import&lt;/span&gt; &lt;span class="nx"&gt;matplotlib&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nx"&gt;pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nx"&gt;plt&lt;/span&gt;
&lt;span class="k"&gt;import&lt;/span&gt; &lt;span class="nx"&gt;os&lt;/span&gt;
&lt;span class="k"&gt;import&lt;/span&gt; &lt;span class="nx"&gt;csv&lt;/span&gt; 
&lt;span class="k"&gt;import&lt;/span&gt; &lt;span class="nx"&gt;os&lt;/span&gt;

&lt;span class="nx"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="nx"&gt;dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;captures&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;
&lt;span class="nx"&gt;file_path&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;data/labels_snake.csv&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;

&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nx"&gt;not&lt;/span&gt; &lt;span class="nx"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nx"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exists&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nx"&gt;file_path&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nx"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mkdir&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;data&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt; &lt;span class="k"&gt;in&lt;/span&gt; &lt;span class="nx"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;listdir&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nx"&gt;dir&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nx"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rsplit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;.&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;rsplit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt; &lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nx"&gt;key&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;n&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
    &lt;span class="nx"&gt;elif&lt;/span&gt; &lt;span class="nx"&gt;key&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;left&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
    &lt;span class="nx"&gt;elif&lt;/span&gt; &lt;span class="nx"&gt;key&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;up&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
    &lt;span class="nx"&gt;elif&lt;/span&gt; &lt;span class="nx"&gt;key&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;right&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
    &lt;span class="nx"&gt;elif&lt;/span&gt; &lt;span class="nx"&gt;key&lt;/span&gt;&lt;span class="o"&gt;==&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="s2"&gt;down&lt;/span&gt;&lt;span class="dl"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nx"&gt;f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;


&lt;span class="nx"&gt;field_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;file_name&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;class&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="kd"&gt;with&lt;/span&gt; &lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;data/labels_snake.csv&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="s1"&gt;w&lt;/span&gt;&lt;span class="dl"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="nx"&gt;csvfile&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nx"&gt;writer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nx"&gt;csv&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DictWriter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nx"&gt;csvfile&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nx"&gt;fieldnames&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nx"&gt;field_names&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nx"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;writeheader&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="nx"&gt;writer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;writerows&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nx"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;In this process, we are essentially creating a dataset of images with their corresponding labels, where each label represents the direction in which a key was pressed. Here are the steps involved:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Importing dependencies:&lt;/strong&gt; We need certain libraries and modules to read and process the image files and write the labels to a CSV file. These are imported at the beginning of the process.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Creating a directory&lt;/strong&gt;: We create a directory named "data" to save the CSV file that will contain the labels and image filenames.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Reading the file nam&lt;/strong&gt;e: We extract the key pressed from the file name of each image. The key pressed can be either left, right, up, down, or no key at all.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Classifying the images&lt;/strong&gt;: Based on the key pressed, we classify each image into one of four classes: 0 for no key being pressed, 1 for left, 2 for up, 3 for right, and 4 for down. This information is stored along with the filename of the image in a list called "labels".&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Writing the labels&lt;/strong&gt;: Finally, we write the labels to the CSV file. This dataset can now be used to train a machine learning model to recognize the direction in which a key is pressed, given an image.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Now, we create a new file &lt;code&gt;snake_resnet.ipynb&lt;/code&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Import Dependencies
&lt;/h2&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;WeightedRandomSampler&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;os&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;train_test_split&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchvision.transforms&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;transforms&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Resize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;CenterCrop&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Grayscale&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchinfo&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchvision.models.video&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;r3d_18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;R3D_18_Weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mc3_18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;MC3_18_Weights&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Create Dataset
&lt;/h2&gt;

&lt;p&gt;To get started, we need to create a custom dataset object using PyTorch. Our dataset consists of a stack of four images that are arranged in chronological order. Each item drawn from the dataset represents a sequence of four frames, where the last frame is associated with a key press. Essentially, this dataset captures motion through the last four frames and associates it with a key press.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;SnakeDataSet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataframe&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stack_size&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dataframe&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;root_dir&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_list&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;try&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;img_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;iloc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
            &lt;span class="n"&gt;images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;img_name&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;img_names&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;iloc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="k"&gt;except&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;img_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;os&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;path&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;iloc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
            &lt;span class="n"&gt;images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;img_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;img_name&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;img_names&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;label&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_frame&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;iloc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;images&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's break down the code:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Initialization (&lt;code&gt;__init__&lt;/code&gt; method):&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;Parameters:

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;dataframe&lt;/code&gt;&lt;/strong&gt;: The dataset containing information about images and labels.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;root_dir&lt;/code&gt;&lt;/strong&gt;: The root directory where the image files are located.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;stack_size&lt;/code&gt;&lt;/strong&gt;: The number of images to be stacked together as a single data point.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;transform&lt;/code&gt;&lt;/strong&gt;: An optional parameter for image transformations (e.g., data augmentation).&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Length method (&lt;code&gt;__len__&lt;/code&gt; method):&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;Returns the length of the dataset, which is the total number of data points. The length is calculated as the length of &lt;strong&gt;&lt;code&gt;key_frame&lt;/code&gt;&lt;/strong&gt; minus three times the &lt;strong&gt;&lt;code&gt;stack_size&lt;/code&gt;&lt;/strong&gt;. This indicates that the dataset is expected to contain sequences of images, and each data point consists of a stack of images.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Get item method (&lt;code&gt;__getitem__&lt;/code&gt; method):&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;Takes an index &lt;strong&gt;&lt;code&gt;idx&lt;/code&gt;&lt;/strong&gt; and returns the corresponding data point.&lt;/li&gt;
&lt;li&gt;The code attempts to load a sequence of images starting from the index &lt;strong&gt;&lt;code&gt;idx&lt;/code&gt;&lt;/strong&gt;. It constructs a list of image file paths (&lt;strong&gt;&lt;code&gt;img_names&lt;/code&gt;&lt;/strong&gt;) based on the specified &lt;strong&gt;&lt;code&gt;stack_size&lt;/code&gt;&lt;/strong&gt; and &lt;strong&gt;&lt;code&gt;root_dir&lt;/code&gt;&lt;/strong&gt;.&lt;/li&gt;
&lt;li&gt;It then opens each image using the Python Imaging Library (PIL) and stores the resulting image objects in a list (&lt;strong&gt;&lt;code&gt;images&lt;/code&gt;&lt;/strong&gt;).&lt;/li&gt;
&lt;li&gt;The label for the data point is extracted from the &lt;strong&gt;&lt;code&gt;key_frame&lt;/code&gt;&lt;/strong&gt; dataframe.&lt;/li&gt;
&lt;li&gt;If an image transformation function (&lt;strong&gt;&lt;code&gt;transform&lt;/code&gt;&lt;/strong&gt;) is provided, it applies the transformation to each image in the sequence.&lt;/li&gt;
&lt;li&gt;The images are stacked along a new dimension (dimension 1) using &lt;strong&gt;&lt;code&gt;torch.stack(images, dim=1)&lt;/code&gt;&lt;/strong&gt;, and then the singleton dimension is removed using &lt;strong&gt;&lt;code&gt;squeeze()&lt;/code&gt;&lt;/strong&gt;. This results in a tensor representing the stacked images.&lt;/li&gt;
&lt;li&gt;The stacked images tensor and the label tensor are returned as a tuple.&lt;/li&gt;
&lt;li&gt;Note: There's a &lt;strong&gt;&lt;code&gt;try-except&lt;/code&gt;&lt;/strong&gt; block that catches potential errors, and if an error occurs, it falls back to using the first sequence in the dataset. This is a basic form of error handling, but it's often better to handle errors more explicitly depending on the use case.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Balancing the Dataset and creating Dataloader
&lt;/h2&gt;

&lt;p&gt;If you analyze the dataset closely, you will be able to see that there is a severe class imbalance in the dataset. This is because most of the time, while playing the game, the user is not pressing any key. Thus, most of the dataset items belong to class 0.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fz4ptbb0yc6gt74ut3aun.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fz4ptbb0yc6gt74ut3aun.png" alt="Histogram"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We fix this by using a PyTorch &lt;code&gt;RandomWeightedSampler&lt;/code&gt; that takes as input the weights of each sample. We calculate these weights using the code below.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;STACK_SIZE = 4
BATCH_SIZE = 32

train, test = train_test_split(pd.read_csv("data/labels_snake.csv"), test_size=0.2, shuffle=False)
classes = ["n", "left", "up", "right", "down"]

labels_unique, counts = np.unique(train["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
example_weights = np.array([class_weights[l] for l in train['class']])
example_weights = np.roll(example_weights, -STACK_SIZE)
sampler = WeightedRandomSampler(example_weights, len(train))

labels_unique, counts = np.unique(test["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
test_example_weights = np.array([class_weights[l] for l in test['class']])
test_example_weights = np.roll(test_example_weights, -STACK_SIZE)
test_sampler = WeightedRandomSampler(test_example_weights, len(test))
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The process of calculating the weight of each sample in a dataset is an important step in various machine-learning tasks such as image classification, text classification, and object detection, among others. The weight of each sample is used to balance the dataset so that each class contributes equally to the learning process. When a dataset is imbalanced, that is, when some classes have more samples than others, the learning algorithm may focus more on the majority class and ignore the minority class, leading to poor performance.&lt;/p&gt;

&lt;p&gt;To calculate the weight of each sample, we first split the dataset csv into train and test using the train_test_split method from the scikit-learn (sklearn) library. After that, we get the unique labels and the count of each label. This helps us to understand the distribution of classes in the dataset. We then calculate the weightage of each class by dividing the sum of all counts (size of the dataset) by the number of times that class occurs. This gives us a measure of how important each class is in the dataset.&lt;/p&gt;

&lt;p&gt;The next step is to get the weight of each example by assigning the class weight to the example. This is done by iterating through the dataset and assigning the weight to each sample based on its class label. We roll the example weights by the stack size because the label associated with a certain image is actually the label of the index of that image + STACK_SIZE. This ensures that each sample is given the correct weight based on its class label.&lt;/p&gt;

&lt;p&gt;In summary, calculating the weight of each sample in a dataset is a crucial step in machine learning tasks. It helps to balance the dataset and ensures that each class contributes equally to the learning process. The process involves splitting the dataset into train and test sets, obtaining the unique labels and the count of each label, calculating the weightage of each class, and assigning the weight to each sample based on its class label.&lt;/p&gt;

&lt;p&gt;Finally, we can create the dataloader for both the training and test datasets.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SnakeDataSet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataframe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;STACK_SIZE&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;BATCH_SIZE&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sampler&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sampler&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drop_last&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;test_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SnakeDataSet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;root_dir&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;captures&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataframe&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stack_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;STACK_SIZE&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;test_dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;test_dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;BATCH_SIZE&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sampler&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;test_sampler&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drop_last&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Create Transforms
&lt;/h2&gt;

&lt;p&gt;For the model that we are going to use, we need to perform certain transforms on the data. First, we need to compute the mean and standard deviation of the dataset images. We can do this using this code.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;compute_mean_std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;'''&lt;/span&gt;&lt;span class="s"&gt;
    We assume that the images of the dataloader have the same height and width
    source: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_std_mean.py
    &lt;/span&gt;&lt;span class="sh"&gt;'''&lt;/span&gt;
    &lt;span class="c1"&gt;# var[X] = E[X**2] - E[X]**2
&lt;/span&gt;    &lt;span class="n"&gt;channels_sum&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;channels_sqrd_sum&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_batches&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch_images&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;  &lt;span class="c1"&gt;# (B,H,W,C)
&lt;/span&gt;        &lt;span class="n"&gt;batch_images&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;batch_images&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;permute&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;channels_sum&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_images&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;channels_sqrd_sum&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_images&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;num_batches&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;channels_sum&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_batches&lt;/span&gt;
    &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;channels_sqrd_sum&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;num_batches&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt;

&lt;span class="nf"&gt;compute_mean_std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This will return the mean and standard deviation, which we can plug into the transformation below.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;transformer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="nc"&gt;Resize&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;antialias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;CenterCrop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
    &lt;span class="nc"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.7138&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.9883&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;1.5832&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.2253&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2192&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2149&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; 
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This resizes the image to 84x84, crops it, converts it to tensor and normalizes it. &lt;/p&gt;

&lt;h2&gt;
  
  
  Creating the Model
&lt;/h2&gt;

&lt;p&gt;As discussed in the introduction, we are going to use the PyTorch-provided &lt;code&gt;r3d&lt;/code&gt; model. This is a 3D Convolution model that uses ResNet architecture.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;r3d_18&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;R3D_18_Weights&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DEFAULT&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We load the default weights and replace the last fully connected layer to predict 5 output classes as required in our application. This is what the network looks like. We have about 33 million trainable parameters.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu9f4owee1ajxg3bbioxp.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu9f4owee1ajxg3bbioxp.png" alt="Network Architecture"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Training the Model
&lt;/h2&gt;

&lt;p&gt;Now, we can finally train the model. We first create the optimizer and the loss criterion. Here, we use a learning rate of 10e-5 and a weight decay of 0.1.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mf"&gt;10e-5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight_decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now, we create the training loop.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
    &lt;span class="n"&gt;correct_predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;total_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="n"&gt;val_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
    &lt;span class="n"&gt;val_correct_predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;val_total_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="c1"&gt;# Set model to training mode
&lt;/span&gt;    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="c1"&gt;# tqdm bar for progress visualization
&lt;/span&gt;    &lt;span class="n"&gt;pbar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;desc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;leave&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Backward pass and optimization
&lt;/span&gt;        &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="c1"&gt;# Update statistics
&lt;/span&gt;        &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;correct_predictions&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;total_samples&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Update tqdm bar with current loss and accuracy
&lt;/span&gt;        &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_postfix&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Loss&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total_samples&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Accuracy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;correct_predictions&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;total_samples&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;
        &lt;span class="n"&gt;steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;inference_mode&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;test_dataloader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
            &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="c1"&gt;# Update statistics
&lt;/span&gt;            &lt;span class="n"&gt;val_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;val_correct_predictions&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;val_total_samples&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Calculate and print epoch-level accuracy and loss for validation
&lt;/span&gt;    &lt;span class="n"&gt;epoch_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;val_loss&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;val_total_samples&lt;/span&gt;
    &lt;span class="n"&gt;epoch_accuracy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;val_correct_predictions&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;val_total_samples&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Epoch &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;num_epochs&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, Val Loss: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, Val Accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;epoch_accuracy&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;save&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;state_dict&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;model_r3d.pth&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This will train the network. &lt;/p&gt;

&lt;h2&gt;
  
  
  Play the game
&lt;/h2&gt;

&lt;p&gt;Create a new python script &lt;code&gt;play.py&lt;/code&gt;. In this script, we will load the model and write code similar to &lt;code&gt;[capture.py](http://capture.py)&lt;/code&gt; to load the game and collect screenshots of the game. We will stack 4 screenshots and pass it through our network.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;base64&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;keyboard&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;PIL&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Image&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchvision.transforms&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Resize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;CenterCrop&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="n"&gt;Grayscale&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;deque&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torchvision.models.video&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;r3d_18&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;selenium&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;webdriver&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;selenium.webdriver.common.by&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;By&lt;/span&gt;

&lt;span class="n"&gt;label_keys&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;""&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;left&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;up&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;right&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;down&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cuda&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cuda&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;is_available&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;r3d_18&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_state_dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;model_mc3.pth&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;transformer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Compose&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="nc"&gt;Resize&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;antialias&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;CenterCrop&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;84&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;ToTensor&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
    &lt;span class="nc"&gt;Normalize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.7138&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.9883&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;1.5832&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.2253&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2192&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2149&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="c1"&gt;# Create a webdriver instance using Firefox
&lt;/span&gt;&lt;span class="n"&gt;driver&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;webdriver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Firefox&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="c1"&gt;# Navigate to the Google Snake game website
&lt;/span&gt;&lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;https://www.google.com/fbx?fbx=snake_arcade&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;frame_stack&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;deque&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;maxlen&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
     &lt;span class="c1"&gt;# Find the canvas element on the webpage
&lt;/span&gt;    &lt;span class="n"&gt;canvas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;find_element&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;By&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CSS_SELECTOR&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;canvas&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Get the base64 encoded image data from the canvas
&lt;/span&gt;    &lt;span class="n"&gt;canvas_base64&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;driver&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;execute_script&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;return arguments[0].toDataURL(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;image/png&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;).substring(21);&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;canvas&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Decode the base64 data to get the PNG image
&lt;/span&gt;    &lt;span class="n"&gt;canvas_png&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;base64&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;b64decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;canvas_base64&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Convert the PNG image to a grayscale numpy array
&lt;/span&gt;    &lt;span class="n"&gt;image&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cvtColor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Image&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;io&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;BytesIO&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;canvas_png&lt;/span&gt;&lt;span class="p"&gt;))),&lt;/span&gt; &lt;span class="n"&gt;cv2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;COLOR_BGR2RGB&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt; 

    &lt;span class="n"&gt;frame_stack&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;image&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="nb"&gt;input&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;frame_stack&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;squeeze&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;unsqueeze&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;frame_stack&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;inference_mode&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;preds&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;keyboard&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;press_and_release&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;label_keys&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;preds&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Once we run this script, you can see the snake playing.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkvt78c8gzk4ojrw7f3gh.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkvt78c8gzk4ojrw7f3gh.gif" alt="Demo"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Conclusion
&lt;/h2&gt;

&lt;p&gt;In wrapping up our journey into building an AI game bot for Google Snake, we've explored the exciting realm of imitation learning and harnessed the power of a 3D Convolution ResNet model. Beyond the fun of game development, the acquired skills have broader applications in serious AI scenarios.&lt;/p&gt;

&lt;p&gt;We began by understanding the significance of imitation learning, focusing on training our AI agent through human gameplay. The use of a 3D Convolution ResNet model allowed us to accurately capture motion by stacking four frames of the game.&lt;/p&gt;

&lt;p&gt;Practical details covered data collection using Selenium, creating labeled datasets, and transforming them into a PyTorch-friendly format. We highlighted the importance of addressing class imbalance through WeightedRandomSampler.&lt;/p&gt;

&lt;p&gt;Implementation involved constructing a custom dataset class, selecting an appropriate model architecture, and training the model using essential transforms. The training loop and evaluation on a test set provided insights into the model's performance.&lt;/p&gt;

&lt;p&gt;To complete the journey, we showcased how the trained model can be used to play the game in real-time. The provided Python script demonstrated capturing frames, making predictions, and controlling the snake's movements.&lt;/p&gt;

&lt;p&gt;In essence, this concise guide empowers you to master AI in gaming through imitation learning and 3D Convolution ResNet models, with skills extending to broader AI applications.&lt;/p&gt;

&lt;p&gt;Git Repo: &lt;a href="https://github.com/akshayballal95/autodrive-snake/tree/blog" rel="noopener noreferrer"&gt;https://github.com/akshayballal95/autodrive-snake/tree/blog&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Want to connect?&lt;/p&gt;

&lt;p&gt;🌍&lt;a href="http://www.akshaymakes.com" rel="noopener noreferrer"&gt;My Website&lt;/a&gt;&lt;br&gt;&lt;br&gt;
🐦&lt;a href="https://twitter.com/akshayballal95" rel="noopener noreferrer"&gt;My Twitter&lt;/a&gt;&lt;br&gt;&lt;br&gt;
👨&lt;a href="https://www.linkedin.com/in/akshay-ballal/" rel="noopener noreferrer"&gt;My LinkedIn&lt;/a&gt;&lt;/p&gt;

</description>
      <category>ai</category>
      <category>machinelearning</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
  </channel>
</rss>
