Re: [Dclib-devel] [dclib-devel] Implementing a GAN with dlib
Brought to you by:
davisking
From: Eloi Du B. <elo...@gm...> - 2017-10-22 05:40:20
|
Ok, I'm starting to dig a little more to understand how dlib is made, I think I have an idea on how to implement this. I have few questions about forward vs backpropagation: I'm making a discriminator defined as: // A 1 filter; 5x1 conv layer that does 2x downsampling template <typename SUBNET> using con2 = dlib::con<1, 1, 5, 1, 2, SUBNET>; template<class SUBNET> using DiscriminatorT = dlib::loss_binary_log<dlib::fc<1,dlib::relu<dlib::max_pool<1, 2, 1, 2, dlib::relu<dlib::max_pool<1, 2, 1, 2, dlib::relu<con2<SUBNET>>>>>>>>; This just to detect either a vector of 128 floats follows a gaussian distribution or not. Then, my discriminator is trained over some test data, positive and negative examples (gaussian vs noise). Now, I'm trying to check if the forward function on the net is giving the same result as a call to the neural net, std::vector<float> classificationForTest = discriminator(testVectors); resizable_tensor tempTensor; discriminator.to_tensor(&testVectors[0], &testVectors[0]+1, tempTensor); auto & tensorOut = discriminator.subnet().forward(tempTensor); assert( *(float*)tensorOut.host() == classificationForTest[0]); This is test passes, the forward pass is what I thought it was. Now, I'm trying to do a back propagation of the error. So let's say classificationForTest[0] is -8 and it should be 8, my error is 16. I'm tempted to do something like (pseudo code): discriminator.subnet().back_propagate_error(tensor = 16); My question is: How to back-propagate and rewind the full network to get the error introduced by each of my scalars of an input vector applied to the discriminator? If I get this, then I know how to solve my GAN as I would inject this error into the output tensor of the generator. Many thanks for any help, 2017-10-21 20:08 GMT-05:00 Davis King <dav...@gm...>: > I would make a network that forks into two subnetworks and then write my > own loss layer that did whatever you want to do. I doubt you could > implement any of the GAN papers using a binary classification loss like > loss_binary_log. They all have very specific loss functions that only make > sense in the context of GAN. > > ------------------------------------------------------------ > ------------------ > Check out the vibrant tech community on one of the world's most > engaging tech sites, Slashdot.org! http://sdm.link/slashdot > _______________________________________________ > Dclib-devel mailing list > Dcl...@li... > https://lists.sourceforge.net/lists/listinfo/dclib-devel > > |