Porting a Legacy Remote Sensing CNN

From Theano/Lasagna to PyTorch — A Practical Case Study

Gabriel Oduori

Why This Matters

Theano was deprecated in 2017. Lasagna is unmaintained.

Thousands of published papers in remote sensing, medical imaging, and signal processing have code that no longer runs.

Porting gives you:

  • Reproducibility — results others can actually verify
  • Extensibility — build on top of the work
  • Understanding — you cannot port what you do not understand

Note

The hardest bugs in porting are not in the architecture — they are in the data pipeline and normalisation.

The Paper — What Are We Porting?

Scarpa et al. 2018A CNN-Based Fusion Method for Feature Extraction from Sentinel Data, Remote Sensing 10(2), 236.

The problem: optical satellites cannot see through clouds.

  t−1  →  cloud-free NDVI image   F₋     ✓
  t    →  CLOUDY — NDVI unknown   F = ?  ✗
  t+1  →  cloud-free NDVI image   F₊     ✓

SAR (radar) satellites penetrate clouds. They image the surface at all three dates regardless of weather.

Goal: use the SAR image at t — which we have — to reconstruct the NDVI image at t — which clouds stole from us.

Why SAR Helps — The Harvest Example

Suppose a crop field is harvested between t−1 and t+1:

NDVI SAR (VH) Meaning
t−1 0.60 0.12 Dense crop
t ? 0.01 Harvest happened — near zero VH = bare soil
t+1 0.10 0.02 Bare soil

Linear interpolation (baseline):

\[\hat{F}(t) = \frac{F_- + F_+}{2} = \frac{0.60 + 0.10}{2} = 0.35 \quad \text{✗ wrong}\]

CNN fusion sees VH = 0.01 at t → bare soil signature → predicts ≈ 0.10 ✓

What SAR Actually Measures

The satellite sends a microwave pulse and listens to the echo.

VV polarisation Send vertical → receive vertical

Bounces off flat surfaces: bare soil, calm water, roads

Range: [0, 0.3] linear σ⁰

VH polarisation Send vertical → receive horizontal

Only happens when the pulse scatters inside a volume: plant canopy, forest

High VH = dense leafy canopy Low VH = bare soil or water

Important

VH and NDVI are correlated — both respond to vegetation biomass — but through different physics. SAR measures canopy geometry. NDVI measures chlorophyll content. That partial decorrelation is what makes fusion valuable.

The Input Stack — All 9 Channels

Channel  Content           Sensor        Scale
───────────────────────────────────────────────────
  0      VH at t−1         Sentinel-1    [0, 0.3]  as-is
  1      VV at t−1         Sentinel-1    [0, 0.3]  as-is
  2      VH at t           Sentinel-1    [0, 0.3]  as-is  ← key: SAR at target date
  3      VV at t           Sentinel-1    [0, 0.3]  as-is
  4      VV at t+1         Sentinel-1    [0, 0.3]  as-is
  5      VH at t+1         Sentinel-1    [0, 0.3]  as-is
  6      NDVI at t−1       Sentinel-2    /1000     → [−1, 1]
  7      NDVI at t+1       Sentinel-2    /1000     → [−1, 1]
  8      DEM               SRTM          as-is     [0, 1]

Channels 2–3 are the SAR images at the target date — the cloud-free observation that the optical sensor cannot make. This is the information the baseline interpolation never has access to.

Original Stack vs PyTorch Stack

Component Theano / Lasagna PyTorch
Computation Symbolic graph, compiled Eager execution
Layers lasagna.layers.* nn.Module
Training theano.function() Python for loop
Debugging Print symbolic tensor shapes print() anywhere
Status Dead since 2017 Actively maintained

The good news: CNN architectures map almost 1:1. The challenges are always in the data pipeline and normalisation.

Porting the Model — The Easy Part

Lasagna (original)

l_in   = InputLayer(shape=(None, bands, 33, 33))
l_conv1 = Conv2DLayer(l_in,   48, (9,9), pad='valid', nonlinearity=rectify)
l_conv2 = Conv2DLayer(l_conv1, 32, (5,5), pad='valid', nonlinearity=rectify)
l_out   = Conv2DLayer(l_conv2,  1, (5,5), pad='valid', nonlinearity=linear)

PyTorch (ported)

class FusionCNN(nn.Module):
    def __init__(self, bands: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(bands, 48, kernel_size=9, padding=0),  # 'valid'
            nn.ReLU(),
            nn.Conv2d(48,    32, kernel_size=5, padding=0),
            nn.ReLU(),
            nn.Conv2d(32,     1, kernel_size=5, padding=0),
        )

padding=0 in PyTorch is valid convolution. Direct translation.

Porting the Architecture — Valid Convolutions

Why does padding=0 matter? With no padding, the output shrinks:

Input          (bands, 33, 33)
After Conv 9×9    (48, 25, 25)   shrinks by (9−1)/2 = 4 on each side
After Conv 5×5    (32, 21, 21)   shrinks by (5−1)/2 = 2 on each side
After Conv 5×5     (1, 17, 17)   shrinks by 2 again

Border formula carried over directly from the paper:

def compute_border(k1=9, k2=5, k3=5) -> int:
    return ((k1-1) + (k2-1) + (k3-1)) // 2   # = 8

A 33×33 input patch → 17×17 prediction. The ground truth target during training is the central 17×17 pixels of the patch.

Note

Valid convolutions mean every output pixel is computed from real data only — no artificial zeros at the boundary. Important for field edges and water bodies.

Porting the Data Pipeline — The Hard Part

Legacy approach (Theano era):

