total-classifier / modeling.py
ianpan's picture
Upload model
dbdbb0d verified
Raw
History Blame Contribute Delete
20.2 kB
import cv2
import glob
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PreTrainedModel
from timm import create_model
from .configuration import TotalClassifierConfig
from .label2index import label2index
_PYDICOM_AVAILABLE = False
try:
from pydicom import dcmread
_PYDICOM_AVAILABLE = True
except ModuleNotFoundError:
pass
_PANDAS_AVAILABLE = False
try:
import pandas as pd
_PANDAS_AVAILABLE = True
except ModuleNotFoundError:
pass
class RNNHead(nn.Module):
def __init__(
self,
rnn_type: str,
rnn_num_layers: int,
rnn_dropout: float,
feature_dim: int,
linear_dropout: float,
num_classes: int,
):
super().__init__()
self.rnn = getattr(nn, rnn_type)(
input_size=feature_dim,
hidden_size=feature_dim // 2,
num_layers=rnn_num_layers,
dropout=rnn_dropout,
batch_first=True,
bidirectional=True,
)
self.dropout = nn.Dropout(linear_dropout)
self.linear = nn.Linear(feature_dim, num_classes)
@staticmethod
def convert_seq_and_mask_to_packed_sequence(
seq: torch.Tensor, mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert seq.shape[0] == mask.shape[0]
lengths = mask.sum(1)
seq = nn.utils.rnn.pack_padded_sequence(
seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False
)
return seq
def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
skip = x
if mask is not None:
# convert to PackedSequence
L = x.shape[1]
x = self.convert_seq_and_mask_to_packed_sequence(x, mask)
x, _ = self.rnn(x)
if mask is not None:
# convert back to tensor
x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0]
x = x + skip
return self.linear(self.dropout(x))
class TotalClassifierModel(PreTrainedModel):
config_class = TotalClassifierConfig
def __init__(self, config):
super().__init__(config)
self.image_size = config.image_size
self.backbone = create_model(
model_name=config.backbone,
pretrained=False,
num_classes=0,
global_pool="",
features_only=True,
in_chans=config.in_chans,
)
self.cnn_dropout = nn.Dropout(p=config.cnn_dropout)
self.head = RNNHead(
rnn_type=config.rnn_type,
rnn_num_layers=config.rnn_num_layers,
rnn_dropout=config.rnn_dropout,
feature_dim=config.feature_dim,
linear_dropout=config.linear_dropout,
num_classes=config.num_classes,
)
self.label2index = label2index
self.index2label = {v: k for k, v in self.label2index.items()}
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
return_logits: bool = False,
return_as_dict: bool = False,
return_as_list: bool = False,
return_as_df: bool = False,
threshold: float = 0.5, # only used for return_as_list=True
) -> torch.Tensor:
if return_as_df:
assert (
_PANDAS_AVAILABLE
), "`return_as_df=True` requires pandas to be installed"
# x.shape = (b, n, c, h, w)
b, n, c, h, w = x.shape
# x = rearrange(x, "b n c h w -> (b n) c h w")
x = x.reshape(b * n, c, h, w)
x = self.normalize(x)
# avg pooling
features = self.backbone(x)
# take last feature map
features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1)
features = self.cnn_dropout(features)
# features = rearrange(features, "(b n) d -> b n d", b=b, n=n)
features = features.reshape(b, n, -1)
logits = self.head(features, mask=mask)
if return_logits:
# return raw logits
return logits
probas = logits.sigmoid()
if return_as_dict or return_as_df:
# list of dictionaries
batch_list = []
for i in range(probas.shape[0]):
dict_for_batch = {}
probas_i = probas[i]
for each_class in range(probas_i.shape[1]):
dict_for_batch[self.index2label[each_class]] = probas_i[
:, each_class
]
if return_as_df:
batch_list.append(
pd.DataFrame(
{k: v.cpu().numpy() for k, v in dict_for_batch.items()}
)
)
else:
batch_list.append(dict_for_batch)
return batch_list
if return_as_list:
# returns list of list of lists of strings
# innermost list - list of strings for each organ present based on threshold
# inner list - list of above for each slice
# outer list - list of above for each batch element (studies)
batch_list = []
# probas.shape = (batch_size, num_slices, num_classes)
for i in range(probas.shape[0]):
probas_i = probas[i]
# probas_i.shape = (num_slices, num_classes)
list_for_batch = []
for each_slice in range(probas_i.shape[0]):
for each_class in range(probas_i.shape[1]):
list_for_batch.append(
[
self.index2label[each_class]
for each_class in range(probas_i.shape[1])
if probas_i[each_slice, each_class] >= threshold
]
)
batch_list.append(list_for_batch)
return batch_list
return probas
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# [0, 255] -> [-1, 1]
mini, maxi = 0.0, 255.0
x = (x - mini) / (maxi - mini)
x = (x - 0.5) * 2.0
return x
@staticmethod
def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]:
# applying windowing to CT
lower, upper = WL - WW // 2, WL + WW // 2
x = np.clip(x, lower, upper)
x = (x - lower) / (upper - lower)
return (x * 255.0).astype("uint8")
@staticmethod
def validate_windows_type(windows):
assert isinstance(windows, tuple) or isinstance(windows, list)
if isinstance(windows, tuple):
assert len(windows) == 2
assert [isinstance(_, int) for _ in windows]
elif isinstance(windows, list):
assert all([isinstance(_, tuple) for _ in windows])
assert all([len(_) == 2 for _ in windows])
assert all([isinstance(__, int) for _ in windows for __ in _])
@staticmethod
def determine_dicom_orientation(ds) -> int:
iop = ds.ImageOrientationPatient
# Calculate the direction cosine for the normal vector of the plane
normal_vector = np.cross(iop[:3], iop[3:])
# Determine the plane based on the largest component of the normal vector
abs_normal = np.abs(normal_vector)
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
return 0 # sagittal
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
return 1 # coronal
else:
return 2 # axial
def load_image_from_dicom(
self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None
) -> np.ndarray:
# windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH)
# or list of tuples if wishing to generate multi-channel image using
# > 1 window
if not _PYDICOM_AVAILABLE:
raise Exception("`pydicom` is not installed")
dicom = dcmread(path)
array = dicom.pixel_array.astype("float32")
m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept)
array = array * m + b
if windows is None:
return array
self.validate_windows_type(windows)
if isinstance(windows, tuple):
windows = [windows]
arr_list = []
for WL, WW in windows:
arr_list.append(self.window(array.copy(), WL, WW))
array = np.stack(arr_list, axis=-1)
if array.shape[-1] == 1:
array = np.squeeze(array, axis=-1)
return array
@staticmethod
def is_valid_dicom(
ds,
fname: str = "",
sort_by_instance_number: bool = False,
exclude_invalid_dicoms: bool = False,
) -> bool:
attributes = [
"pixel_array",
"RescaleSlope",
"RescaleIntercept",
]
if sort_by_instance_number:
attributes.append("InstanceNumber")
else:
attributes.append("ImagePositionPatient")
attributes.append("ImageOrientationPatient")
attributes_present = [hasattr(ds, attr) for attr in attributes]
valid = all(attributes_present)
if not valid and not exclude_invalid_dicoms:
raise Exception(
f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}"
)
return valid
@staticmethod
def most_common_element(lst):
return max(set(lst), key=lst.count)
@staticmethod
def center_crop_or_pad_borders(image, size):
height, width = image.shape[:2]
new_height, new_width = size
if new_height < height:
# crop top and bottom
crop_top = (height - new_height) // 2
crop_bottom = height - new_height - crop_top
image = image[crop_top:-crop_bottom]
elif new_height > height:
# pad top and bottom
pad_top = (new_height - height) // 2
pad_bottom = new_height - height - pad_top
image = np.pad(
image,
((pad_top, pad_bottom), (0, 0)),
mode="constant",
constant_values=0,
)
if new_width < width:
# crop left and right
crop_left = (width - new_width) // 2
crop_right = width - new_width - crop_left
image = image[:, crop_left:-crop_right]
elif new_width > width:
# pad left and right
pad_left = (new_width - width) // 2
pad_right = new_width - width - pad_left
image = np.pad(
image,
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
return image
def load_stack_from_dicom_folder(
self,
path: str,
windows: tuple[int, int] | list[tuple[int, int]] | None = None,
dicom_extension: str = ".dcm",
sort_by_instance_number: bool = False,
exclude_invalid_dicoms: bool = False,
fix_unequal_shapes: str = "crop_pad",
return_sorted_dicom_files: bool = False,
) -> np.ndarray | tuple[np.ndarray, list[str]]:
if not _PYDICOM_AVAILABLE:
raise Exception("`pydicom` is not installed")
dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}"))
if len(dicom_files) == 0:
raise Exception(
f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`"
)
dicoms = [dcmread(f) for f in dicom_files]
dicoms = [
(d, dicom_files[idx])
for idx, d in enumerate(dicoms)
if self.is_valid_dicom(
d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms
)
]
# handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True
# by only including valid DICOM filenames
dicom_files = [_[1] for _ in dicoms]
dicoms = [_[0] for _ in dicoms]
slices = [dcm.pixel_array.astype("float32") for dcm in dicoms]
shapes = np.stack([s.shape for s in slices], axis=0)
if not np.all(shapes == shapes[0]):
unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True)
standard_shape = tuple(unique_shapes[np.argmax(counts)])
print(
f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}"
)
if fix_unequal_shapes == "crop_pad":
slices = [
self.center_crop_or_pad_borders(s, standard_shape)
if s.shape != standard_shape
else s
for s in slices
]
elif fix_unequal_shapes == "resize":
slices = [
cv2.resize(s, standard_shape) if s.shape != standard_shape else s
for s in slices
]
slices = np.stack(slices, axis=0)
# find orientation
orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms]
# use most common
orientation = self.most_common_element(orientation)
# sort using ImagePositionPatient
# orientation is index to use for sorting
if sort_by_instance_number:
positions = [float(d.InstanceNumber) for d in dicoms]
else:
positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms]
indices = np.argsort(positions)
slices = slices[indices]
# rescale
m, b = (
[float(d.RescaleSlope) for d in dicoms],
[float(d.RescaleIntercept) for d in dicoms],
)
m, b = self.most_common_element(m), self.most_common_element(b)
slices = slices * m + b
if windows is not None:
self.validate_windows_type(windows)
if isinstance(windows, tuple):
windows = [windows]
arr_list = []
for WL, WW in windows:
arr_list.append(self.window(slices.copy(), WL, WW))
slices = np.stack(arr_list, axis=-1)
if slices.shape[-1] == 1:
slices = np.squeeze(slices, axis=-1)
if return_sorted_dicom_files:
return slices, [dicom_files[idx] for idx in indices]
return slices
def preprocess(
self,
x: np.ndarray,
mode: str = "2d",
torchify: bool = True,
add_batch_dim: bool = False,
device: str | torch.device | None = None,
) -> np.ndarray:
if device is not None:
assert torchify, "`torchify` must be `True` if specifying `device`"
mode = mode.lower()
if mode == "2d":
x = cv2.resize(x, self.image_size)
if x.ndim == 2:
x = x[:, :, np.newaxis]
elif mode == "3d":
x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0)
if x.ndim == 3:
x = x[:, :, :, np.newaxis]
if torchify:
if x.ndim == 3:
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
elif x.ndim == 4:
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
if add_batch_dim:
if torchify:
x = x.unsqueeze(0)
else:
x = x[np.newaxis]
if device is not None:
x = x.to(device)
return x
def crop_single_plane(
self,
x: np.ndarray,
device: str | torch.device,
organ: str | list[str],
threshold: float = 0.5,
buffer: float | int = 0,
speed_up: str | None = None,
) -> np.ndarray:
num_slices = x.shape[0]
if speed_up is not None:
assert speed_up in ["fast", "faster", "fastest"]
if speed_up == "fast":
# 75% of slices
reduce_num_slices = 3 * num_slices // 4
elif speed_up == "faster":
# 50% of slices
reduce_num_slices = num_slices // 2
elif speed_up == "fastest":
# 33% of slices
reduce_num_slices = num_slices // 3
indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int)
x = x[indices]
x = self.preprocess(x, mode="3d")
x = torch.from_numpy(x)
x = rearrange(x, "n h w c -> n c h w").float().to(device)
x = rearrange(x, "n c h w -> 1 n c h w")
if x.size(2) > 1:
# if multi-channel, take mean
x = x.mean(2, keepdim=True)
organ_cls = self.forward(x)[0]
if speed_up is not None:
# organ_cls.shape = (num_slices, num_classes)
organ_cls = (
F.interpolate(
organ_cls.transpose(1, 0).unsqueeze(0),
size=(num_slices,),
mode="linear",
)
.squeeze(0)
.transpose(1, 0)
)
assert organ_cls.shape[0] == num_slices
slices = []
for each_organ in organ:
slices.append(
torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0]
)
slices = torch.cat(slices)
slice_min, slice_max = slices.min().item(), slices.max().item()
if buffer > 0:
if isinstance(buffer, float):
# % buffer
diff = slice_max - slice_min
buf = int(buffer * diff)
else:
# absolute slice buffer
buf = buffer
slice_min = max(0, slice_min - buf)
slice_max = min(num_slices - 1, slice_max + buf)
return slice_min, slice_max
@torch.no_grad()
def crop(
self,
x: np.ndarray,
organ: str | list[str],
crop_dims: int | list[int] = 0,
device: str | torch.device | None = None,
raw_hu: bool = False,
threshold: float = 0.5,
buffer: float | int = 0,
speed_up: str | None = None,
) -> (
np.ndarray
| tuple[np.ndarray, list[int]]
| tuple[np.ndarray, list[int], list[int]]
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
assert isinstance(x, np.ndarray)
assert x.ndim in {
3,
4,
}, f"x should be a 3D or 4D array, but got {x.ndim} dimensions"
if raw_hu:
# if input is in Hounsfield units, apply soft tissue window
x = self.window(x, WL=50, WW=400)
x0 = x
if not isinstance(organ, list):
organ = [organ]
if not isinstance(crop_dims, list):
crop_dims = [crop_dims]
assert max(crop_dims) <= 2
assert min(crop_dims) >= 0
if isinstance(buffer, float):
# percentage of cropped axis dimension
assert buffer < 1
if 0 in crop_dims:
smin0, smax0 = self.crop_single_plane(
x0, device, organ, threshold, buffer, speed_up
)
else:
smin0, smax0 = 0, x0.shape[0]
if 1 in crop_dims:
# swap plane
x = x0.swapaxes(1, 0)
smin1, smax1 = self.crop_single_plane(
x, device, organ, threshold, buffer, speed_up
)
else:
smin1, smax1 = 0, x0.shape[1]
if 2 in crop_dims:
# swap plane
x = x0.swapaxes(2, 0)
smin2, smax2 = self.crop_single_plane(
x, device, organ, threshold, buffer, speed_up
)
else:
smin2, smax2 = 0, x0.shape[2]
return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1]