Skip to content

Commit 271112c

Browse files
authored
fix vits reduce_sum's input/output dtype, test=tts (#3028)
1 parent 34f2995 commit 271112c

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

paddlespeech/t2s/models/vits/duration_predictor.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,21 +155,19 @@ def forward(
155155
z_u, z1 = paddle.split(z_q, [1, 1], 1)
156156
u = F.sigmoid(z_u) * x_mask
157157
z0 = (w - u) * x_mask
158-
logdet_tot_q += paddle.sum(
159-
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
160-
logq = (paddle.sum(-0.5 *
161-
(math.log(2 * math.pi) +
162-
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
163-
158+
tmp1 = (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask
159+
logdet_tot_q += paddle.sum(tmp1, [1, 2])
160+
tmp2 = -0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask
161+
logq = (paddle.sum(tmp2, [1, 2]) - logdet_tot_q)
164162
logdet_tot = 0
165163
z0, logdet = self.log_flow(z0, x_mask)
166164
logdet_tot += logdet
167165
z = paddle.concat([z0, z1], 1)
168166
for flow in self.flows:
169167
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
170168
logdet_tot = logdet_tot + logdet
171-
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
172-
(z**2)) * x_mask, [1, 2]) - logdet_tot)
169+
tmp3 = 0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask
170+
nll = (paddle.sum(tmp3, [1, 2]) - logdet_tot)
173171
# (B,)
174172
return nll + logq
175173
else:

paddlespeech/t2s/models/vits/generator.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,9 @@ def forward(
371371
# (B, H, T_text)
372372
s_p_sq_r = paddle.exp(-2 * logs_p)
373373
# (B, 1, T_text)
374+
tmp1 = -0.5 * math.log(2 * math.pi) - logs_p
374375
neg_x_ent_1 = paddle.sum(
375-
-0.5 * math.log(2 * math.pi) - logs_p,
376+
tmp1,
376377
[1],
377378
keepdim=True, )
378379
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
@@ -384,8 +385,9 @@ def forward(
384385
z_p.transpose([0, 2, 1]),
385386
(m_p * s_p_sq_r), )
386387
# (B, 1, T_text)
388+
tmp2 = -0.5 * (m_p**2) * s_p_sq_r
387389
neg_x_ent_4 = paddle.sum(
388-
-0.5 * (m_p**2) * s_p_sq_r,
390+
tmp2,
389391
[1],
390392
keepdim=True, )
391393
# (B, T_feats, T_text)
@@ -403,7 +405,6 @@ def forward(
403405
w = attn.sum(2)
404406
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
405407
dur_nll = dur_nll / paddle.sum(x_mask)
406-
407408
# expand the length to match with the feature sequence
408409
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
409410
m_p = paddle.matmul(attn.squeeze(1),
@@ -511,8 +512,9 @@ def inference(
511512
# (B, H, T_text)
512513
s_p_sq_r = paddle.exp(-2 * logs_p)
513514
# (B, 1, T_text)
515+
tmp3 = -0.5 * math.log(2 * math.pi) - logs_p
514516
neg_x_ent_1 = paddle.sum(
515-
-0.5 * math.log(2 * math.pi) - logs_p,
517+
tmp3,
516518
[1],
517519
keepdim=True, )
518520
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
@@ -524,8 +526,9 @@ def inference(
524526
z_p.transpose([0, 2, 1]),
525527
(m_p * s_p_sq_r), )
526528
# (B, 1, T_text)
529+
tmp4 = -0.5 * (m_p**2) * s_p_sq_r
527530
neg_x_ent_4 = paddle.sum(
528-
-0.5 * (m_p**2) * s_p_sq_r,
531+
tmp4,
529532
[1],
530533
keepdim=True, )
531534
# (B, T_feats, T_text)

paddlespeech/t2s/models/vits/transform.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ def piecewise_rational_quadratic_transform(
6161

6262

6363
def mask_preprocess(x, mask):
64+
# bins.dtype = int32
6465
B, C, T, bins = paddle.shape(x)
65-
new_x = paddle.zeros([mask.sum(), bins])
66+
mask_int = paddle.cast(mask, dtype='int64')
67+
# paddle.sum 输入是 int32 或 bool 的时候,输出是 int64
68+
# paddle.zeros (fill_constant) 的 shape 会被强制转成 int32 类型
69+
new_x = paddle.zeros([paddle.sum(mask_int), bins])
6670
for i in range(bins):
6771
new_x[:, i] = x[:, :, :, i][mask]
6872
return new_x
@@ -240,4 +244,7 @@ def rational_quadratic_spline(
240244

241245
def _searchsorted(bin_locations, inputs, eps=1e-6):
242246
bin_locations[..., -1] += eps
243-
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
247+
mask = inputs[..., None] >= bin_locations
248+
mask_int = paddle.cast(mask, 'int64')
249+
out = paddle.sum(mask_int, axis=-1) - 1
250+
return out

0 commit comments

Comments
 (0)