HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Plot heat map from csv file using numpy and matplotlib

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
mapfilenumpymatplotlibcsvandheatusingplotfrom

Problem

There's a csv file with format:

x0, y0, v00
x0, y1, v01
...
x1, y0  v10
...


And what I want to do is to plot a heat map, in which at location (x, y) the value v is plotted with corresponding color. Below is my current implementation.

```
import random
import numpy as np
import matplotlib.pyplot as plt

def create_test_csv(file):
random.seed(42)
f = open(file, "w")
for x in range(300):
for y in range(600):
value = random.randrange(255)
f.write(str(x) + "," + str(y) + "," + str(value) + "\n")

def get_xyz_from_csv_file(csv_file_path):
'''
get x, y, z value from csv file
csv file format: x0,y0,z0
'''
x = []
y = []
z = []
map_value = {}

for line in open(csv_file_path):
list = line.split(",")
temp_x = float(list[0])
temp_y = float(list[1])
temp_z = float(list[2])
x.append(temp_x)
y.append(temp_y)
z.append(temp_z)
map_value[(temp_x, temp_y)] = temp_z

return x, y, map_value

def draw_heatmap(x, y, map_value):

plt_x = np.asarray(list(set(x)))
plt_y = np.asarray(list(set(y)))
plt_z = np.zeros(shape = (len(plt_x), len(plt_y)))

for i in range(len(plt_x)):
for j in range(len(plt_y)):
if map_value.has_key((plt_x.item(i), plt_y.item(j))):
plt_z[i][j] = map_value[(plt_x.item(i), plt_y.item(j))]

z_min = plt_z.min()
z_max = plt_z.max()
plt_z = np.transpose(plt_z)

plot_name = "demo"

color_map = plt.cm.gist_heat #plt.cm.rainbow #plt.cm.hot #plt.cm.gist_heat
plt.clf()
plt.pcolor(plt_x, plt_y, plt_z, cmap=color_map, vmin=z_min, vmax=z_max)
plt.axis([plt_x.min(), plt_x.max(), plt_y.min(), plt_y.max()])
plt.title(plot_name)
plt.colorbar().set_label(plot_name, rotation=270)
ax = plt.gca()
ax.set_aspect('equal')
figure = plt.gcf()
plt.show()
return figure

if __name__ == "__main__":
csv_file_nam

Solution

Since you are already using numpy, you can use numpy's loadtxt function to read in all the data at once as numpy arrays from the start. This allows you to avoid having to worry about opening or closing files (this is done automatically), converting to numpy arrays, etc. Then it is a simple matter of converting the indexes to values.

You can also vectorize the test data creation, using numpy's meshgrid function to get a grid of corresponding X and Y coordinates.

You can make the plotting better, in my opinion at least, by using plt.subplots() to get a figure and axes object right at the beginning, then using those to do the plotting.

So here is how I would do it:

import numpy as np
import matplotlib.pyplot as plt

def create_test_csv(fname):
    np.random.seed(42)

    # Generate X and Y coordinates
    x = np.arange(300)
    y = np.arange(600)

    # Get corresponding X and Y coordinates
    xs, ys = np.meshgrid(x, y)

    # Get random values for each location
    zs = np.random.randint(0, 255, size=xs.size)

    # Convert 3 2D arrays to 1 2D array of columns
    data = np.vstack([xs.ravel(), ys.ravel(), zs.ravel()]).T

    # Save to file
    np.savetxt(fname, data, delimiter=',', fmt='%d')

def get_xyz_from_csv_file_np(csv_file_path):
    '''
    get a grid of values from a csv file
    csv file format: x0,y0,z0
    '''

    # Load the csv file into a single 2D array, 
    # then split the columns into individual variables.
    x, y, z = np.loadtxt(csv_file_path, delimiter=',', dtype=np.int).T

    # Create an empty 2D array of pixels and 
    # put all the values into the correct place
    plt_z = np.zeros((y.max()+1, x.max()+1))
    plt_z[y, x] = z

    return plt_z

def draw_heatmap(plt_z):
    # Generate y and x values from the dimension lengths
    plt_y = np.arange(plt_z.shape[0])
    plt_x = np.arange(plt_z.shape[1])

    # everything is the same from here on
    z_min = plt_z.min()
    z_max = plt_z.max() 

    plot_name = "demo"

    color_map = plt.cm.gist_heat #plt.cm.rainbow #plt.cm.hot #plt.cm.gist_heat
    fig, ax = plt.subplots()
    cax = ax.pcolor(plt_x, plt_y, plt_z, cmap=color_map, vmin=z_min, vmax=z_max)
    ax.set_xlim(plt_x.min(), plt_x.max())
    ax.set_ylim(plt_y.min(), plt_y.max())
    fig.colorbar(cax).set_label(plot_name, rotation=270) 
    ax.set_title(plot_name)  
    ax.set_aspect('equal')
    plt.show()
    return figure
    figure = plt.gcf()
    plt.show()
    return figure   

if __name__ == "__main__":
    fname = 'temp.csv'
    create_test_csv(fname)
    res = get_xyz_from_csv_file_np(fname)
    draw_heatmap(res)

Code Snippets

import numpy as np
import matplotlib.pyplot as plt


def create_test_csv(fname):
    np.random.seed(42)

    # Generate X and Y coordinates
    x = np.arange(300)
    y = np.arange(600)

    # Get corresponding X and Y coordinates
    xs, ys = np.meshgrid(x, y)

    # Get random values for each location
    zs = np.random.randint(0, 255, size=xs.size)

    # Convert 3 2D arrays to 1 2D array of columns
    data = np.vstack([xs.ravel(), ys.ravel(), zs.ravel()]).T

    # Save to file
    np.savetxt(fname, data, delimiter=',', fmt='%d')


def get_xyz_from_csv_file_np(csv_file_path):
    '''
    get a grid of values from a csv file
    csv file format: x0,y0,z0
    '''

    # Load the csv file into a single 2D array, 
    # then split the columns into individual variables.
    x, y, z = np.loadtxt(csv_file_path, delimiter=',', dtype=np.int).T

    # Create an empty 2D array of pixels and 
    # put all the values into the correct place
    plt_z = np.zeros((y.max()+1, x.max()+1))
    plt_z[y, x] = z

    return plt_z


def draw_heatmap(plt_z):
    # Generate y and x values from the dimension lengths
    plt_y = np.arange(plt_z.shape[0])
    plt_x = np.arange(plt_z.shape[1])

    # everything is the same from here on
    z_min = plt_z.min()
    z_max = plt_z.max() 

    plot_name = "demo"

    color_map = plt.cm.gist_heat #plt.cm.rainbow #plt.cm.hot #plt.cm.gist_heat
    fig, ax = plt.subplots()
    cax = ax.pcolor(plt_x, plt_y, plt_z, cmap=color_map, vmin=z_min, vmax=z_max)
    ax.set_xlim(plt_x.min(), plt_x.max())
    ax.set_ylim(plt_y.min(), plt_y.max())
    fig.colorbar(cax).set_label(plot_name, rotation=270) 
    ax.set_title(plot_name)  
    ax.set_aspect('equal')
    plt.show()
    return figure
    figure = plt.gcf()
    plt.show()
    return figure   


if __name__ == "__main__":
    fname = 'temp.csv'
    create_test_csv(fname)
    res = get_xyz_from_csv_file_np(fname)
    draw_heatmap(res)

Context

StackExchange Code Review Q#132499, answer score: 3

Revisions (0)

No revisions yet.