# Load everything into RAM upfront
X_train = np.zeros((N, bands, 33, 33), dtype=np.float32)
Y_train = np.zeros((N, 1,     17, 17), dtype=np.float32)
# ... fill arrays, then pass to Theano

PyTorch approach:

class SentinelPatchDataset(Dataset):
    def __init__(self, ...):
        # Only store patch COORDINATES, not the patches themselves
        self._index: List[Tuple[int, int, int]] = []  # (scene, y, x)

    def __getitem__(self, idx):
        # Extract patch ON DEMAND when DataLoader requests it
        scene_idx, y0, x0 = self._index[idx]
        return _make_patch(self._scenes[scene_idx], y0, x0, ...)

DataLoader handles batching, shuffling, and multiprocessing for free. No memory explosion on full scenes (5253 × 4797 pixels).

The Normalisation Bug — The Interesting Part

The original code divided all channels by 1000:

# Original intent: normalise NDVI (stored as NDVI × 1000)
x = x / 1000.0    # applied to ALL channels

This is correct for NDVI channels (350 → 0.35). But look at SAR values:

Channel Raw range After ÷1000 Effect
VH, VV (SAR) [0, 0.3] [0, 0.0003] Invisible to network
NDVI [−74, 800] [−0.07, 0.8] Correct
DEM [0, 1] [0, 0.001] Invisible to network

The network was training as optical-only without knowing it. The entire multi-sensor fusion was silently disabled.

The Fix — Per-Channel Normalisation

def __getitem__(self, idx):
    x_raw, y = self._get_raw_item(idx)

    x_norm = x_raw.copy()
    x_norm[6] = x_raw[6] / NDVI_SCALE   # NDVI₀ (F₋): stored × 1000 → actual
    x_norm[7] = x_raw[7] / NDVI_SCALE   # NDVI₂ (F₊): stored × 1000 → actual
    # SAR channels 0–5: used as-is  (already ≈ [0, 0.3])
    # DEM channel   8:  used as-is  (already in [0, 1])

    x = select_channels(x_norm[None, ...], self.identity)[0]
    y_t = torch.from_numpy(y).float() / NDVI_SCALE
    return torch.from_numpy(x).float(), y_t

NDVI_SCALE = 1000.0 lives in src/config.py — one place to change it.

Important

Lesson: porting forces you to understand the original code rather than just translate it. Silent numerical bugs do not raise exceptions.

The torch.compile Gotcha

PyTorch 2.0 introduced torch.compile — wraps the model in a compiled version that runs ~1.4× faster on CPU.

model = FusionCNN(bands=bands)
model = torch.compile(model)   # PyTorch 2.0+

The bug: torch.compile prefixes all weight keys with _orig_mod.:

# Before compile:   net.0.weight
# After compile:    _orig_mod.net.0.weight

Saving model.state_dict() after compiling → loading fails on a fresh model.

The fix:

base = model._orig_mod if hasattr(model, "_orig_mod") else model
torch.save({"model_state_dict": base.state_dict(), ...}, path)

Small fix. Took a while to track down.

What the Port Enabled

Running python run_train.py now executes the full pipeline:

Stage 1 — Train      SGD + momentum 0.9, MAE loss, torch.compile speedup
             ↓
Stage 2 — Predict    inference on the 470×450 test region → GeoTIFF
             ↓
Stage 3 — Evaluate   MAE / RMSE / CC vs ground-truth NDVI
             ↓
Stage 4 — Baseline   same metrics for linear interpolation F̂ = ½(F₋+F₊)
             ↓
Stage 5 — Summary    % gain of CNN over baseline
             ↓
Stage 6 — Visualise  loss curves + NDVI map comparison PNGs

All outputs go to a timestamped folder — results are never overwritten.

The 8 Model Variants — One Port, Eight Experiments

Same architecture, different input stacks. Only select_channels() changes:

Variant Inputs Tests
SAR VH, VV at t SAR only
OPTI F₋ Optical causal only
OPTII F₋, F₊ Optical both sides
SOPTI S₋, S, F₋ SAR + 1 optical
SOPTIIp S₋, S, S₊, F₋, F₊, DEM Full fusion

Comparing MAE/RMSE across variants directly measures the marginal value of each information source — SAR, DEM, temporal context.

Lessons Learned

  1. Read the paper, not just the code. The code may have bugs the paper does not.
  2. Normalisation is the most common silent bug. Check every channel’s scale before and after.
  3. PyTorch’s eager execution makes debugging easier. You can print() or pdb anywhere — no symbolic graph to reason about.
  4. Write tests after porting. Use the paper’s numbers as ground truth for expected outputs.
  5. Port the simplest variant first. OPTI (1 channel) is trivial to verify; SOPTIIp (9 channels) hides bugs.
  6. New framework features (torch.compile) have edge cases. Read the docs for anything touching model serialisation.

The Broader Point

Any published CNN from 2014–2018 using Theano, Lasagna, or early Caffe can be ported using the same approach:

Almost always clean:

  • Layer definitions
  • Activation functions
  • Loss functions
  • Optimiser hyperparameters

Always check carefully:

  • Data loading and normalisation
  • Train/val split logic
  • Mask and boundary handling
  • Checkpoint saving/loading

The reward is a maintainable, testable, extensible implementation that others can actually run — and that you can build on for your own research.

References

Scarpa G., Gargiulo M., Mazza A., Gaetano R. (2018). A CNN-Based Fusion Method for Feature Extraction from Sentinel Data. Remote Sensing, 10(2), 236. https://doi.org/10.3390/rs10020236

Repository:

git clone <repo>
pip install torch numpy scipy matplotlib
conda install -c conda-forge gdal

python run_train.py   # full pipeline, one command

Note

Code, tests, and this presentation live in the same repository.