Avatarr05 commited on
Commit
de1b7f1
·
verified ·
1 Parent(s): 38f1e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -296
app.py CHANGED
@@ -1,297 +1,298 @@
1
- import pandas as pd
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- from scipy.signal import savgol_filter
5
- import rasterio
6
- import multiprocessing
7
- import time
8
- import torch
9
- from pickle import load
10
- import warnings
11
-
12
- import gradio as gr
13
- import os
14
-
15
- from matplotlib.pyplot import figure
16
- from mpl_toolkits.axes_grid1 import make_axes_locatable
17
- import matplotlib.ticker as ticker
18
- from matplotlib.animation import FuncAnimation
19
- from matplotlib import rc
20
-
21
- from rasterio.plot import show
22
-
23
- warnings.filterwarnings("ignore")
24
-
25
- rc('animation', html='jshtml')
26
-
27
-
28
- # ---------------------------
29
- # Trait list (unchanged)
30
- # ---------------------------
31
- Traits = ["cab", "cw", "cm", "LAI", "cp", "cbc", "car", "anth"]
32
-
33
- # ---------------------------
34
- # Spectral preprocessing
35
- # ---------------------------
36
- def filter_segment(features_noWtab, order=1, der=False):
37
- part1 = features_noWtab.copy()
38
- if der:
39
- fr1 = savgol_filter(part1, 65, 1, deriv=1)
40
- else:
41
- fr1 = savgol_filter(part1, 65, order)
42
- return pd.DataFrame(data=fr1, columns=part1.columns)
43
-
44
- def feature_preparation(features, inval=[1351,1431,1801,2051], frmax=2451, order=1, der=False):
45
- other = features.copy()
46
- other.columns = other.columns.astype('int')
47
- other[other < 0] = np.nan
48
- other[other > 1] = np.nan
49
- other = (other.ffill() + other.bfill())/2
50
- other = other.interpolate(method='linear', axis=1, limit_direction='both')
51
-
52
- wt_ab = [i for i in range(inval[0],inval[1])] + [i for i in range(inval[2],inval[3])] + [i for i in range(2451,2501)]
53
- features_noWtab = other.drop(wt_ab, axis=1)
54
-
55
- fr1 = filter_segment(features_noWtab.loc[:,:inval[0]-1], order=order, der=der)
56
- fr2 = filter_segment(features_noWtab.loc[:,inval[1]:inval[2]-1], order=order, der=der)
57
- fr3 = filter_segment(features_noWtab.loc[:,inval[3]:frmax], order=order, der=der)
58
-
59
- inter = pd.concat([fr1,fr2,fr3], axis=1, join='inner')
60
- inter[inter<0]=0
61
- return inter
62
-
63
- def plot_fig(features, save=False, file=None, figsize=(15,10)):
64
- plt.figure(figsize=figsize)
65
- plt.plot(features.T)
66
- plt.ylim(0, features.max().max())
67
- if save:
68
- plt.savefig(file + '.pdf', bbox_inches='tight', dpi=1000)
69
- plt.savefig(file + '.svg', bbox_inches='tight', dpi=1000)
70
- plt.show()
71
-
72
- # ---------------------------
73
- # Image handling
74
- # ---------------------------
75
- def image_processing(enmap_im_path, bands_path):
76
- bands = pd.read_csv(bands_path)['bands'].astype(float)
77
- src = rasterio.open(enmap_im_path)
78
- array = src.read()
79
- sp_px = np.stack([array[i].reshape(-1,1) for i in range(array.shape[0])], axis=0)
80
- sp_px = np.swapaxes(sp_px.mean(axis=2),0,1)
81
- assert (sp_px.shape[1] == bands.shape[0]), "Mismatch between image bands and CSV bands!"
82
- df = pd.DataFrame(sp_px, columns=bands.to_list())
83
- df[df < df.quantile(0.01).min() + 10] = np.nan
84
- idx_null = df[df.T.isna().all()].index
85
- return src, df, idx_null
86
-
87
- def process_dataframe(veg_spec):
88
- veg_reindex = veg_spec.reindex(columns=sorted(veg_spec.columns.tolist() +
89
- [i for i in range(400,2501) if i not in veg_spec.columns.tolist()]))
90
- veg_reindex = veg_reindex/10000
91
- veg_reindex.columns = veg_reindex.columns.astype(int)
92
- inter = veg_reindex.loc[:,~veg_reindex.columns.duplicated()]
93
- inter = feature_preparation(veg_reindex, order=1)
94
- inter = inter.loc[:,~inter.columns.duplicated()]
95
- return inter.loc[:,400:]
96
-
97
- def transform_data(df):
98
- num_cpus = multiprocessing.cpu_count()
99
- df_chunks = [chunk for chunk in np.array_split(df, num_cpus)]
100
- print("Starting data transformation ...")
101
- with multiprocessing.Pool(num_cpus) as pool:
102
- results = pool.map(process_dataframe, df_chunks)
103
- pool.close(); pool.join()
104
- df_transformed = pd.concat(results).reset_index(drop=True)
105
- print("Transformation complete.")
106
- return df_transformed
107
-
108
- # ---------------------------
109
- # Model loading (PyTorch)
110
- # ---------------------------
111
- def load_model(dir_data, gp=None):
112
- """
113
- Loads a PyTorch model and its associated scaler from a directory.
114
- Replaces the original TensorFlow-based loading logic.
115
- """
116
- model_path = os.path.join(dir_data, "model.pt")
117
- scaler_path = os.path.join(dir_data, "scaler_global.pkl")
118
-
119
- if not os.path.exists(model_path):
120
- raise FileNotFoundError(f"Model weights not found in {dir_data}")
121
-
122
- model = torch.load(model_path, map_location="cpu")
123
- model.eval()
124
-
125
- if os.path.exists(scaler_path):
126
- scaler_list = load(open(scaler_path, "rb"))
127
- else:
128
- scaler_list = None
129
-
130
- return model, scaler_list
131
-
132
- # ---------------------------
133
- # Visualization utilities
134
- # ---------------------------
135
- def animation_preds(src, preds_tr, Traits=Traits):
136
- from matplotlib.animation import FuncAnimation
137
- import matplotlib.ticker as ticker
138
-
139
- def update(frame):
140
- tr = frame
141
- preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
142
- preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
143
- flag = np.array(preds_vis)
144
- maxv = pd.DataFrame(flag).max().max()
145
- minv = pd.DataFrame(flag).min().min()
146
- pred_im.set_array(preds_tr_.values.reshape(src.shape[0], src.shape[1]))
147
- pred_im.set_clim(vmin=minv, vmax=maxv)
148
- ax2.set_title(f"{Traits[tr]} map")
149
- return pred_im
150
-
151
- plt.rc('font', size=3)
152
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 2), dpi=300,
153
- sharex=True, sharey=True,
154
- gridspec_kw={'width_ratios': [1, 1.09]})
155
-
156
- nir = src.read(72)/10000
157
- red = src.read(47)/10000
158
- green = src.read(28)/10000
159
- blue = src.read(6)/10000
160
- nrg = np.dstack((nir, red, green))
161
- ax1.imshow(nrg)
162
-
163
- tr = 0
164
- preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
165
- preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
166
- flag = np.array(preds_vis)
167
- maxv = pd.DataFrame(flag).max().max()
168
- minv = pd.DataFrame(flag).min().min()
169
-
170
- pred_im = ax2.imshow(preds_tr_.values.reshape(src.shape[0], src.shape[1]), vmin=minv, vmax=maxv)
171
- plt.colorbar(pred_im, ax=ax2, fraction=0.04, pad=0.04)
172
-
173
- ax1.set(title="Original scene (False Color)")
174
- ax2.set(title=f"{Traits[tr]} map")
175
- for ax in (ax1, ax2):
176
- ax.set_aspect("equal")
177
- ax.axis("off")
178
- ax.xaxis.set_major_locator(ticker.NullLocator())
179
- ax.yaxis.set_major_locator(ticker.NullLocator())
180
-
181
- animation = FuncAnimation(fig, update, frames=range(1, 20), interval=1000)
182
- animation.save("Traits_predictions.gif")
183
- return "Traits_predictions.gif"
184
-
185
- def geo_tiff_save(src, preds):
186
- size = (src.height, src.width, preds.shape[1])
187
- new_image_path = "./twentyTraitPredictions.tif"
188
- with rasterio.open(
189
- new_image_path, "w",
190
- driver="GTiff",
191
- width=size[1], height=size[0],
192
- count=size[2], dtype="float32",
193
- crs=src.crs, transform=src.transform
194
- ) as new_image:
195
- for i in range(1, size[2] + 1):
196
- array_data = np.array(preds.loc[:, i-1]).reshape((src.height, src.width))
197
- new_image.write(array_data, i)
198
- return new_image_path
199
-
200
-
201
- # -------------------------------
202
- # Model configuration
203
- # -------------------------------
204
- repo_id = "Avatarr05/Multi-trait_SSL"
205
-
206
- # Map of available pretrained weights in your repo
207
- model_file_map = {
208
- ("MAE", "Full Range"): "mae/MAE_FR_400-2449_FT_155.pt",
209
- ("MAE", "Half Range"): "mae/MAE_HR_VNIR_400-899_FT_155.pt",
210
- ("GAN", "Full Range"): "Gans_models/checkpoints_GanFR_seed140/best_model.pt",
211
- ("GAN", "Half Range"): "Gans_models/checkpoints_GanHR_seed140/best_model.pt",
212
- }
213
-
214
- _model_cache = {}
215
-
216
-
217
- def load_pretrained_model(model_name, range_type):
218
- """Downloads and loads pretrained weights and associated scaler."""
219
- key = (model_name, range_type)
220
- if key in _model_cache:
221
- return _model_cache[key]
222
-
223
- if key not in model_file_map:
224
- raise ValueError(f"No pretrained weights found for {model_name} ({range_type})")
225
-
226
- model_path = model_file_map[key]
227
- # Download from your Hugging Face repo
228
- file_path = hf_hub_download(repo_id=repo_id, filename=model_path)
229
-
230
- # Load PyTorch model and scaler
231
- best_model, scaler_list = load_model(os.path.dirname(file_path))
232
- _model_cache[key] = (best_model, scaler_list)
233
- return best_model, scaler_list
234
-
235
-
236
- # -------------------------------
237
- # Core function: regression + visualization
238
- # -------------------------------
239
- def apply_regression(input_image, input_csv, model_choice, range_choice):
240
- """
241
- Applies the pretrained model to the uploaded hyperspectral scene (.tif)
242
- and associated band CSV, using your original preprocessing + transformations.
243
- """
244
- # 1️⃣ Load model + scaler
245
- best_model, scaler_list = load_pretrained_model(model_choice, range_choice)
246
- best_model.eval()
247
-
248
- # 2️⃣ Preprocess input data (your unchanged pipeline)
249
- src, df, idx_null = image_processing(input_image, input_csv)
250
- df_transformed = transform_data(df)
251
-
252
- # 3️⃣ Run inference (PyTorch forward pass)
253
- with torch.no_grad():
254
- x = torch.tensor(df_transformed.values, dtype=torch.float32)
255
- tf_preds = best_model(x).numpy()
256
-
257
- # 4️⃣ Reverse scaling
258
- if scaler_list is not None:
259
- tf_preds = scaler_list.inverse_transform(tf_preds)
260
-
261
- # 5️⃣ Build prediction DataFrame
262
- preds = pd.DataFrame(tf_preds)
263
- preds.loc[idx_null] = np.nan
264
-
265
- # 6️⃣ Generate visualization and save GeoTIFF
266
- fig = animation_preds(src, preds)
267
- raster_path = geo_tiff_save(src, preds)
268
-
269
- return fig, raster_path
270
-
271
- # -------------------------------
272
- # Gradio interface
273
- # -------------------------------
274
- iface = gr.Interface(
275
- fn=apply_regression,
276
- inputs=[
277
- gr.File(type="filepath", label="Upload Hyperspectral Scene (.tif)"),
278
- gr.File(type="filepath", label="Upload Band Information (.csv)"),
279
- gr.Dropdown(["MAE", "GAN"], label="Select Model Type"),
280
- gr.Radio(["Full Range", "Half Range"], label="Scene Range"),
281
- ],
282
- outputs=[
283
- gr.Image(label="Predicted Trait Maps (Animation)", show_download_button=False),
284
- gr.File(label="Download Predicted GeoTIFF"),
285
- ],
286
- title="🛰️ Multi-Trait Prediction from Hyperspectral Scenes (PyTorch)",
287
- description=(
288
- "Upload your hyperspectral scene (.tif) and its corresponding CSV file. "
289
- "The selected pretrained model will process the data, predict multiple traits, "
290
- "and generate both an animated visualization and a downloadable GeoTIFF."
291
- ),
292
- # article=copyright_html,
293
- theme="soft",
294
- )
295
-
296
- # Launch the Gradio app
 
