fromnumpyimportexp,array,random,dot# Define neural network classclassNeuralNetwork():def__init__(self):#Seed random number generatorrandom.seed(1)# Model single neuron with 3 inputs and 1 output# Assign random weights to a 3 x 1 matrix with values between -1 and 1self.synaptic_weights=2*random.random((3,1))-1# Define sigmoid functiondef__sigmoid(self,x):return1/(1+exp(-x))# Gradient of the sigmoid curvedef__sigmoid_derivative(self,x):returnx*(1-x)deftrain(self,training_set_inputs,training_set_outputs,number_of_training_iterations):foriterationinrange(number_of_training_iterations):# Pass the trianing set through out neural networkoutput=self.predict(training_set_inputs)# Calculate the errorerror=training_set_outputs-output# Multiply the error by the input and again by the gradient of the sigmoid curveadjustment=dot(training_set_inputs.T,self.__sigmoid_derivative(output))# Adjust the weightsself.synaptic_weights+=adjustmentdefpredict(self,inputs):# Pass inputs through our neural networkreturnself.__sigmoid(dot(inputs,self.synaptic_weights))# The neural network thinksdefthink(self,inputs):# Pass inputs through our neural networkreturnself.__sigmoid(dot(inputs,self.synaptic_weights))if__name__=='__main__':# Initialize a single neuron neural networkneural_network=NeuralNetwork()print('Random starting synaptic weights:')print(neural_network.synaptic_weights)# The training set - we will have 4 examples: 3 inputs, 1 outputtraining_set_inputs=array([[0,0,1],[1,1,1],[1,0,1],[0,1,1]])training_set_outputs=array([[0,1,1,0]]).T# Train neural network using training dataneural_network.train(training_set_inputs,training_set_outputs,10000)print('New synaptic weights after training:')print(neural_network.synaptic_weights)# Test neural network with new dataprint('Consider new situation [1,0,0] -> ?')print(neural_network.think(array([1,0,0])))
Random starting synaptic weights:
[[-0.16595599]
[ 0.44064899]
[-0.99977125]]
New synaptic weights after training:
[[ 2.25398875]
[ 2.47194872]
[ 9.40867424]]
Consider new situation [1,0,0] -> ?
[ 0.90499404]