.png)
我就废话不多说了,直接上代码吧!
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | importtorchimporttorch.nn as nnfromtorch.autograd importVariableimportnumpy as npimportmatplotlib.pyplot as plttorch.manual_seed(1)np.random.seed(1)BATCH_SIZE =64LR_G =0.0001LR_D =0.0001N_IDEAS =5ART_COMPONENTS =15PAINT_POINTS =np.vstack([np.linspace(-1,1,ART_COMPONENTS) for_ inrange(BATCH_SIZE)])defartist_works():    a =np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]    paintings =a*np.power(PAINT_POINTS,2) +(a-1)    paintings =torch.from_numpy(paintings).float()    returnVariable(paintings)G =nn.Sequential(    nn.Linear(N_IDEAS,128),    nn.ReLU(),    nn.Linear(128,ART_COMPONENTS),)D =nn.Sequential(    nn.Linear(ART_COMPONENTS,128),    nn.ReLU(),    nn.Linear(128,1),    nn.Sigmoid(),)opt_D =torch.optim.Adam(D.parameters(),lr=LR_D)opt_G =torch.optim.Adam(G.parameters(),lr=LR_G)plt.ion()forstep inrange(10000):    artist_paintings =artist_works()    G_ideas =Variable(torch.randn(BATCH_SIZE,N_IDEAS))    G_paintings =G(G_ideas)    prob_artist0 =D(artist_paintings)    prob_artist1 =D(G_paintings)    D_loss =-torch.mean(torch.log(prob_artist0) +torch.log(1-prob_artist1))    G_loss =torch.mean(torch.log(1-prob_artist1))    opt_D.zero_grad()    D_loss.backward(retain_variables=True)    opt_D.step()    opt_G.zero_grad()    G_loss.backward()    opt_G.step()    ifstep %50==0:        plt.cla()        plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)        plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0], 2) +1,c='#74BCFF',lw=3,label='upper bound',)        plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0], 2) +0,c='#FF9359',lw=3,label='lower bound',)        plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)'%prob_artist0.data.numpy().mean(), fontdict={'size':15})        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)'%-D_loss.data.numpy(), fontdict={'size': 15})        plt.ylim((0,3))        plt.legend(loc='upper right', fontsize=12)        plt.draw()        plt.pause(0.01)plt.ioff()plt.show() | 

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。