chriswu25 commited on
Commit
687d3ce
·
verified ·
1 Parent(s): b4060af

Update src/lora_helper.py

Browse files
Files changed (1) hide show
  1. src/lora_helper.py +195 -195
src/lora_helper.py CHANGED
@@ -1,196 +1,196 @@
1
- from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
- from safetensors import safe_open
3
- import re
4
- import torch
5
- from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
-
7
- device = "cuda"
8
-
9
- def load_safetensors(path):
10
- tensors = {}
11
- with safe_open(path, framework="pt", device="cpu") as f:
12
- for key in f.keys():
13
- tensors[key] = f.get_tensor(key)
14
- return tensors
15
-
16
- def get_lora_rank(checkpoint):
17
- for k in checkpoint.keys():
18
- if k.endswith(".down.weight"):
19
- return checkpoint[k].shape[0]
20
-
21
- def load_checkpoint(local_path):
22
- if local_path is not None:
23
- if '.safetensors' in local_path:
24
- print(f"Loading .safetensors checkpoint from {local_path}")
25
- checkpoint = load_safetensors(local_path)
26
- else:
27
- print(f"Loading checkpoint from {local_path}")
28
- checkpoint = torch.load(local_path, map_location='cpu')
29
- return checkpoint
30
-
31
- def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32
- number = len(lora_weights)
33
- ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34
- lora_attn_procs = {}
35
- double_blocks_idx = list(range(19))
36
- single_blocks_idx = list(range(38))
37
- for name, attn_processor in transformer.attn_processors.items():
38
- match = re.search(r'\.(\d+)\.', name)
39
- if match:
40
- layer_index = int(match.group(1))
41
-
42
- if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43
-
44
- lora_state_dicts = {}
45
- for key, value in checkpoint.items():
46
- # Match based on the layer index in the key (assuming the key contains layer index)
47
- if re.search(r'\.(\d+)\.', key):
48
- checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49
- if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50
- lora_state_dicts[key] = value
51
-
52
- lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53
- dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54
- )
55
-
56
- # Load the weights from the checkpoint dictionary into the corresponding layers
57
- for n in range(number):
58
- lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59
- lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60
- lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61
- lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62
- lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63
- lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64
- lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65
- lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66
- lora_attn_procs[name].to(device)
67
-
68
- elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69
-
70
- lora_state_dicts = {}
71
- for key, value in checkpoint.items():
72
- # Match based on the layer index in the key (assuming the key contains layer index)
73
- if re.search(r'\.(\d+)\.', key):
74
- checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75
- if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76
- lora_state_dicts[key] = value
77
-
78
- lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79
- dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80
- )
81
- # Load the weights from the checkpoint dictionary into the corresponding layers
82
- for n in range(number):
83
- lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84
- lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85
- lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86
- lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87
- lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88
- lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89
- lora_attn_procs[name].to(device)
90
- else:
91
- lora_attn_procs[name] = FluxAttnProcessor2_0()
92
-
93
- transformer.set_attn_processor(lora_attn_procs)
94
-
95
-
96
- def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97
- ck_number = len(checkpoints)
98
- cond_lora_number = [len(ls) for ls in lora_weights]
99
- cond_number = sum(cond_lora_number)
100
- ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101
- multi_lora_weight = []
102
- for ls in lora_weights:
103
- for n in ls:
104
- multi_lora_weight.append(n)
105
-
106
- lora_attn_procs = {}
107
- double_blocks_idx = list(range(19))
108
- single_blocks_idx = list(range(38))
109
- for name, attn_processor in transformer.attn_processors.items():
110
- match = re.search(r'\.(\d+)\.', name)
111
- if match:
112
- layer_index = int(match.group(1))
113
-
114
- if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115
- lora_state_dicts = [{} for _ in range(ck_number)]
116
- for idx, checkpoint in enumerate(checkpoints):
117
- for key, value in checkpoint.items():
118
- # Match based on the layer index in the key (assuming the key contains layer index)
119
- if re.search(r'\.(\d+)\.', key):
120
- checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121
- if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122
- lora_state_dicts[idx][key] = value
123
-
124
- lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125
- dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126
- )
127
-
128
- # Load the weights from the checkpoint dictionary into the corresponding layers
129
- num = 0
130
- for idx in range(ck_number):
131
- for n in range(cond_lora_number[idx]):
132
- lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133
- lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134
- lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135
- lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136
- lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137
- lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138
- lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139
- lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140
- lora_attn_procs[name].to(device)
141
- num += 1
142
-
143
- elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144
-
145
- lora_state_dicts = [{} for _ in range(ck_number)]
146
- for idx, checkpoint in enumerate(checkpoints):
147
- for key, value in checkpoint.items():
148
- # Match based on the layer index in the key (assuming the key contains layer index)
149
- if re.search(r'\.(\d+)\.', key):
150
- checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151
- if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152
- lora_state_dicts[idx][key] = value
153
-
154
- lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155
- dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156
- )
157
- # Load the weights from the checkpoint dictionary into the corresponding layers
158
- num = 0
159
- for idx in range(ck_number):
160
- for n in range(cond_lora_number[idx]):
161
- lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162
- lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163
- lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164
- lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165
- lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166
- lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167
- lora_attn_procs[name].to(device)
168
- num += 1
169
-
170
- else:
171
- lora_attn_procs[name] = FluxAttnProcessor2_0()
172
-
173
- transformer.set_attn_processor(lora_attn_procs)
174
-
175
-
176
- def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177
- checkpoint = load_checkpoint(local_path)
178
- update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179
-
180
- def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181
- checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182
- update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183
-
184
- def unset_lora(transformer):
185
- lora_attn_procs = {}
186
- for name, attn_processor in transformer.attn_processors.items():
187
- lora_attn_procs[name] = FluxAttnProcessor2_0()
188
- transformer.set_attn_processor(lora_attn_procs)
189
-
190
-
191
- '''
192
- unset_lora(pipe.transformer)
193
- lora_path = "./lora.safetensors"
194
- lora_weights = [1, 1]
195
- set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196
  '''
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors import safe_open
3
+ import re
4
+ import torch
5
+ from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
+
7
+ device = "cuda"
8
+
9
+ def load_safetensors(path):
10
+ tensors = {}
11
+ with safe_open(path, framework="pt", device="cpu") as f:
12
+ for key in f.keys():
13
+ tensors[key] = f.get_tensor(key)
14
+ return tensors
15
+
16
+ def get_lora_rank(checkpoint):
17
+ for k in checkpoint.keys():
18
+ if k.endswith(".down.weight"):
19
+ return checkpoint[k].shape[0]
20
+
21
+ def load_checkpoint(local_path):
22
+ if local_path is not None:
23
+ if '.safetensors' in local_path:
24
+ print(f"Loading .safetensors checkpoint from {local_path}")
25
+ checkpoint = load_safetensors(local_path)
26
+ else:
27
+ print(f"Loading checkpoint from {local_path}")
28
+ checkpoint = torch.load(local_path, map_location='cpu')
29
+ return checkpoint
30
+
31
+ def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size, device="cpu"):
32
+ number = len(lora_weights)
33
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34
+ lora_attn_procs = {}
35
+ double_blocks_idx = list(range(19))
36
+ single_blocks_idx = list(range(38))
37
+ for name, attn_processor in transformer.attn_processors.items():
38
+ match = re.search(r'\.(\d+)\.', name)
39
+ if match:
40
+ layer_index = int(match.group(1))
41
+
42
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43
+
44
+ lora_state_dicts = {}
45
+ for key, value in checkpoint.items():
46
+ # Match based on the layer index in the key (assuming the key contains layer index)
47
+ if re.search(r'\.(\d+)\.', key):
48
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50
+ lora_state_dicts[key] = value
51
+
52
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54
+ )
55
+
56
+ # Load the weights from the checkpoint dictionary into the corresponding layers
57
+ for n in range(number):
58
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66
+ lora_attn_procs[name].to(device)
67
+
68
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69
+
70
+ lora_state_dicts = {}
71
+ for key, value in checkpoint.items():
72
+ # Match based on the layer index in the key (assuming the key contains layer index)
73
+ if re.search(r'\.(\d+)\.', key):
74
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76
+ lora_state_dicts[key] = value
77
+
78
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80
+ )
81
+ # Load the weights from the checkpoint dictionary into the corresponding layers
82
+ for n in range(number):
83
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89
+ lora_attn_procs[name].to(device)
90
+ else:
91
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
92
+
93
+ transformer.set_attn_processor(lora_attn_procs)
94
+
95
+
96
+ def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97
+ ck_number = len(checkpoints)
98
+ cond_lora_number = [len(ls) for ls in lora_weights]
99
+ cond_number = sum(cond_lora_number)
100
+ ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101
+ multi_lora_weight = []
102
+ for ls in lora_weights:
103
+ for n in ls:
104
+ multi_lora_weight.append(n)
105
+
106
+ lora_attn_procs = {}
107
+ double_blocks_idx = list(range(19))
108
+ single_blocks_idx = list(range(38))
109
+ for name, attn_processor in transformer.attn_processors.items():
110
+ match = re.search(r'\.(\d+)\.', name)
111
+ if match:
112
+ layer_index = int(match.group(1))
113
+
114
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115
+ lora_state_dicts = [{} for _ in range(ck_number)]
116
+ for idx, checkpoint in enumerate(checkpoints):
117
+ for key, value in checkpoint.items():
118
+ # Match based on the layer index in the key (assuming the key contains layer index)
119
+ if re.search(r'\.(\d+)\.', key):
120
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122
+ lora_state_dicts[idx][key] = value
123
+
124
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126
+ )
127
+
128
+ # Load the weights from the checkpoint dictionary into the corresponding layers
129
+ num = 0
130
+ for idx in range(ck_number):
131
+ for n in range(cond_lora_number[idx]):
132
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138
+ lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139
+ lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140
+ lora_attn_procs[name].to(device)
141
+ num += 1
142
+
143
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144
+
145
+ lora_state_dicts = [{} for _ in range(ck_number)]
146
+ for idx, checkpoint in enumerate(checkpoints):
147
+ for key, value in checkpoint.items():
148
+ # Match based on the layer index in the key (assuming the key contains layer index)
149
+ if re.search(r'\.(\d+)\.', key):
150
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152
+ lora_state_dicts[idx][key] = value
153
+
154
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156
+ )
157
+ # Load the weights from the checkpoint dictionary into the corresponding layers
158
+ num = 0
159
+ for idx in range(ck_number):
160
+ for n in range(cond_lora_number[idx]):
161
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167
+ lora_attn_procs[name].to(device)
168
+ num += 1
169
+
170
+ else:
171
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
172
+
173
+ transformer.set_attn_processor(lora_attn_procs)
174
+
175
+
176
+ def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177
+ checkpoint = load_checkpoint(local_path)
178
+ update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179
+
180
+ def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181
+ checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182
+ update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183
+
184
+ def unset_lora(transformer):
185
+ lora_attn_procs = {}
186
+ for name, attn_processor in transformer.attn_processors.items():
187
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
188
+ transformer.set_attn_processor(lora_attn_procs)
189
+
190
+
191
+ '''
192
+ unset_lora(pipe.transformer)
193
+ lora_path = "./lora.safetensors"
194
+ lora_weights = [1, 1]
195
+ set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196
  '''