|
3 | 3 |
|
4 | 4 | from model import StyledGenerator
|
5 | 5 |
|
6 |
| -generator = StyledGenerator(512).cuda() |
| 6 | +device = 'cuda' |
| 7 | + |
| 8 | +generator = StyledGenerator(512).to(device) |
7 | 9 | generator.load_state_dict(torch.load('checkpoint/180000.model'))
|
8 | 10 | generator.eval()
|
9 | 11 |
|
|
16 | 18 |
|
17 | 19 | with torch.no_grad():
|
18 | 20 | for i in range(10):
|
19 |
| - style = generator.mean_style(torch.randn(1024, 512).cuda()) |
| 21 | + style = generator.mean_style(torch.randn(1024, 512).to(device)) |
20 | 22 |
|
21 | 23 | if mean_style is None:
|
22 | 24 | mean_style = style
|
|
27 | 29 | mean_style /= 10
|
28 | 30 |
|
29 | 31 | image = generator(
|
30 |
| - torch.randn(15, 512).cuda(), |
| 32 | + torch.randn(15, 512).to(device), |
31 | 33 | step=step,
|
32 | 34 | alpha=alpha,
|
33 | 35 | mean_style=mean_style,
|
|
37 | 39 | utils.save_image(image, 'sample.png', nrow=5, normalize=True, range=(-1, 1))
|
38 | 40 |
|
39 | 41 | for j in range(20):
|
40 |
| - source_code = torch.randn(5, 512).cuda() |
41 |
| - target_code = torch.randn(3, 512).cuda() |
| 42 | + source_code = torch.randn(5, 512).to(device) |
| 43 | + target_code = torch.randn(3, 512).to(device) |
42 | 44 |
|
43 |
| - images = [torch.ones(1, 3, shape, shape).cuda() * -1] |
| 45 | + images = [torch.ones(1, 3, shape, shape).to(device) * -1] |
44 | 46 |
|
45 | 47 | source_image = generator(
|
46 | 48 | source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
|
|
0 commit comments