297
  iface.launch() #share=False
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from scipy.signal import savgol_filter
5
+ import rasterio
6
+ import multiprocessing
7
+ import time
8
+ import torch
9
+ from pickle import load
10
+ import warnings
11
+
12
+ import gradio as gr
13
+ import os
14
+
15
+ from matplotlib.pyplot import figure
16
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
17
+ import matplotlib.ticker as ticker
18
+ from matplotlib.animation import FuncAnimation
19
+ from matplotlib import rc
20
+
21
+ from rasterio.plot import show
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ warnings.filterwarnings("ignore")
25
+
26
+ rc('animation', html='jshtml')
27
+
28
+
29
+ # ---------------------------
30
+ # Trait list (unchanged)
31
+ # ---------------------------
32
+ Traits = ["cab", "cw", "cm", "LAI", "cp", "cbc", "car", "anth"]
33
+
34
+ # ---------------------------
35
+ # Spectral preprocessing
36
+ # ---------------------------
37
+ def filter_segment(features_noWtab, order=1, der=False):
38
+ part1 = features_noWtab.copy()
39
+ if der:
40
+ fr1 = savgol_filter(part1, 65, 1, deriv=1)
41
+ else:
42
+ fr1 = savgol_filter(part1, 65, order)
43
+ return pd.DataFrame(data=fr1, columns=part1.columns)
44
+
45
+ def feature_preparation(features, inval=[1351,1431,1801,2051], frmax=2451, order=1, der=False):
46
+ other = features.copy()
47
+ other.columns = other.columns.astype('int')
48
+ other[other < 0] = np.nan
49
+ other[other > 1] = np.nan
50
+ other = (other.ffill() + other.bfill())/2
51
+ other = other.interpolate(method='linear', axis=1, limit_direction='both')
52
+
53
+ wt_ab = [i for i in range(inval[0],inval[1])] + [i for i in range(inval[2],inval[3])] + [i for i in range(2451,2501)]
54
+ features_noWtab = other.drop(wt_ab, axis=1)
55
+
56
+ fr1 = filter_segment(features_noWtab.loc[:,:inval[0]-1], order=order, der=der)
57
+ fr2 = filter_segment(features_noWtab.loc[:,inval[1]:inval[2]-1], order=order, der=der)
58
+ fr3 = filter_segment(features_noWtab.loc[:,inval[3]:frmax], order=order, der=der)
59
+
60
+ inter = pd.concat([fr1,fr2,fr3], axis=1, join='inner')
61
+ inter[inter<0]=0
62
+ return inter
63
+
64
+ def plot_fig(features, save=False, file=None, figsize=(15,10)):
65
+ plt.figure(figsize=figsize)
66
+ plt.plot(features.T)
67
+ plt.ylim(0, features.max().max())
68
+ if save:
69
+ plt.savefig(file + '.pdf', bbox_inches='tight', dpi=1000)
70
+ plt.savefig(file + '.svg', bbox_inches='tight', dpi=1000)
71
+ plt.show()
72
+
73
+ # ---------------------------
74
+ # Image handling
75
+ # ---------------------------
76
+ def image_processing(enmap_im_path, bands_path):
77
+ bands = pd.read_csv(bands_path)['bands'].astype(float)
78
+ src = rasterio.open(enmap_im_path)
79
+ array = src.read()
80
+ sp_px = np.stack([array[i].reshape(-1,1) for i in range(array.shape[0])], axis=0)
81
+ sp_px = np.swapaxes(sp_px.mean(axis=2),0,1)
82
+ assert (sp_px.shape[1] == bands.shape[0]), "Mismatch between image bands and CSV bands!"
83
+ df = pd.DataFrame(sp_px, columns=bands.to_list())
84
+ df[df < df.quantile(0.01).min() + 10] = np.nan
85
+ idx_null = df[df.T.isna().all()].index
86
+ return src, df, idx_null
87
+
88
+ def process_dataframe(veg_spec):
89
+ veg_reindex = veg_spec.reindex(columns=sorted(veg_spec.columns.tolist() +
90
+ [i for i in range(400,2501) if i not in veg_spec.columns.tolist()]))
91
+ veg_reindex = veg_reindex/10000
92
+ veg_reindex.columns = veg_reindex.columns.astype(int)
93
+ inter = veg_reindex.loc[:,~veg_reindex.columns.duplicated()]
94
+ inter = feature_preparation(veg_reindex, order=1)
95
+ inter = inter.loc[:,~inter.columns.duplicated()]
96
+ return inter.loc[:,400:]
97
+
98
+ def transform_data(df):
99
+ num_cpus = multiprocessing.cpu_count()
100
+ df_chunks = [chunk for chunk in np.array_split(df, num_cpus)]
101
+ print("Starting data transformation ...")
102
+ with multiprocessing.Pool(num_cpus) as pool:
103
+ results = pool.map(process_dataframe, df_chunks)
104
+ pool.close(); pool.join()
105
+ df_transformed = pd.concat(results).reset_index(drop=True)
106
+ print("Transformation complete.")
107
+ return df_transformed
108
+
109
+ # ---------------------------
110
+ # Model loading (PyTorch)
111
+ # ---------------------------
112
+ def load_model(dir_data, gp=None):
113
+ """
114
+ Loads a PyTorch model and its associated scaler from a directory.
115
+ Replaces the original TensorFlow-based loading logic.
116
+ """
117
+ model_path = os.path.join(dir_data, "model.pt")
118
+ scaler_path = os.path.join(dir_data, "scaler_global.pkl")
119
+
120
+ if not os.path.exists(model_path):
121
+ raise FileNotFoundError(f"Model weights not found in {dir_data}")
122
+
123
+ model = torch.load(model_path, map_location="cpu")
124
+ model.eval()
125
+
126
+ if os.path.exists(scaler_path):
127
+ scaler_list = load(open(scaler_path, "rb"))
128
+ else:
129
+ scaler_list = None
130
+
131
+ return model, scaler_list
132
+
133
+ # ---------------------------
134
+ # Visualization utilities
135
+ # ---------------------------
136
+ def animation_preds(src, preds_tr, Traits=Traits):
137
+ from matplotlib.animation import FuncAnimation
138
+ import matplotlib.ticker as ticker
139
+
140
+ def update(frame):
141
+ tr = frame
142
+ preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
143
+ preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
144
+ flag = np.array(preds_vis)
145
+ maxv = pd.DataFrame(flag).max().max()
146
+ minv = pd.DataFrame(flag).min().min()
147
+ pred_im.set_array(preds_tr_.values.reshape(src.shape[0], src.shape[1]))
148
+ pred_im.set_clim(vmin=minv, vmax=maxv)
149
+ ax2.set_title(f"{Traits[tr]} map")
150
+ return pred_im
151
+
152
+ plt.rc('font', size=3)
153
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 2), dpi=300,
154
+ sharex=True, sharey=True,
155
+ gridspec_kw={'width_ratios': [1, 1.09]})
156
+
157
+ nir = src.read(72)/10000
158
+ red = src.read(47)/10000
159
+ green = src.read(28)/10000
160
+ blue = src.read(6)/10000
161
+ nrg = np.dstack((nir, red, green))
162
+ ax1.imshow(nrg)
163
+
164
+ tr = 0
165
+ preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr]))
166
+ preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)]
167
+ flag = np.array(preds_vis)
168
+ maxv = pd.DataFrame(flag).max().max()
169
+ minv = pd.DataFrame(flag).min().min()
170
+
171
+ pred_im = ax2.imshow(preds_tr_.values.reshape(src.shape[0], src.shape[1]), vmin=minv, vmax=maxv)
172
+ plt.colorbar(pred_im, ax=ax2, fraction=0.04, pad=0.04)
173
+
174
+ ax1.set(title="Original scene (False Color)")
175
+ ax2.set(title=f"{Traits[tr]} map")
176
+ for ax in (ax1, ax2):
177
+ ax.set_aspect("equal")
178
+ ax.axis("off")
179
+ ax.xaxis.set_major_locator(ticker.NullLocator())
180
+ ax.yaxis.set_major_locator(ticker.NullLocator())
181
+
182
+ animation = FuncAnimation(fig, update, frames=range(1, 20), interval=1000)
183
+ animation.save("Traits_predictions.gif")
184
+ return "Traits_predictions.gif"
185
+
186
+ def geo_tiff_save(src, preds):
187
+ size = (src.height, src.width, preds.shape[1])
188
+ new_image_path = "./twentyTraitPredictions.tif"
189
+ with rasterio.open(
190
+ new_image_path, "w",
191
+ driver="GTiff",
192
+ width=size[1], height=size[0],
193
+ count=size[2], dtype="float32",
194
+ crs=src.crs, transform=src.transform
195
+ ) as new_image:
196
+ for i in range(1, size[2] + 1):
197
+ array_data = np.array(preds.loc[:, i-1]).reshape((src.height, src.width))
198
+ new_image.write(array_data, i)
199
+ return new_image_path
200
+
201
+
202
+ # -------------------------------
203
+ # Model configuration
204
+ # -------------------------------
205
+ repo_id = "Avatarr05/Multi-trait_SSL"
206
+
207
+ # Map of available pretrained weights in your repo
208
+ model_file_map = {
209
+ ("MAE", "Full Range"): "mae/MAE_FR_400-2449_FT_155.pt",
210
+ ("MAE", "Half Range"): "mae/MAE_HR_VNIR_400-899_FT_155.pt",
211
+ ("GAN", "Full Range"): "Gans_models/checkpoints_GanFR_seed140/best_model.pt",
212
+ ("GAN", "Half Range"): "Gans_models/checkpoints_GanHR_seed140/best_model.pt",
213
+ }
214
+
215
+ _model_cache = {}
216
+
217
+
218
+ def load_pretrained_model(model_name, range_type):
219
+ """Downloads and loads pretrained weights and associated scaler."""
220
+ key = (model_name, range_type)
221
+ if key in _model_cache:
222
+ return _model_cache[key]
223
+
224
+ if key not in model_file_map:
225
+ raise ValueError(f"No pretrained weights found for {model_name} ({range_type})")
226
+
227
+ model_path = model_file_map[key]
228
+ # Download from your Hugging Face repo
229
+ file_path = hf_hub_download(repo_id=repo_id, filename=model_path)
230
+
231
+ # Load PyTorch model and scaler
232
+ best_model, scaler_list = load_model(os.path.dirname(file_path))
233
+ _model_cache[key] = (best_model, scaler_list)
234
+ return best_model, scaler_list
235
+
236
+
237
+ # -------------------------------
238
+ # Core function: regression + visualization
239
+ # -------------------------------
240
+ def apply_regression(input_image, input_csv, model_choice, range_choice):
241
+ """
242
+ Applies the pretrained model to the uploaded hyperspectral scene (.tif)
243
+ and associated band CSV, using your original preprocessing + transformations.
244
+ """
245
+ # 1️⃣ Load model + scaler
246
+ best_model, scaler_list = load_pretrained_model(model_choice, range_choice)
247
+ best_model.eval()
248
+
249
+ # 2️⃣ Preprocess input data (your unchanged pipeline)
250
+ src, df, idx_null = image_processing(input_image, input_csv)
251
+ df_transformed = transform_data(df)
252
+
253
+ # 3️⃣ Run inference (PyTorch forward pass)
254
+ with torch.no_grad():
255
+ x = torch.tensor(df_transformed.values, dtype=torch.float32)
256
+ tf_preds = best_model(x).numpy()
257
+
258
+ # 4️⃣ Reverse scaling
259
+ if scaler_list is not None:
260
+ tf_preds = scaler_list.inverse_transform(tf_preds)
261
+
262
+ # 5️⃣ Build prediction DataFrame
263
+ preds = pd.DataFrame(tf_preds)
264
+ preds.loc[idx_null] = np.nan
265
+
266
+ # 6️⃣ Generate visualization and save GeoTIFF
267
+ fig = animation_preds(src, preds)
268
+ raster_path = geo_tiff_save(src, preds)
269
+
270
+ return fig, raster_path
271
+
272
+ # -------------------------------
273
+ # Gradio interface
274
+ # -------------------------------
275
+ iface = gr.Interface(
276
+ fn=apply_regression,
277
+ inputs=[
278
+ gr.File(type="filepath", label="Upload Hyperspectral Scene (.tif)"),
279
+ gr.File(type="filepath", label="Upload Band Information (.csv)"),
280
+ gr.Dropdown(["MAE", "GAN"], label="Select Model Type"),
281
+ gr.Radio(["Full Range", "Half Range"], label="Scene Range"),
282
+ ],
283
+ outputs=[
284
+ gr.Image(label="Predicted Trait Maps (Animation)", show_download_button=False),
285
+ gr.File(label="Download Predicted GeoTIFF"),
286
+ ],
287
+ title="🛰️ Multi-Trait Prediction from Hyperspectral Scenes (PyTorch)",
288
+ description=(
289
+ "Upload your hyperspectral scene (.tif) and its corresponding CSV file. "
290
+ "The selected pretrained model will process the data, predict multiple traits, "
291
+ "and generate both an animated visualization and a downloadable GeoTIFF."
292
+ ),
293
+ # article=copyright_html,
294
+ theme="soft",
295
+ )
296
+
297
+ # Launch the Gradio app
298
  iface.launch() #share=False