ray.data.preprocessors.TorchVisionPreprocessor
ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[Union[np.ndarray, torch.Tensor]], torch.Tensor], output_columns: Optional[List[str]] = None, batched: bool = False)[source]#
Bases:
ray.data.preprocessor.PreprocessorApply a TorchVision transform to image columns.
Examples
Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add
ToTensorto your preprocessing pipeline.from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://[email protected]/imagenet-sample-images") dataset = preprocessor.transform(dataset)
For better performance, set
batchedtoTrueand replaceToTensorwith a batch-supportingLambda.import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://[email protected]/imagenet-sample-images") dataset = preprocessor.transform(dataset)
- Parameters
columns – The columns to apply the TorchVision transform to.
transform – The TorchVision transform you want to apply. This transform should accept a
np.ndarrayortorch.Tensoras input and return atorch.Tensoras output.output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns.
batched – If
True, applytransformto batches of shape \((B, H, W, C)\). Otherwise, applytransformto individual images.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
fit(ds)Fit this Preprocessor to the Dataset.
fit_transform(ds)Fit this Preprocessor to the Dataset and then transform the Dataset.
transform(ds)Transform the given dataset.
transform_batch(data)Transform a single batch of data.
Return Dataset stats for the most recent transform call, if any.