본문 바로가기

Python/DeepLearning

역.전.파의 기본 개념

ㅇ 무슨 조폭 이름도 아니고 처음 들어보지만 개념은 이해할 수 있다. 여태 Input에서 Output으로 결과를 살펴보는 것을 하였다. 이 것을 순전파(Foward)라고 한다. 하지만 weight를 찾으려고 얼마나 고생했나? 몇 십개의 노드와 레이어를 달고 계속 기울기를 계산해 편향을 계산했다. 지금이야 간단한 것들을 하니까 괜찮지만 대용량의 데이터를 처리할 때에는 곤란하다. 그래서 나타난 개념이 역전파(Backward)이다.

 순방향과는 반대로 반대 방향으로 국소적 미분을 곱하며 다음 노드로 전달한다. 무슨 소린지 잘 모르겠고 구현하면서 이해해보자.

class MulLayer:
	def __init__(self):
		self.x = None
		self.y = None
	
	def forward(self, x, y):
		self.x = x
		self.y = y
		out = x*y
		return out
	
	def backward(self, dout):
		dx = dout * self.y
		dy = dout * self.x
		return dx, dy

 곱셉노드의 역전파를 구현해본다.

apple = 100
apple_num = 2
tax = 1.1

mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()

apple_price = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_price, tax)

print(price)
#220

책에서 나오는 것 처럼 사과 2개를 사고 택스가 1.1퍼 붙는 것을 생각해보자. 이 과정을 forward로 본다면 사과 -> 곱하기2 -> 곱하기1.1 -> 계산일 것이다. 

dprice = 1

dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

print(dapple, dapple_num, dtax)

#2.2 110.00000000000001 200

잘 나온다. 덧셈도 구현해본다. 덧셈은 그대로 흘려보내주면된다.

class AddLayer:
	def __init__(self):
		pass
		
	def forward(self, x, y):
		out = x+y
		return out
	
	def backward(self, dout):
		dx = dout*1
		dy = dout*1
		return dx, dy

오렌지를 넣어서 같이 계산해본다 사과 2개에 오렌지 3개 살거다 

apple = 100
apple_num = 2
orange = 150
orange_num = 3
tax = 1.1

mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
mul_tax_layer = MulLayer()
add_fruits_layer = AddLayer()

apple_price = mul_apple_layer.forward(apple, apple_num)
orange_price = mul_orange_layer.forward(orange, orange_num)
price = add_fruits_layer.forward(apple_price,orange_price)
price = mul_tax_layer.forward(price, tax)

print(price)
#715

backward 해본당.

dprice = 1

dfruits_price, dtax = mul_tax_layer.backward(dprice)
dapple_sum, dorange_sum = add_fruits_layer.backward(dfruits_price)
dapple, dapple_count = mul_apple_layer.backward(dapple_sum)
dorange, dorange_count = mul_orange_layer.backward(dorange_sum)
print(dprice,dfruits_price,dapple_sum,dorange_sum,dapple,dapple_count,dorange,dorange_count,dtax)
#1 1.1 1.1 1.1 2.2 110 3.3 165 650

 나는 코드 재사용 할 수 있을 줄 알았는데 보니까 얘네가 계속 값을 가지고 있다. 그런 멍청한 짓 하면 안됨. 알고리즘의 기본도 안된 사람이였다니... 아무튼 이렇게 진행된다. 하지만.. 얘들은 단순하기 짝이 없다. 다음엔 더 어려운 것을 구현해본다.