"""
Core module for MatplotlibPickleEditor.
Handles loading, modifying and saving serialized matplotlib figures.
"""
import pickle
import logging
import warnings
from pathlib import Path
from typing import Dict, Optional
import copy
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
except ImportError as e:
raise ImportError(
"matplotlib not installed. Run: pip install matplotlib"
) from e
logger = logging.getLogger(__name__)
[docs]
class PickleSecurityError(Exception):
"""Exception for security issues in pickle loading."""
pass
[docs]
class MatplotlibPickleEditor:
"""
Core editor to modify colors, labels, and styles in serialized matplotlib figures.
Attributes:
COMMON_COLORS (List[str]): List of common predefined colors for quick selection.
"""
COMMON_COLORS = [
'red', 'blue', 'green', 'black', 'orange', 'purple', 'brown',
'pink', 'gray', 'cyan', 'magenta', 'yellow',
'#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A994E',
'#BC4B51', '#5B8E7D', '#8B5A3C', '#264653', '#E76F51'
]
def __init__(self, pickle_path: Path, strict_validation: bool = True):
"""
Initialize the editor.
Args:
pickle_path (Path): Path to the pickle file.
strict_validation (bool): If True, ensures the loaded object is a Figure.
"""
self.pickle_path = pickle_path
self.strict_validation = strict_validation
self.figure: Optional[Figure] = None
self.axes: Optional[Axes] = None
self._modifications_made = False
if not self.pickle_path.exists():
raise FileNotFoundError(f"File not found: {self.pickle_path}")
[docs]
def load(self) -> None:
"""
Load and deserialize the pickle file.
Raises:
PickleSecurityError: If the file is corrupted or not a valid Figure.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
with open(self.pickle_path, 'rb') as f:
obj = pickle.load(f)
except pickle.UnpicklingError as e:
raise PickleSecurityError(
f"Pickle file corrupted or invalid: {e}"
) from e
if not isinstance(obj, Figure):
if self.strict_validation:
raise PickleSecurityError(
f"Object is not matplotlib.figure.Figure, but {type(obj)}"
)
else:
logger.warning(f"WARNING: unexpected type {type(obj)}")
self.figure = obj
if self.figure.axes:
self.axes = self.figure.axes[0]
else:
raise PickleSecurityError("No axes found in figure")
[docs]
def get_legend_labels(self) -> Dict[int, str]:
"""
Retrieve current legend labels mapped by index.
Returns:
Dict[int, str]: Dictionary {index: label_text}.
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return {}
labels = {}
for idx, text in enumerate(legend.get_texts()):
labels[idx] = text.get_text()
return labels
[docs]
def get_line_colors(self) -> Dict[str, str]:
"""
Get current line colors.
Returns:
Dictionary {label: hex_color}
Raises:
RuntimeError: If load() has not been called
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return {}
lines = self.axes.get_lines()
colors = {}
for line, text in zip(lines, legend.get_texts()):
label = text.get_text()
color = mpl.colors.to_hex(line.get_color())
colors[label] = color
return colors
[docs]
def rename_legend_labels(self, mapping: Dict[str, str]) -> int:
"""
Rename specific legend labels.
Args:
mapping (Dict[str, str]): Map of {old_name: new_name}.
Returns:
int: Number of labels successfully renamed.
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
changed = 0
for text in legend.get_texts():
current_label = text.get_text()
if current_label in mapping:
text.set_text(mapping[current_label])
changed += 1
self._modifications_made = True
return changed
[docs]
def change_line_colors(self, label_color_map: Dict[str, str]) -> int:
"""
Update the color of lines associated with specific legend labels.
Args:
label_color_map (Dict[str, str]): Map of {label: hex_color/name}.
Returns:
int: Number of lines updated.
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
lines = self.axes.get_lines()
legend_texts = legend.get_texts()
legend_lines = legend.get_lines()
changed = 0
for line, leg_line, text in zip(lines, legend_lines, legend_texts):
label = text.get_text()
if label in label_color_map:
color = label_color_map[label]
try:
# Update both the plot line and legend line
line.set_color(color)
leg_line.set_color(color)
changed += 1
self._modifications_made = True
except ValueError as e:
logger.warning(f"Invalid color '{color}' for '{label}': {e}")
return changed
[docs]
def change_line_linestyle(self, style_map: Dict[str, str]) -> int:
"""
Update the line style (e.g., solid, dashed) for specific labels.
Args:
style_map (Dict[str, str]): Map of {label: style_string} (e.g. '--').
Returns:
int: Number of lines updated.
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
lines = self.axes.get_lines()
legend_texts = legend.get_texts()
legend_lines = legend.get_lines()
changed = 0
for line, leg_line, text in zip(lines, legend_lines, legend_texts):
label = text.get_text()
if label in style_map:
style = style_map[label]
try:
# Update both the plot line and legend line
line.set_linestyle(style)
leg_line.set_linestyle(style)
changed += 1
self._modifications_made = True
except Exception as e:
logger.warning(f"Invalid style '{style}' for '{label}': {e}")
return changed
[docs]
def change_line_linewidth(self, width_map: Dict[str, float]) -> int:
"""
Change line widths and legend widths.
Args:
width_map: Dictionary {label: width}
Returns:
Number of widths changed
Raises:
RuntimeError: If load() has not been called
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
lines = self.axes.get_lines()
legend_texts = legend.get_texts()
legend_lines = legend.get_lines()
changed = 0
for line, leg_line, text in zip(lines, legend_lines, legend_texts):
label = text.get_text()
if label in width_map:
width = width_map[label]
try:
# Update both the plot line and legend line
line.set_linewidth(width)
leg_line.set_linewidth(width)
changed += 1
self._modifications_made = True
except Exception as e:
logger.warning(f"Invalid width '{width}' for '{label}': {e}")
return changed
[docs]
def change_line_alpha(self, alpha_map: Dict[str, float]) -> int:
"""
Change line transparency and legend transparency.
Args:
alpha_map: Dictionary {label: alpha} (0-1)
Returns:
Number of alpha values changed
Raises:
RuntimeError: If load() has not been called
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
lines = self.axes.get_lines()
legend_texts = legend.get_texts()
legend_lines = legend.get_lines()
changed = 0
for line, leg_line, text in zip(lines, legend_lines, legend_texts):
label = text.get_text()
if label in alpha_map:
alpha = alpha_map[label]
try:
if not 0 <= alpha <= 1:
logger.warning(f"Alpha must be between 0 and 1, received {alpha}")
continue
# Update both the plot line and legend line
line.set_alpha(alpha)
leg_line.set_alpha(alpha)
changed += 1
self._modifications_made = True
except Exception as e:
logger.warning(f"Invalid alpha '{alpha}' for '{label}': {e}")
return changed
[docs]
def change_line_visibility(self, visibility_map: Dict[str, bool]) -> int:
"""
Change line visibility (show/hide).
Args:
visibility_map: Dictionary {label: bool} (True=visible, False=hidden)
Returns:
Number of lines changed
Raises:
RuntimeError: If load() has not been called
"""
if not self.axes:
raise RuntimeError("You must call load() first")
legend = self.axes.get_legend()
if not legend:
return 0
lines = self.axes.get_lines()
legend_texts = legend.get_texts()
legend_lines = legend.get_lines()
changed = 0
for line, leg_line, text in zip(lines, legend_lines, legend_texts):
label = text.get_text()
if label in visibility_map:
visible = visibility_map[label]
try:
# Update plot line visibility
line.set_visible(visible)
# Update legend handle visibility
leg_line.set_visible(visible)
# Dim the legend text if hidden, restore if visible
text.set_alpha(1.0 if visible else 0.5)
changed += 1
self._modifications_made = True
except Exception as e:
logger.warning(f"Invalid visibility '{visible}' for '{label}': {e}")
return changed
[docs]
def save(self, output_path: Optional[Path] = None,
format: str = 'pickle') -> Path:
"""
Serialize the modified figure to disk.
Args:
output_path (Optional[Path]): Destination file. If None, overwrites original.
format (str): Output format ('pickle', 'png', 'pdf', 'svg').
Returns:
Path: The actual path of the saved file.
"""
if not self.figure:
raise RuntimeError("You must call load() first")
if output_path is None:
if format == 'pickle':
output_path = self.pickle_path
else:
output_path = self.pickle_path.with_suffix(f'.{format}')
if format == 'pickle':
with open(output_path, 'wb') as f:
pickle.dump(self.figure, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
self.figure.savefig(output_path, format=format, dpi=300,
bbox_inches='tight')
self._modifications_made = False
return output_path
[docs]
def preview(self) -> None:
"""
Display a copy of the current figure for preview.
Raises:
RuntimeError: If load() has not been called
"""
if not self.figure:
raise RuntimeError("You must call load() first")
plt.show()
[docs]
def has_modifications(self) -> bool:
"""
Check if there are unsaved modifications.
Returns:
True if there are unsaved modifications
"""
return self._modifications_made