Hi All,
I would like to scroll on axis set from twinx command.
I have described the problem in a question there:
I really would like to implement this feature.
Any help is welcome.
Regards
Patrick
Hi All,
I would like to scroll on axis set from twinx command.
I have described the problem in a question there:
I really would like to implement this feature.
Any help is welcome.
Regards
Patrick
0
Here is a solution to the issue of having separate axis that respond to zooming and scrolling.
I’ve also added the option to display the range either for each axis individually or for the overall plot.
The solution is not as easy as expected because axes created from twinx do not handle the event ax.xaxis.contains(event) so I had to use distance to its spines.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axis import XAxis, YAxis
from shapely.geometry import LineString, Point
#========================================================================================
# Data
x = np.array([0, 1, 2])
y1 = np.array([0, 1, 2]) # Density
y2 = np.array([0, 3, 2]) # Temperature
y3 = np.array([50, 30, 15]) # Velocity
#========================================================================================
# Create figure and main axis
fig, ax = plt.subplots(1, 1)
# Define left and bottom margins
margin_left = 0.4
margin_bottom = 0.2
fig.subplots_adjust(left=margin_left, bottom=margin_bottom)
ax.patch.set_visible(False) # Hide the background of the main axis
ax.pan_start = None # Variable to store the starting position for panning
fig_width, fig_height = fig.canvas.get_width_height()
bbox = ax.get_position()
ax_width = bbox.width * fig_width
ax_height = bbox.height * fig_height
# Create secondary axes (twinx)
twin1 = ax.twinx()
position = -80 / ax_width
twin1.spines['left'].set_position(('axes', position)) # Shift left spine
twin1.spines['left'].set_visible(True)
twin1.yaxis.set_label_position('left')
twin1.yaxis.set_ticks_position('left')
twin1.set_zorder(-10) # Send this axis to the background
twin2 = ax.twinx()
position = -160 / ax_width
twin2.spines['left'].set_position(('axes', position)) # Shift left spine
twin2.spines['left'].set_visible(True)
twin2.yaxis.set_label_position('left')
twin2.yaxis.set_ticks_position('left')
twin2.set_zorder(-10) # Send this axis further back
# List of axes for detection
axs = [ax, twin1, twin2]
ax.plot(x, y1, color='red', label='Density', zorder=1)
ax.set(xlabel='Distance', ylabel='Density')
ax.yaxis.label.set_color('red')
ax.grid(visible=True, which='major', color='lightgray', linestyle='dashed', linewidth=0.5)
twin1.plot(x, y2, color='blue', label='Temperature', zorder=-10)
twin1.set(ylabel='Temperature')
twin1.yaxis.label.set_color('blue')
twin2.plot(x, y3, color='green', label='Velocity', zorder=-10)
twin2.set(ylabel='Velocity')
twin2.yaxis.label.set_color('green')
#========================================================================================
def detect_artist(event):
#------------------------------------------------------------------
closest_axis = None
axis_type = None
distance_min = float('inf')
#------------------------------------------------------------------
#print('-----------------')
#print(f'Mouse: {event.x}, {event.y}')
for ax_current in axs:
if ax_current.contains(event)[0]:
return ax_current, ax_current # If cursor is inside the main axis, return it
mouse = Point((event.x, event.y))
spine = ax_current.spines['left']
path = spine.get_path().transformed(spine.get_transform())
spine_left = LineString(path.vertices)
spine = ax_current.spines['bottom']
path = spine.get_path().transformed(spine.get_transform())
spine_bottom = LineString(path.vertices)
distance_left = mouse.distance(spine_left)
distance_bottom = mouse.distance(spine_bottom)
#print(f"Distance : {distance_left}, {distance_bottom}")
if min(distance_left, distance_bottom) < distance_min:
distance_min = min(distance_left, distance_bottom)
closest_axis = ax_current
axis_type = ax_current.yaxis if distance_left < distance_bottom else ax_current.xaxis
return closest_axis, axis_type
#========================================================================================
def on_scroll(event):
"""Zoom in and out based on mouse scroll."""
if event.button == 'up':
scale_factor = 0.9 # Zoom in
elif event.button == 'down':
scale_factor = 1.1 # Zoom out
ax, artist = detect_artist(event) # Detect the Artist element under the mouse
if artist is None:
#print("No artist detected under the scroll event.")
return
#------------------------------
if isinstance(artist, XAxis):
# Zoom on the X axis
cur_xlim = ax.get_xlim()
xdata = event.xdata if event.xdata is not None else (cur_xlim[0] + cur_xlim[1]) / 2
new_xlim = [xdata - (xdata - cur_xlim[0]) * scale_factor,
xdata + (cur_xlim[1] - xdata) * scale_factor]
ax.set_xlim(new_xlim)
#print("Zooming in on the X axis (ticks or labels)")
#------------------------------
elif isinstance(artist, YAxis):
# Zoom on the Y axis
cur_ylim = ax.get_ylim()
ydata = event.ydata if event.ydata is not None else (cur_ylim[0] + cur_ylim[1]) / 2
new_ylim = [ydata - (ydata - cur_ylim[0]) * scale_factor,
ydata + (cur_ylim[1] - ydata) * scale_factor]
ax.set_ylim(new_ylim)
#print("Zooming in on the Y axis (ticks or labels)")
#------------------------------
elif isinstance(artist, plt.Axes):
# Zoom on both axes
cur_xlim = ax.get_xlim()
cur_ylim = ax.get_ylim()
xdata = event.xdata
ydata = event.ydata
if xdata is None or ydata is None:
return # Prevent NoneType error
new_xlim = [xdata - (xdata - cur_xlim[0]) * scale_factor,
xdata + (cur_xlim[1] - xdata) * scale_factor]
new_ylim = [ydata - (ydata - cur_ylim[0]) * scale_factor,
ydata + (cur_ylim[1] - ydata) * scale_factor]
ax.set_xlim(new_xlim)
ax.set_ylim(new_ylim)
#print("Zooming in on both axes (plot area)")
ax.figure.canvas.draw() # Redraw the canvas
#========================================================================================
def on_press(event):
"""Store the starting point for panning."""
if event.button != 1: return # Only left button
ax, artist = detect_artist(event) # Detect the Artist element under the mouse
#print("press", ax, artist, type(artist))
if artist is None: return
display_coord = (event.x, event.y)
coordx, coordy = ax.transData.inverted().transform(display_coord)
ax.pan_start = coordx, coordy # Store the starting point for panning
#========================================================================================
def on_motion(event):
"""Handle panning when the mouse is moved."""
ax, artist = detect_artist(event) # Detect the Artist element under the mouse
#print("motion", ax, artist, event.inaxes)
if artist is None: return
#-------------------------
if hasattr(ax, "pan_start") and ax.pan_start:
xstart, ystart = ax.pan_start
display_coord = (event.x, event.y)
coordx, coordy = ax.transData.inverted().transform(display_coord)
dx = xstart - coordx
dy = ystart - coordy
cur_xlim = ax.get_xlim()
cur_ylim = ax.get_ylim()
# Update limits for panning
if isinstance(artist, XAxis):
ax.set_xlim(cur_xlim[0] + dx, cur_xlim[1] + dx)
elif isinstance(artist, YAxis):
ax.set_ylim(cur_ylim[0] + dy, cur_ylim[1] + dy)
else:
ax.set_xlim(cur_xlim[0] + dx, cur_xlim[1] + dx)
ax.set_ylim(cur_ylim[0] + dy, cur_ylim[1] + dy)
ax.figure.canvas.draw() # Redraw the canvas
#========================================================================================
def on_release(event):
"""Remove the starting point for panning."""
for ax in axs:
ax.pan_start = None
#========================================================================================
def on_key_press(event):
ax, artist = detect_artist(event) # Detect the Artist element under the mouse
#------------------------------
if event.key == 'a':
if isinstance(artist, plt.Axes):
#print("key a on Axe")
# Set vertical range
for ax_current in axs:
visible_lines = [line for line in ax_current.lines if line.get_visible()]
if visible_lines:
y_min = min(line.get_ydata().min() for line in visible_lines)
y_max = max(line.get_ydata().max() for line in visible_lines)
y_margin = (y_max - y_min) * 0.05
ax_current.set_ylim(y_min - y_margin, y_max + y_margin)
# Set horizontal range
x_min = min(line.get_xdata().min() for line in visible_lines)
x_max = max(line.get_xdata().max() for line in visible_lines)
x_margin = (x_max - x_min) * 0.05
ax.set_xlim(x_min - x_margin, x_max + x_margin)
fig.canvas.draw()
#------------------------------
elif isinstance(artist, XAxis):
#print("key a on xaxis")
visible_lines = [line for line in ax.lines if line.get_visible()]
if visible_lines:
x_min = min(line.get_xdata().min() for line in visible_lines)
x_max = max(line.get_xdata().max() for line in visible_lines)
x_margin = (x_max - x_min) * 0.05
ax.set_xlim(x_min - x_margin, x_max + x_margin)
ax.figure.canvas.draw() # Redraw the canvas
#------------------------------
elif isinstance(artist, YAxis):
#print("key a on yaxis")
visible_lines = [line for line in ax.lines if line.get_visible()]
if visible_lines:
y_min = min(line.get_ydata().min() for line in visible_lines)
y_max = max(line.get_ydata().max() for line in visible_lines)
y_margin = (y_max - y_min) * 0.05
ax.set_ylim(y_min - y_margin, y_max + y_margin)
ax.figure.canvas.draw() # Redraw the canvas
#========================================================================================
# Connect events to their respective functions
fig.canvas.mpl_connect('scroll_event', on_scroll)
fig.canvas.mpl_connect('button_press_event', on_press)
fig.canvas.mpl_connect('motion_notify_event', on_motion)
fig.canvas.mpl_connect('button_release_event', on_release)
fig.canvas.mpl_connect('key_press_event', on_key_press)
plt.show()