Skip to content

Commit c63e7c2

Browse files
committed
Added device option on generate.py
1 parent 866e506 commit c63e7c2

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

generate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from model import StyledGenerator
55

6-
generator = StyledGenerator(512).cuda()
6+
device = 'cuda'
7+
8+
generator = StyledGenerator(512).to(device)
79
generator.load_state_dict(torch.load('checkpoint/180000.model'))
810
generator.eval()
911

@@ -16,7 +18,7 @@
1618

1719
with torch.no_grad():
1820
for i in range(10):
19-
style = generator.mean_style(torch.randn(1024, 512).cuda())
21+
style = generator.mean_style(torch.randn(1024, 512).to(device))
2022

2123
if mean_style is None:
2224
mean_style = style
@@ -27,7 +29,7 @@
2729
mean_style /= 10
2830

2931
image = generator(
30-
torch.randn(15, 512).cuda(),
32+
torch.randn(15, 512).to(device),
3133
step=step,
3234
alpha=alpha,
3335
mean_style=mean_style,
@@ -37,10 +39,10 @@
3739
utils.save_image(image, 'sample.png', nrow=5, normalize=True, range=(-1, 1))
3840

3941
for j in range(20):
40-
source_code = torch.randn(5, 512).cuda()
41-
target_code = torch.randn(3, 512).cuda()
42+
source_code = torch.randn(5, 512).to(device)
43+
target_code = torch.randn(3, 512).to(device)
4244

43-
images = [torch.ones(1, 3, shape, shape).cuda() * -1]
45+
images = [torch.ones(1, 3, shape, shape).to(device) * -1]
4446

4547
source_image = generator(
4648
source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7

0 commit comments

Comments
 (0)