DEV Community

David Haley
David Haley

Posted on

Get GPU names + counts with TensorFlow API

I'm running some benchmarks and want to output the GPU in use, rather than track it manually.

In my case, there'll only ever be 0 or 1 gpu type, but possibly more than 1. Hence the error if there are multiple names.

May the graphics be ever in your favor! 🖥️

from itertools import groupby
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
gpu_details = [tf.config.experimental.get_device_details(gpu) for gpu in gpu_devices]

gpus_by_name = {
  k: list(v) for k, v in groupby(gpu_details, key = lambda x: x['device_name'])
}

gpu_names = list(gpus_by_name.keys())

if len(gpu_names) == 0:
    gpu_name = "None"
    gpu_count = 0
elif len(gpu_names) == 1:
    gpu_name = gpu_names[0]
    gpu_count = len(gpus_by_name[gpu_name])
if len(gpus_by_name.keys()) > 1:
    raise "Dunno how to handle multiple gpu types"
Enter fullscreen mode Exit fullscreen mode

Top comments (0)